710 lines
26 KiB
Python
710 lines
26 KiB
Python
import os
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common.logger import Logger
|
|
|
|
try:
|
|
from tqdm import TqdmExperimentalWarning
|
|
|
|
# Remove experimental warning
|
|
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
|
|
from tqdm.rich import tqdm
|
|
except ImportError:
|
|
# Rich not installed, we only throw an error
|
|
# if the progress bar is used
|
|
tqdm = None
|
|
|
|
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
|
|
|
if TYPE_CHECKING:
|
|
from stable_baselines3.common import base_class
|
|
|
|
|
|
class BaseCallback(ABC):
|
|
"""
|
|
Base class for callback.
|
|
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
|
|
# The RL model
|
|
# Type hint as string to avoid circular import
|
|
model: "base_class.BaseAlgorithm"
|
|
|
|
def __init__(self, verbose: int = 0):
|
|
super().__init__()
|
|
# Number of time the callback was called
|
|
self.n_calls = 0 # type: int
|
|
# n_envs * n times env.step() was called
|
|
self.num_timesteps = 0 # type: int
|
|
self.verbose = verbose
|
|
self.locals: Dict[str, Any] = {}
|
|
self.globals: Dict[str, Any] = {}
|
|
# Sometimes, for event callback, it is useful
|
|
# to have access to the parent object
|
|
self.parent = None # type: Optional[BaseCallback]
|
|
|
|
@property
|
|
def training_env(self) -> VecEnv:
|
|
training_env = self.model.get_env()
|
|
assert (
|
|
training_env is not None
|
|
), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks"
|
|
return training_env
|
|
|
|
@property
|
|
def logger(self) -> Logger:
|
|
return self.model.logger
|
|
|
|
# Type hint as string to avoid circular import
|
|
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
|
"""
|
|
Initialize the callback by saving references to the
|
|
RL model and the training environment for convenience.
|
|
"""
|
|
self.model = model
|
|
self._init_callback()
|
|
|
|
def _init_callback(self) -> None:
|
|
pass
|
|
|
|
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
|
# Those are reference and will be updated automatically
|
|
self.locals = locals_
|
|
self.globals = globals_
|
|
# Update num_timesteps in case training was done before
|
|
self.num_timesteps = self.model.num_timesteps
|
|
self._on_training_start()
|
|
|
|
def _on_training_start(self) -> None:
|
|
pass
|
|
|
|
def on_rollout_start(self) -> None:
|
|
self._on_rollout_start()
|
|
|
|
def _on_rollout_start(self) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _on_step(self) -> bool:
|
|
"""
|
|
:return: If the callback returns False, training is aborted early.
|
|
"""
|
|
return True
|
|
|
|
def on_step(self) -> bool:
|
|
"""
|
|
This method will be called by the model after each call to ``env.step()``.
|
|
|
|
For child callback (of an ``EventCallback``), this will be called
|
|
when the event is triggered.
|
|
|
|
:return: If the callback returns False, training is aborted early.
|
|
"""
|
|
self.n_calls += 1
|
|
self.num_timesteps = self.model.num_timesteps
|
|
|
|
return self._on_step()
|
|
|
|
def on_training_end(self) -> None:
|
|
self._on_training_end()
|
|
|
|
def _on_training_end(self) -> None:
|
|
pass
|
|
|
|
def on_rollout_end(self) -> None:
|
|
self._on_rollout_end()
|
|
|
|
def _on_rollout_end(self) -> None:
|
|
pass
|
|
|
|
def update_locals(self, locals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Update the references to the local variables.
|
|
|
|
:param locals_: the local variables during rollout collection
|
|
"""
|
|
self.locals.update(locals_)
|
|
self.update_child_locals(locals_)
|
|
|
|
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Update the references to the local variables on sub callbacks.
|
|
|
|
:param locals_: the local variables during rollout collection
|
|
"""
|
|
pass
|
|
|
|
|
|
class EventCallback(BaseCallback):
|
|
"""
|
|
Base class for triggering callback on event.
|
|
|
|
:param callback: Callback that will be called
|
|
when an event is triggered.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
|
|
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
|
|
super().__init__(verbose=verbose)
|
|
self.callback = callback
|
|
# Give access to the parent
|
|
if callback is not None:
|
|
assert self.callback is not None
|
|
self.callback.parent = self
|
|
|
|
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
|
|
super().init_callback(model)
|
|
if self.callback is not None:
|
|
self.callback.init_callback(self.model)
|
|
|
|
def _on_training_start(self) -> None:
|
|
if self.callback is not None:
|
|
self.callback.on_training_start(self.locals, self.globals)
|
|
|
|
def _on_event(self) -> bool:
|
|
if self.callback is not None:
|
|
return self.callback.on_step()
|
|
return True
|
|
|
|
def _on_step(self) -> bool:
|
|
return True
|
|
|
|
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Update the references to the local variables.
|
|
|
|
:param locals_: the local variables during rollout collection
|
|
"""
|
|
if self.callback is not None:
|
|
self.callback.update_locals(locals_)
|
|
|
|
|
|
class CallbackList(BaseCallback):
|
|
"""
|
|
Class for chaining callbacks.
|
|
|
|
:param callbacks: A list of callbacks that will be called
|
|
sequentially.
|
|
"""
|
|
|
|
def __init__(self, callbacks: List[BaseCallback]):
|
|
super().__init__()
|
|
assert isinstance(callbacks, list)
|
|
self.callbacks = callbacks
|
|
|
|
def _init_callback(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.init_callback(self.model)
|
|
|
|
def _on_training_start(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_training_start(self.locals, self.globals)
|
|
|
|
def _on_rollout_start(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_rollout_start()
|
|
|
|
def _on_step(self) -> bool:
|
|
continue_training = True
|
|
for callback in self.callbacks:
|
|
# Return False (stop training) if at least one callback returns False
|
|
continue_training = callback.on_step() and continue_training
|
|
return continue_training
|
|
|
|
def _on_rollout_end(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_rollout_end()
|
|
|
|
def _on_training_end(self) -> None:
|
|
for callback in self.callbacks:
|
|
callback.on_training_end()
|
|
|
|
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Update the references to the local variables.
|
|
|
|
:param locals_: the local variables during rollout collection
|
|
"""
|
|
for callback in self.callbacks:
|
|
callback.update_locals(locals_)
|
|
|
|
|
|
class CheckpointCallback(BaseCallback):
|
|
"""
|
|
Callback for saving a model every ``save_freq`` calls
|
|
to ``env.step()``.
|
|
By default, it only saves model checkpoints,
|
|
you need to pass ``save_replay_buffer=True``,
|
|
and ``save_vecnormalize=True`` to also save replay buffer checkpoints
|
|
and normalization statistics checkpoints.
|
|
|
|
.. warning::
|
|
|
|
When using multiple environments, each call to ``env.step()``
|
|
will effectively correspond to ``n_envs`` steps.
|
|
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
|
|
|
|
:param save_freq: Save checkpoints every ``save_freq`` call of the callback.
|
|
:param save_path: Path to the folder where the model will be saved.
|
|
:param name_prefix: Common prefix to the saved models
|
|
:param save_replay_buffer: Save the model replay buffer
|
|
:param save_vecnormalize: Save the ``VecNormalize`` statistics
|
|
:param verbose: Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
save_freq: int,
|
|
save_path: str,
|
|
name_prefix: str = "rl_model",
|
|
save_replay_buffer: bool = False,
|
|
save_vecnormalize: bool = False,
|
|
verbose: int = 0,
|
|
):
|
|
super().__init__(verbose)
|
|
self.save_freq = save_freq
|
|
self.save_path = save_path
|
|
self.name_prefix = name_prefix
|
|
self.save_replay_buffer = save_replay_buffer
|
|
self.save_vecnormalize = save_vecnormalize
|
|
|
|
def _init_callback(self) -> None:
|
|
# Create folder if needed
|
|
if self.save_path is not None:
|
|
os.makedirs(self.save_path, exist_ok=True)
|
|
|
|
def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
|
|
"""
|
|
Helper to get checkpoint path for each type of checkpoint.
|
|
|
|
:param checkpoint_type: empty for the model, "replay_buffer_"
|
|
or "vecnormalize_" for the other checkpoints.
|
|
:param extension: Checkpoint file extension (zip for model, pkl for others)
|
|
:return: Path to the checkpoint
|
|
"""
|
|
return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.n_calls % self.save_freq == 0:
|
|
model_path = self._checkpoint_path(extension="zip")
|
|
self.model.save(model_path)
|
|
if self.verbose >= 2:
|
|
print(f"Saving model checkpoint to {model_path}")
|
|
|
|
if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None:
|
|
# If model has a replay buffer, save it too
|
|
replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl")
|
|
self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined]
|
|
if self.verbose > 1:
|
|
print(f"Saving model replay buffer checkpoint to {replay_buffer_path}")
|
|
|
|
if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None:
|
|
# Save the VecNormalize statistics
|
|
vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl")
|
|
self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr]
|
|
if self.verbose >= 2:
|
|
print(f"Saving model VecNormalize to {vec_normalize_path}")
|
|
|
|
return True
|
|
|
|
|
|
class ConvertCallback(BaseCallback):
|
|
"""
|
|
Convert functional callback (old-style) to object.
|
|
|
|
:param callback:
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
|
|
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
|
|
super().__init__(verbose)
|
|
self.callback = callback
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.callback is not None:
|
|
return self.callback(self.locals, self.globals)
|
|
return True
|
|
|
|
|
|
class EvalCallback(EventCallback):
|
|
"""
|
|
Callback for evaluating an agent.
|
|
|
|
.. warning::
|
|
|
|
When using multiple environments, each call to ``env.step()``
|
|
will effectively correspond to ``n_envs`` steps.
|
|
To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
|
|
|
|
:param eval_env: The environment used for initialization
|
|
:param callback_on_new_best: Callback to trigger
|
|
when there is a new best model according to the ``mean_reward``
|
|
:param callback_after_eval: Callback to trigger after every evaluation
|
|
:param n_eval_episodes: The number of episodes to test the agent
|
|
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
|
|
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
|
|
will be saved. It will be updated at each evaluation.
|
|
:param best_model_save_path: Path to a folder where the best model
|
|
according to performance on the eval env will be saved.
|
|
:param deterministic: Whether the evaluation should
|
|
use a stochastic or deterministic actions.
|
|
:param render: Whether to render or not the environment during evaluation
|
|
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results
|
|
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
|
|
wrapped with a Monitor wrapper)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
eval_env: Union[gym.Env, VecEnv],
|
|
callback_on_new_best: Optional[BaseCallback] = None,
|
|
callback_after_eval: Optional[BaseCallback] = None,
|
|
n_eval_episodes: int = 5,
|
|
eval_freq: int = 10000,
|
|
log_path: Optional[str] = None,
|
|
best_model_save_path: Optional[str] = None,
|
|
deterministic: bool = True,
|
|
render: bool = False,
|
|
verbose: int = 1,
|
|
warn: bool = True,
|
|
):
|
|
super().__init__(callback_after_eval, verbose=verbose)
|
|
|
|
self.callback_on_new_best = callback_on_new_best
|
|
if self.callback_on_new_best is not None:
|
|
# Give access to the parent
|
|
self.callback_on_new_best.parent = self
|
|
|
|
self.n_eval_episodes = n_eval_episodes
|
|
self.eval_freq = eval_freq
|
|
self.best_mean_reward = -np.inf
|
|
self.last_mean_reward = -np.inf
|
|
self.deterministic = deterministic
|
|
self.render = render
|
|
self.warn = warn
|
|
|
|
# Convert to VecEnv for consistency
|
|
if not isinstance(eval_env, VecEnv):
|
|
eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value]
|
|
|
|
self.eval_env = eval_env
|
|
self.best_model_save_path = best_model_save_path
|
|
# Logs will be written in ``evaluations.npz``
|
|
if log_path is not None:
|
|
log_path = os.path.join(log_path, "evaluations")
|
|
self.log_path = log_path
|
|
self.evaluations_results: List[List[float]] = []
|
|
self.evaluations_timesteps: List[int] = []
|
|
self.evaluations_length: List[List[int]] = []
|
|
# For computing success rate
|
|
self._is_success_buffer: List[bool] = []
|
|
self.evaluations_successes: List[List[bool]] = []
|
|
|
|
def _init_callback(self) -> None:
|
|
# Does not work in some corner cases, where the wrapper is not the same
|
|
if not isinstance(self.training_env, type(self.eval_env)):
|
|
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
|
|
|
|
# Create folders if needed
|
|
if self.best_model_save_path is not None:
|
|
os.makedirs(self.best_model_save_path, exist_ok=True)
|
|
if self.log_path is not None:
|
|
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
|
|
|
|
# Init callback called on new best model
|
|
if self.callback_on_new_best is not None:
|
|
self.callback_on_new_best.init_callback(self.model)
|
|
|
|
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Callback passed to the ``evaluate_policy`` function
|
|
in order to log the success rate (when applicable),
|
|
for instance when using HER.
|
|
|
|
:param locals_:
|
|
:param globals_:
|
|
"""
|
|
info = locals_["info"]
|
|
|
|
if locals_["done"]:
|
|
maybe_is_success = info.get("is_success")
|
|
if maybe_is_success is not None:
|
|
self._is_success_buffer.append(maybe_is_success)
|
|
|
|
def _on_step(self) -> bool:
|
|
continue_training = True
|
|
|
|
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
|
# Sync training and eval env if there is VecNormalize
|
|
if self.model.get_vec_normalize_env() is not None:
|
|
try:
|
|
sync_envs_normalization(self.training_env, self.eval_env)
|
|
except AttributeError as e:
|
|
raise AssertionError(
|
|
"Training and eval env are not wrapped the same way, "
|
|
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
|
"and warning above."
|
|
) from e
|
|
|
|
# Reset success rate buffer
|
|
self._is_success_buffer = []
|
|
|
|
episode_rewards, episode_lengths = evaluate_policy(
|
|
self.model,
|
|
self.eval_env,
|
|
n_eval_episodes=self.n_eval_episodes,
|
|
render=self.render,
|
|
deterministic=self.deterministic,
|
|
return_episode_rewards=True,
|
|
warn=self.warn,
|
|
callback=self._log_success_callback,
|
|
)
|
|
|
|
if self.log_path is not None:
|
|
assert isinstance(episode_rewards, list)
|
|
assert isinstance(episode_lengths, list)
|
|
self.evaluations_timesteps.append(self.num_timesteps)
|
|
self.evaluations_results.append(episode_rewards)
|
|
self.evaluations_length.append(episode_lengths)
|
|
|
|
kwargs = {}
|
|
# Save success log if present
|
|
if len(self._is_success_buffer) > 0:
|
|
self.evaluations_successes.append(self._is_success_buffer)
|
|
kwargs = dict(successes=self.evaluations_successes)
|
|
|
|
np.savez(
|
|
self.log_path,
|
|
timesteps=self.evaluations_timesteps,
|
|
results=self.evaluations_results,
|
|
ep_lengths=self.evaluations_length,
|
|
**kwargs,
|
|
)
|
|
|
|
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
|
|
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
|
|
self.last_mean_reward = float(mean_reward)
|
|
|
|
if self.verbose >= 1:
|
|
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
|
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
|
|
# Add to current Logger
|
|
self.logger.record("eval/mean_reward", float(mean_reward))
|
|
self.logger.record("eval/mean_ep_length", mean_ep_length)
|
|
|
|
if len(self._is_success_buffer) > 0:
|
|
success_rate = np.mean(self._is_success_buffer)
|
|
if self.verbose >= 1:
|
|
print(f"Success rate: {100 * success_rate:.2f}%")
|
|
self.logger.record("eval/success_rate", success_rate)
|
|
|
|
# Dump log so the evaluation results are printed with the correct timestep
|
|
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
|
self.logger.dump(self.num_timesteps)
|
|
|
|
if mean_reward > self.best_mean_reward:
|
|
if self.verbose >= 1:
|
|
print("New best mean reward!")
|
|
if self.best_model_save_path is not None:
|
|
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
|
|
self.best_mean_reward = float(mean_reward)
|
|
# Trigger callback on new best model, if needed
|
|
if self.callback_on_new_best is not None:
|
|
continue_training = self.callback_on_new_best.on_step()
|
|
|
|
# Trigger callback after every evaluation, if needed
|
|
if self.callback is not None:
|
|
continue_training = continue_training and self._on_event()
|
|
|
|
return continue_training
|
|
|
|
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
|
|
"""
|
|
Update the references to the local variables.
|
|
|
|
:param locals_: the local variables during rollout collection
|
|
"""
|
|
if self.callback:
|
|
self.callback.update_locals(locals_)
|
|
|
|
|
|
class StopTrainingOnRewardThreshold(BaseCallback):
|
|
"""
|
|
Stop the training once a threshold in episodic reward
|
|
has been reached (i.e. when the model is good enough).
|
|
|
|
It must be used with the ``EvalCallback``.
|
|
|
|
:param reward_threshold: Minimum expected reward per episode
|
|
to stop training.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward
|
|
threshold reached
|
|
"""
|
|
|
|
parent: EvalCallback
|
|
|
|
def __init__(self, reward_threshold: float, verbose: int = 0):
|
|
super().__init__(verbose=verbose)
|
|
self.reward_threshold = reward_threshold
|
|
|
|
def _on_step(self) -> bool:
|
|
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
|
|
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
|
|
if self.verbose >= 1 and not continue_training:
|
|
print(
|
|
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
|
|
f" is above the threshold {self.reward_threshold}"
|
|
)
|
|
return continue_training
|
|
|
|
|
|
class EveryNTimesteps(EventCallback):
|
|
"""
|
|
Trigger a callback every ``n_steps`` timesteps
|
|
|
|
:param n_steps: Number of timesteps between two trigger.
|
|
:param callback: Callback that will be called
|
|
when the event is triggered.
|
|
"""
|
|
|
|
def __init__(self, n_steps: int, callback: BaseCallback):
|
|
super().__init__(callback)
|
|
self.n_steps = n_steps
|
|
self.last_time_trigger = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
|
|
self.last_time_trigger = self.num_timesteps
|
|
return self._on_event()
|
|
return True
|
|
|
|
|
|
class StopTrainingOnMaxEpisodes(BaseCallback):
|
|
"""
|
|
Stop the training once a maximum number of episodes are played.
|
|
|
|
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
|
|
and in total for ``max_episodes * n_envs`` episodes.
|
|
|
|
:param max_episodes: Maximum number of episodes to stop training.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by
|
|
reaching ``max_episodes``
|
|
"""
|
|
|
|
def __init__(self, max_episodes: int, verbose: int = 0):
|
|
super().__init__(verbose=verbose)
|
|
self.max_episodes = max_episodes
|
|
self._total_max_episodes = max_episodes
|
|
self.n_episodes = 0
|
|
|
|
def _init_callback(self) -> None:
|
|
# At start set total max according to number of envirnments
|
|
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
|
|
|
|
def _on_step(self) -> bool:
|
|
# Check that the `dones` local variable is defined
|
|
assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
|
|
self.n_episodes += np.sum(self.locals["dones"]).item()
|
|
|
|
continue_training = self.n_episodes < self._total_max_episodes
|
|
|
|
if self.verbose >= 1 and not continue_training:
|
|
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
|
|
mean_ep_str = (
|
|
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
|
|
)
|
|
|
|
print(
|
|
f"Stopping training with a total of {self.num_timesteps} steps because the "
|
|
f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
|
|
f"by playing for {self.n_episodes} episodes "
|
|
f"{mean_ep_str}"
|
|
)
|
|
return continue_training
|
|
|
|
|
|
class StopTrainingOnNoModelImprovement(BaseCallback):
|
|
"""
|
|
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
|
|
|
|
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
|
|
|
|
It must be used with the ``EvalCallback``.
|
|
|
|
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
|
|
:param min_evals: Number of evaluations before start to count evaluations without improvements.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
|
|
"""
|
|
|
|
parent: EvalCallback
|
|
|
|
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
|
|
super().__init__(verbose=verbose)
|
|
self.max_no_improvement_evals = max_no_improvement_evals
|
|
self.min_evals = min_evals
|
|
self.last_best_mean_reward = -np.inf
|
|
self.no_improvement_evals = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
|
|
|
|
continue_training = True
|
|
|
|
if self.n_calls > self.min_evals:
|
|
if self.parent.best_mean_reward > self.last_best_mean_reward:
|
|
self.no_improvement_evals = 0
|
|
else:
|
|
self.no_improvement_evals += 1
|
|
if self.no_improvement_evals > self.max_no_improvement_evals:
|
|
continue_training = False
|
|
|
|
self.last_best_mean_reward = self.parent.best_mean_reward
|
|
|
|
if self.verbose >= 1 and not continue_training:
|
|
print(
|
|
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
|
|
)
|
|
|
|
return continue_training
|
|
|
|
|
|
class ProgressBarCallback(BaseCallback):
|
|
"""
|
|
Display a progress bar when training SB3 agent
|
|
using tqdm and rich packages.
|
|
"""
|
|
|
|
pbar: tqdm
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
if tqdm is None:
|
|
raise ImportError(
|
|
"You must install tqdm and rich in order to use the progress bar callback. "
|
|
"It is included if you install stable-baselines with the extra packages: "
|
|
"`pip install stable-baselines3[extra]`"
|
|
)
|
|
|
|
def _on_training_start(self) -> None:
|
|
# Initialize progress bar
|
|
# Remove timesteps that were done in previous training sessions
|
|
self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
|
|
|
|
def _on_step(self) -> bool:
|
|
# Update progress bar, we do num_envs steps per call to `env.step()`
|
|
self.pbar.update(self.training_env.num_envs)
|
|
return True
|
|
|
|
def _on_training_end(self) -> None:
|
|
# Flush and close progress bar
|
|
self.pbar.refresh()
|
|
self.pbar.close()
|