Stable baselines 3¶
Training¶
In this example, we showcase how the built-in stable baselines 3 (sb3) extension can be used to smoothly leverage the large associated collection of algorithm and policies.
1import math
2import pathlib
3import random
4import shutil
5
6from stable_baselines3.common.callbacks import (
7 EvalCallback,
8 StopTrainingOnRewardThreshold,
9)
10from stable_baselines3.common.logger import configure
11
12from amaze import Maze, Robot, Simulation, Sign, amaze_main, StartLocation
13from amaze.extensions.sb3 import (
14 make_vec_maze_env,
15 env_method,
16 load_sb3_controller,
17 PPO,
18 TensorboardCallback,
19 sb3_controller,
20 CV2QTGuard,
21)
As usual, we start by importing the necessary packages and we define some global configuration options. Note that, in addition to the traditional amaze classes, we also import extension-specific items (detailed below).
37 train_mazes = Maze.BuildData.from_string(train_maze).all_rotations()
38 eval_mazes = [d.where(seed=test_seed) for d in train_mazes]
39
40 train_env = make_vec_maze_env(train_mazes, robot, SEED)
41 eval_env = make_vec_maze_env(eval_mazes, robot, SEED, log_trajectory=True)
The training function is defined much more shortly than in the hand-written
q-learning case thanks to the added functionalities of stable baselines 3 and
added wrappers.
While, creating mazes and robots should be familiar by now, we see a new
extension-specific function
make_vec_maze_env()
used to create Vectorized Environments
(VecEnv)
43 optimal_reward = sum(env_method(eval_env, "optimal_reward")) / len(eval_mazes)
We also, sometimes, need access to the underlying environments (regular mazes) as
illustrated below.
There we collect the average optimal reward by calling
optimal_reward() on every maze
used for intermediate performance evaluation thanks to
env_method().
44 tb_callback = TensorboardCallback(
45 log_trajectory_every=1, max_timestep=BUDGET # Eval callback (below)
46 )
47 eval_callback = EvalCallback(
48 eval_env,
49 best_model_save_path=FOLDER,
50 log_path=FOLDER,
51 eval_freq=max(100, BUDGET // (10 * len(eval_mazes))),
52 verbose=1,
53 n_eval_episodes=len(eval_mazes),
54 callback_after_eval=tb_callback,
55 callback_on_new_best=StopTrainingOnRewardThreshold(
56 reward_threshold=optimal_reward, verbose=1
57 ),
58 )
Next we create a
TensorboardCallback, an illustrative
built-in callback that uses Tensorboard to provide an overview of the training
process.
In addition to logging numerical data such as the average rewards it also
automatically generates trajectory images whenever the
EventCallback is triggered.
The following lines define such an object, in a traditional SB3 fashion, while
adding our own tensorboard callback and also using the optimal reward to stop
as soon as the agent is behaving optimally.
61 model = sb3_controller(
62 robot,
63 PPO,
64 policy="MlpPolicy",
65 env=train_env,
66 seed=SEED,
67 learning_rate=1e-3,
68 )
69
70 print("== Starting", "=" * 68)
71 model.set_logger(configure(FOLDER, ["csv", "tensorboard"]))
72 model.learn(BUDGET, callback=eval_callback, progress_bar=not is_test)
73
74 tb_callback.log_step(final=True)
Finally, we create the sb3 model, using the dedicated wrapper
sb3_controller(), by providing the robot data and the
of underlying model type (one of compatible_models())
and, afterwards, the usual parameters.
Then after setting up the logger and letting the training process run its
course, we perform a final step of the callback to render the final
trajectories.
Using¶
80def evaluate(is_test):
81 model = load_sb3_controller(BEST)
Once the training process is complete, we evaluate the resulting agent’s
generalization capability in the same manner as in Full example: Training.
The only difference is the use of the dedicated loading function
load_sb3_controller() which is a verbose alias to
load().
The reminder of this function being the same, we refer the reader to the
previous example (Generalization), if needed.
132def main(is_test=False):
133 folder = pathlib.Path(FOLDER)
134 if folder.exists():
135 shutil.rmtree(folder)
136 folder.mkdir(parents=True, exist_ok=False)
137
138 train_maze = train(is_test)
139 evaluate(is_test)
140
141 with CV2QTGuard(platform=False):
142 amaze_main(
143 f"--controller {BEST} --extension sb3 --maze {train_maze}"
144 f" --auto-quit --no-restore-config --width 1000"
145 )
Finally, the main should also be familiar from the previous example.
One thing to note, however, is that, due to incompatibilities between the current
opencv and PyQT5 libraries, one should use
CV2QTGuard when combining stable baselines
3 with the native Qt5 components.