Source code for amaze.extensions.sb3.maze_env

"""SB3 wrapper for the maze environment"""

import logging
from typing import Optional, List

import numpy as np
from PyQt5 import QtWidgets
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QImage, QPainter
from gymnasium import spaces, Env
from stable_baselines3.common.env_checker import check_env as _check_env
from stable_baselines3.common.env_util import make_vec_env

from ...misc.utils import qt_application
from .guard import CV2QTGuard
from .utils import IOMapper
from ...misc.resources import qimage_to_numpy
from ...simu.maze import Maze
from ...simu.robot import Robot
from ...simu.simulation import Simulation
from ...simu.types import InputType, OutputType
from ...visu.widgets.maze import MazeWidget

logger = logging.getLogger(__name__)


[docs] def make_vec_maze_env( mazes: List[Maze.BuildData], robot: Robot.BuildData, seed, check_env=True, **kwargs, ): """Encapsulates the creation of a vectorized environment""" mazes = [m for m in mazes] def env_fn(): env = MazeEnv(mazes.pop(0), robot, **kwargs) if check_env: _check_env(env) env.reset(full_reset=True) return env return make_vec_env(env_fn, n_envs=len(mazes), seed=seed)
[docs] def env_method(env, method: str, *args, **kwargs): """Calls a given function, with specified arguments, on each underlying environments""" return [getattr(e.unwrapped, method)(*args, **kwargs) for e in env.envs]
[docs] def env_attr(env, attr: str): """Returns the requested attribute from each underlying environments""" return [getattr(e.unwrapped, attr) for e in env.envs]
[docs] class MazeEnv(Env): """AMaze wrapper for the stable baselines 3 library""" metadata = dict( render_modes=["human", "rgb_array"], render_fps=30, min_resolution=256 ) def __init__( self, maze: Maze.BuildData, robot: Robot.BuildData, log_trajectory: bool = False, ): """Built with maze data and robot data :param ~amaze.simu.maze.Maze.BuildData maze: maze data :param ~amaze.simu.robot.Robot.BuildData robot: agent data """ super().__init__() self.render_mode = "rgb_array" self.name = maze.to_string() self._simulation = Simulation( Maze.generate(maze), robot, save_trajectory=log_trajectory ) _pretty_rewards = ", ".join( f"{k}: {v:.2g}" for k, v in self._simulation.rewards.__dict__.items() ) logger.debug( f"Creating MazeEnv with" f"\n {self.name}" f"\n maze: {maze}" f"\nrobot: {robot}" f"\nrewards: [{_pretty_rewards}]" f"\n{log_trajectory=}" ) self.observation_type = robot.inputs if robot.inputs is InputType.DISCRETE: self.observation_space = spaces.Box( low=-1, high=1, shape=(8,), dtype=np.float32 ) else: self.observation_space = spaces.Box( low=0, high=255, shape=(1, robot.vision, robot.vision), dtype=np.uint8, ) self.action_type = robot.outputs self.action_space = { OutputType.DISCRETE: spaces.Discrete(4), OutputType.CONTINUOUS: spaces.Box( low=-1, high=1, shape=(2,), dtype=np.float32 ), }[robot.outputs] self.mapper = IOMapper( observation_space=self.observation_space, action_space=self.action_space, ) self.widget, self.app = None, None self.prev_trajectory = None self.last_infos = None self.length = len(self._simulation.maze.solution) self.resets = 0
[docs] def reset(self, seed=None, options=None, full_reset=False): """Stub""" self.last_infos = self.infos() if self._simulation.trajectory is not None: self.prev_trajectory = self._simulation.trajectory.copy(True) super().reset(seed=seed) self._simulation.reset() maze_str = self._simulation.maze.to_string() if full_reset: self.resets = 0 logger.debug(f"Initial reset for {maze_str}") else: self.resets += 1 logger.debug(f"Reset {self.resets} for {maze_str}") return self._observations(), self.infos()
[docs] def step(self, action): """Stub docstring""" vec_action = self.mapper.map_action(action) reward = self._simulation.step(vec_action) observation = self._observations() terminated = self._simulation.success() truncated = self._simulation.failure() info = self._simulation.infos() # done = terminated or truncated # logger.debug(f"Step {self._simulation.timestep:03d} ({done=})" # f" for {self._simulation.maze.to_string()}") return observation, reward, terminated, truncated, info
[docs] def render(self) -> Optional[np.ndarray]: """Stub""" with CV2QTGuard(): # Using Qt in CV2 context -> Protect s = 256 if self.widget is None: self.widget = self._create_widget(show_robot=True) img = QImage(s, s, QImage.Format_RGB888) img.fill(Qt.white) painter = QPainter(img) self.widget.render_onto(painter, width=s) painter.end() return qimage_to_numpy(img)
[docs] def close(self): """Stub""" pass
def name(self): return self.name def atomic_rewards(self): return self._simulation.rewards
[docs] def optimal_reward(self): """Return the cumulative reward for an agent following an optimal trajectory""" return self._simulation.optimal_reward
def maximal_duration(self): return self._simulation.deadline def io_types(self): return (self._simulation.data.inputs, self._simulation.data.outputs) def log_trajectory(self, do_log: bool): self._simulation.reset(save_trajectory=do_log) def plot_trajectory( self, cb_side: int = 0, verbose: bool = True, square: bool = False ) -> np.ndarray: with CV2QTGuard(): _ = qt_application() plot = MazeWidget.plot_trajectory( simulation=self._simulation, size=256, trajectory=self.prev_trajectory, config=dict(solution=True, robot=False, dark=True), side=cb_side, verbose=verbose, square=square, img_format=QImage.Format_RGBA8888, path=None, ) img = qimage_to_numpy(plot) return img def _create_widget(self, show_robot=False): if self.widget: return self.widget app = QtWidgets.QApplication.instance() if app is None: # logger.debug("Creating qt app") self.app = QtWidgets.QApplication([]) # logger.debug("Creating qt widget") self.widget = MazeWidget.from_simulation(simulation=self._simulation) self.widget.update_config(robot=show_robot, solution=True, dark=True) return self.widget def _observations(self): return self.mapper.map_observation(self._simulation.observations) def infos(self): return self._simulation.infos()