Source code for amaze.extensions.sb3.networks

"""Half-hearted attempt a making a custom CNN.

Improvements are being worked on.
"""

from gymnasium import spaces
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torch import nn, no_grad, as_tensor, Tensor


[docs] class CustomCNN(BaseFeaturesExtractor): """Bare-bones attempt at using a custom CNN. :param observation_space: (gym.Space) :param features_dim: (int) Number of features extracted. This corresponds to the number of unit for the last layer. """ def __init__(self, observation_space: spaces.Box, features_dim: int = 256): super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper n_input_channels = observation_space.shape[0] print("=" * 80) print("== CustomCNN") print("=" * 80) print(n_input_channels) print(observation_space) self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 8, kernel_size=5, stride=3, padding=0), nn.ReLU(), nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0), nn.ReLU(), nn.Flatten(), ) # Compute shape by doing one forward pass with no_grad(): n_flatten = self.cnn( as_tensor(observation_space.sample()[None]).float() ).shape[1] self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) exit(42)
[docs] def forward(self, observations: Tensor) -> Tensor: """Performs one computational step""" return self.linear(self.cnn(observations))