Source code for amaze.simu.controllers.control

import json
import logging
from pathlib import Path
from typing import Union, Type, Optional
from zipfile import ZipFile

from . import (
    BaseController,
    CheaterController,
    KeyboardController,
    RandomController,
    TabularController,
)
from .base import Robot

logger = logging.getLogger(__name__)

CONTROLLERS: dict[str, Type[BaseController]] = {
    t.short_name: t
    for t in [
        RandomController,
        CheaterController,
        KeyboardController,
        TabularController,
    ]
}


[docs] def builtin_controllers(): """Provides the list of controllers shipped with this library""" return list(CONTROLLERS.keys())
[docs] def controller_factory(c_type: str, c_data: dict): """Create a controller of a given c_type from the given c_data""" c_class = CONTROLLERS[c_type.lower()] if not getattr(c_class, "cheats", False): c_data.pop("simulation", None) return c_class(**c_data)
[docs] def save( controller: BaseController, path: Union[Path, str], infos: Optional[dict] = None, *args, **kwargs, ) -> Path: """Save the controller under the provided path Optionally store the provided information for latter reference (e.g. type of mazes, performance, ...) Additional arguments are forwarded to the controller's :meth:`~.BaseController._save_to_archive` """ reverse_map = {t: n for n, t in CONTROLLERS.items()} if (controller_class := reverse_map.get(type(controller), None)) is None: raise ValueError(f"Controller class {type(controller)} is not" f" registered") if isinstance(path, str): path = Path(path) if path.suffix != ".zip": path = path.with_suffix(".zip") with ZipFile(path, "w") as archive: archive.writestr("controller_class", controller_class) archive.writestr("robot", controller.robot_data.to_string()) _infos = controller.infos.copy() if infos is not None: _infos.update(infos) if _infos: archive.writestr("infos", json.dumps(_infos).encode("utf-8")) # noinspection PyProtectedMember controller._save_to_archive(archive, *args, **kwargs) logger.debug(f"Saved controller to {path}") return path
[docs] def load(path: Union[Path, str], *args, **kwargs): """Loads a controller from the provided path. Handles any type currently registered. When using extensions, make sure to load (import) all those used during training. """ logger.debug(f"Loading controller from {path}") with ZipFile(path, "r") as archive: controller_class = archive.read("controller_class").decode("utf-8") if (c_type := CONTROLLERS.get(controller_class)) is None: msg = f"Unsupported controller type {controller_class}." if len(tokens := controller_class.split(".")) > 1: msg += f" Did you forget to include the '{tokens[0]}'" f" extension?" raise ValueError(msg) logger.debug(f"> controller class: {controller_class}") robot = Robot.BuildData.from_string(archive.read("robot").decode("utf-8")) logger.debug(f"> Robot build data: {robot}") # noinspection PyProtectedMember c = c_type._load_from_archive(archive, *args, robot=robot, **kwargs) if "infos" in archive.namelist(): c.infos = json.loads(archive.read("infos").decode("utf-8")) return c