Stable baselines 3

Training

In this example, we showcase how the built-in stable baselines 3 (sb3) extension can be used to smoothly leverage the large associated collection of algorithm and policies.

 1import math
 2import pathlib
 3import random
 4import shutil
 5
 6from stable_baselines3.common.callbacks import (
 7    EvalCallback,
 8    StopTrainingOnRewardThreshold,
 9)
10from stable_baselines3.common.logger import configure
11
12from amaze import Maze, Robot, Simulation, Sign, amaze_main, StartLocation
13from amaze.extensions.sb3 import (
14    make_vec_maze_env,
15    env_method,
16    load_sb3_controller,
17    PPO,
18    TensorboardCallback,
19    sb3_controller,
20    CV2QTGuard,
21)

As usual, we start by importing the necessary packages and we define some global configuration options. Note that, in addition to the traditional amaze classes, we also import extension-specific items (detailed below).

37    train_mazes = Maze.BuildData.from_string(train_maze).all_rotations()
38    eval_mazes = [d.where(seed=test_seed) for d in train_mazes]
39
40    train_env = make_vec_maze_env(train_mazes, robot, SEED)
41    eval_env = make_vec_maze_env(eval_mazes, robot, SEED, log_trajectory=True)

The training function is defined much more shortly than in the hand-written q-learning case thanks to the added functionalities of stable baselines 3 and added wrappers. While, creating mazes and robots should be familiar by now, we see a new extension-specific function make_vec_maze_env() used to create Vectorized Environments (VecEnv)

43    optimal_reward = sum(env_method(eval_env, "optimal_reward")) / len(eval_mazes)

We also, sometimes, need access to the underlying environments (regular mazes) as illustrated below. There we collect the average optimal reward by calling optimal_reward() on every maze used for intermediate performance evaluation thanks to env_method().

44    tb_callback = TensorboardCallback(
45        log_trajectory_every=1, max_timestep=BUDGET  # Eval callback (below)
46    )
47    eval_callback = EvalCallback(
48        eval_env,
49        best_model_save_path=FOLDER,
50        log_path=FOLDER,
51        eval_freq=max(100, BUDGET // (10 * len(eval_mazes))),
52        verbose=1,
53        n_eval_episodes=len(eval_mazes),
54        callback_after_eval=tb_callback,
55        callback_on_new_best=StopTrainingOnRewardThreshold(
56            reward_threshold=optimal_reward, verbose=1
57        ),
58    )

Next we create a TensorboardCallback, an illustrative built-in callback that uses Tensorboard to provide an overview of the training process. In addition to logging numerical data such as the average rewards it also automatically generates trajectory images whenever the EventCallback is triggered. The following lines define such an object, in a traditional SB3 fashion, while adding our own tensorboard callback and also using the optimal reward to stop as soon as the agent is behaving optimally.

61    model = sb3_controller(
62        robot,
63        PPO,
64        policy="MlpPolicy",
65        env=train_env,
66        seed=SEED,
67        learning_rate=1e-3,
68    )
69
70    print("== Starting", "=" * 68)
71    model.set_logger(configure(FOLDER, ["csv", "tensorboard"]))
72    model.learn(BUDGET, callback=eval_callback, progress_bar=not is_test)
73
74    tb_callback.log_step(final=True)

Finally, we create the sb3 model, using the dedicated wrapper sb3_controller(), by providing the robot data and the of underlying model type (one of compatible_models()) and, afterwards, the usual parameters. Then after setting up the logger and letting the training process run its course, we perform a final step of the callback to render the final trajectories.

Using

80def evaluate(is_test):
81    model = load_sb3_controller(BEST)

Once the training process is complete, we evaluate the resulting agent’s generalization capability in the same manner as in Full example: Training. The only difference is the use of the dedicated loading function load_sb3_controller() which is a verbose alias to load(). The reminder of this function being the same, we refer the reader to the previous example (Generalization), if needed.

Full listing for examples/extensions/sb3:main
132def main(is_test=False):
133    folder = pathlib.Path(FOLDER)
134    if folder.exists():
135        shutil.rmtree(folder)
136    folder.mkdir(parents=True, exist_ok=False)
137
138    train_maze = train(is_test)
139    evaluate(is_test)
140
141    with CV2QTGuard(platform=False):
142        amaze_main(
143            f"--controller {BEST} --extension sb3 --maze {train_maze}"
144            f" --auto-quit --no-restore-config --width 1000"
145        )

Finally, the main should also be familiar from the previous example. One thing to note, however, is that, due to incompatibilities between the current opencv and PyQT5 libraries, one should use CV2QTGuard when combining stable baselines 3 with the native Qt5 components.