Source code for amaze.extensions.sb3.controller

"""Implements a wrapper around common models from stable baselines 3"""

import io
from typing import Optional, Dict, Type, Union, Tuple
from zipfile import ZipFile

import numpy as np
import torch
from gymnasium import Space
from gymnasium.spaces import Discrete, Box
from stable_baselines3 import SAC, A2C, DQN, PPO, TD3
from stable_baselines3.common.base_class import BaseAlgorithm

from .utils import IOMapper
from ...simu.controllers.base import BaseController
from ...simu.controllers.control import save
from ...simu.pos import Vec
from ...simu.robot import Robot
from ...simu.types import InputType, OutputType, State

_classes = {c.__name__: c for c in [SAC, A2C, DQN, PPO, TD3]}
_i_types_mapping: Dict[int, InputType] = {
    1: InputType.DISCRETE,
    3: InputType.CONTINUOUS,
}
_o_types_mapping: Dict[Type[Space], OutputType] = {
    Discrete: OutputType.DISCRETE,
    Box: OutputType.CONTINUOUS,
}


[docs] def wrapped_sb3_model(model_type: Type[BaseAlgorithm]): """Creates a class wrapping a specific stable baselines 3 model. Internal use only. """ class SB3Controller(model_type, BaseController): simple = False _model_type = model_type def __init__(self, robot_data: Robot.BuildData, *args, **kwargs): # noinspection PyTypeChecker BaseController.__init__(self, robot_data=robot_data) model_type.__init__(self, *args, **kwargs) # print(f"[kgd-debug] policy={self.policy.__class__.__name__}" # f" {self._i_type=} {self._o_type=} {self._vision=}") def _setup_model(self) -> None: super()._setup_model() # print("[kgd-debug] SB3 model setup") input_type = _i_types_mapping[len(self.observation_space.shape)] output_type = _o_types_mapping[self.action_space.__class__] vision = ( None if input_type is InputType.DISCRETE else self.observation_space.shape[1] ) deduced_robot_data = Robot.BuildData(input_type, output_type, vision) if self.robot_data != deduced_robot_data: raise ValueError( "Incompatible IO specifications:\n" f"- Model created with {self.robot_data}\n" f"- Deduced from environment" f" {deduced_robot_data}" ) self._mapper = IOMapper( observation_space=self.observation_space, action_space=self.action_space, ) @classmethod def __repr__(cls) -> str: return f"SB3.Controller[{cls._model_type.__name__}]" def __call__(self, inputs: State) -> Vec: return self._mapper.map_action( self.policy.predict( self._mapper.map_observation(inputs), deterministic=True )[0] ) def predict( self, observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: # print("predict", f"{deterministic=}") return super().predict(observation, state, episode_start, deterministic) def value(self, inputs: State) -> float: if isinstance(self._mapper.a_space, Discrete): actions = range(self._mapper.a_space.n) else: raise NotImplementedError obs, _ = self.policy.obs_to_tensor(self._mapper.map_observation(inputs)) _, log_prob, _ = self.policy.evaluate_actions(obs, torch.Tensor(actions)) return log_prob def reset(self): pass @staticmethod def inputs_types(): return list(InputType) @staticmethod def outputs_types(): return list(OutputType) def save(self, path: str, *_args, **_kwargs) -> None: # print("[kgd-debug] infos:\n", pprint.pformat(infos)) save( self, path, *_args, infos=dict(algo=self._model_type.__name__), **_kwargs, ) def _save_to_archive(self, archive: ZipFile, *_args, **_kwargs) -> bool: """Delegates savings of the internals to the SB3 model""" buffer = io.BytesIO() self._model_type.save(self, buffer, *_args, **_kwargs) archive.writestr("sb3.zip", buffer.getvalue()) return True @classmethod def _load_from_archive( cls, archive: ZipFile, robot: Robot.BuildData, *_args, **_kwargs ): """Loads the SB3 specific contents from the archive""" buffer = io.BytesIO(archive.read("sb3.zip")) loaded_model = cls._model_type.load(buffer, *_args, **_kwargs) model = cls( robot_data=robot, policy=loaded_model.policy, env=loaded_model.env, device=loaded_model.device, _init_setup_model=False, ) model.__dict__.update(loaded_model.__dict__) return model return SB3Controller