Source code for amaze.extensions.sb3.utils
"""Various utility functions"""
import numpy as np
from gymnasium import Space
from gymnasium.spaces import Discrete
from ...simu.pos import Vec
from ...simu.simulation import Simulation
[docs]
class IOMapper:
"""Transform AMaze's inputs/outputs types to SB3 objects"""
def __init__(self, observation_space: Space, action_space: Space):
self.o_space = observation_space
if len(self.o_space.shape) == 1:
self.map_observation = lambda obs: obs
else:
self.map_observation = (
lambda obs: (obs * 255).astype(np.uint8).reshape(self.o_space.shape)
)
self.a_space = action_space
if isinstance(self.a_space, Discrete):
self.action_mapping = Simulation.discrete_actions()
self.map_action = lambda a: Vec(*self.action_mapping[a])
else:
self.map_action = lambda a: Vec(*a)