I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,34 @@
import os
from stable_baselines3.a2c import A2C
from stable_baselines3.common.utils import get_system_info
from stable_baselines3.ddpg import DDPG
from stable_baselines3.dqn import DQN
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3
# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()
def HER(*args, **kwargs):
raise ImportError(
"Since Stable Baselines 2.1.0, `HER` is now a replay buffer class `HerReplayBuffer`.\n "
"Please check the documentation for more information: https://stable-baselines3.readthedocs.io/"
)
__all__ = [
"A2C",
"DDPG",
"DQN",
"PPO",
"SAC",
"TD3",
"HerReplayBuffer",
"get_system_info",
]

View File

@ -0,0 +1,4 @@
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "A2C"]

View File

@ -0,0 +1,208 @@
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance
SelfA2C = TypeVar("SelfA2C", bound="A2C")
class A2C(OnPolicyAlgorithm):
"""
Advantage Actor Critic (A2C)
Paper: https://arxiv.org/abs/1602.01783
Code: This implementation borrows code from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
and Stable Baselines (https://github.com/hill-a/stable-baselines)
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param rms_prop_eps: RMSProp epsilon. It stabilizes square root computation in denominator
of RMSProp update
:param use_rms_prop: Whether to use RMSprop (default) or Adam as optimizer
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param normalize_advantage: Whether to normalize or not the advantage
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 7e-4,
n_steps: int = 5,
gamma: float = 0.99,
gae_lambda: float = 1.0,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
rms_prop_eps: float = 1e-5,
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate=learning_rate,
n_steps=n_steps,
gamma=gamma,
gae_lambda=gae_lambda,
ent_coef=ent_coef,
vf_coef=vf_coef,
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
seed=seed,
_init_setup_model=False,
supported_action_spaces=(
spaces.Box,
spaces.Discrete,
spaces.MultiDiscrete,
spaces.MultiBinary,
),
)
self.normalize_advantage = normalize_advantage
# Update optimizer inside the policy if we want to use RMSProp
# (original implementation) rather than Adam
if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
if _init_setup_model:
self._setup_model()
def train(self) -> None:
"""
Update policy using the currently gathered
rollout buffer (one gradient step over whole data).
"""
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# This will only loop once (get all data in one go)
for rollout_data in self.rollout_buffer.get(batch_size=None):
actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long
actions = actions.long().flatten()
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage (not present in the original implementation)
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy gradient loss
policy_loss = -(advantages * log_prob).mean()
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values)
# Entropy loss favor exploration
if entropy is None:
# Approximate entropy when no analytical form
entropy_loss = -th.mean(-log_prob)
else:
entropy_loss = -th.mean(entropy)
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
self._n_updates += 1
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/explained_variance", explained_var)
self.logger.record("train/entropy_loss", entropy_loss.item())
self.logger.record("train/policy_loss", policy_loss.item())
self.logger.record("train/value_loss", value_loss.item())
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
def learn(
self: SelfA2C,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "A2C",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfA2C:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View File

@ -0,0 +1,7 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy

View File

@ -0,0 +1,306 @@
from typing import Dict, SupportsFloat
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
try:
import cv2
cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None # type: ignore[assignment]
class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sticky action.
Paper: https://arxiv.org/abs/1709.06009
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
:param env: Environment to wrap
:param action_repeat_probability: Probability of repeating the last action
"""
def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
super().__init__(env)
self.action_repeat_probability = action_repeat_probability
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self._sticky_action = 0 # NOOP
return self.env.reset(**kwargs)
def step(self, action: int) -> AtariStepReturn:
if self.np_random.random() >= self.action_repeat_probability:
self._sticky_action = action
return self.env.step(self._sticky_action)
class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
:param env: Environment to wrap
:param noop_max: Maximum value of no-ops to run
"""
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
super().__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info: Dict = {}
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
return obs, info
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Take action on reset for environments that are fixed until firing.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(1)
if terminated or truncated:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(2)
if terminated or truncated:
self.env.reset(**kwargs)
return obs, {}
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
self.lives = 0
self.was_real_done = True
def step(self, action: int) -> AtariStepReturn:
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = terminated or truncated
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
if 0 < lives < self.lives:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
terminated = True
self.lives = lives
return obs, reward, terminated, truncated, info
def reset(self, **kwargs) -> AtariResetReturn:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
:param kwargs: Extra keywords passed to env.reset() call
:return: the first observation of the environment
"""
if self.was_real_done:
obs, info = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, terminated, truncated, info = self.env.step(0)
# The no-op step can lead to a game over, so we need to check it again
# to see if we should reset the environment and avoid the
# monitor.py `RuntimeError: Tried to step environment that needs reset`
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
return obs, info
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Return only every ``skip``-th frame (frameskipping)
and return the max between the two last frames.
:param env: Environment to wrap
:param skip: Number of ``skip``-th frame
The same action will be taken ``skip`` times.
"""
def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
assert env.observation_space.shape is not None, "No shape defined for the observation space"
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
self._skip = skip
def step(self, action: int) -> AtariStepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.
:param action: the action
:return: observation, reward, terminated, truncated, information
"""
total_reward = 0.0
terminated = truncated = False
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += float(reward)
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
class ClipRewardEnv(gym.RewardWrapper):
"""
Clip the reward to {+1, 0, -1} by its sign.
:param env: Environment to wrap
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
def reward(self, reward: SupportsFloat) -> float:
"""
Bin reward to {+1, 0, -1} by its sign.
:param reward:
:return:
"""
return np.sign(float(reward))
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
"""
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.
:param env: Environment to wrap
:param width: New frame width
:param height: New frame height
"""
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
super().__init__(env)
self.width = width
self.height = height
assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}"
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.height, self.width, 1),
dtype=env.observation_space.dtype, # type: ignore[arg-type]
)
def observation(self, frame: np.ndarray) -> np.ndarray:
"""
returns the current observation from a frame
:param frame: environment frame
:return: the observation
"""
assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`"
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame[:, :, None]
class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Atari 2600 preprocessings
Specifically:
* Noop reset: obtain initial state by taking random number of no-ops on reset.
* Frame skipping: 4 by default
* Max-pooling: most recent two observations
* Termination signal when a life is lost.
* Resize to a square image: 84x84 by default
* Grayscale observation
* Clip reward to {-1, 0, 1}
* Sticky actions: disabled by default
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
for a visual explanation.
.. warning::
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
:param env: Environment to wrap
:param noop_max: Max number of no-ops
:param frame_skip: Frequency at which the agent experiences the game.
This correspond to repeating the action ``frame_skip`` times.
:param screen_size: Resize Atari frame
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
:param action_repeat_probability: Probability of repeating the last action
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = True,
clip_reward: bool = True,
action_repeat_probability: float = 0.0,
) -> None:
if action_repeat_probability > 0.0:
env = StickyActionEnv(env, action_repeat_probability)
if noop_max > 0:
env = NoopResetEnv(env, noop_max=noop_max)
# frame_skip=1 is the same as no frame-skip (action repeat)
if frame_skip > 1:
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined]
env = FireResetEnv(env)
env = WarpFrame(env, width=screen_size, height=screen_size)
if clip_reward:
env = ClipRewardEnv(env)
super().__init__(env)

View File

@ -0,0 +1,844 @@
"""Abstract base classes for RL algorithms."""
import io
import pathlib
import time
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
get_schedule_fn,
get_system_info,
set_random_seed,
update_learning_rate,
)
from stable_baselines3.common.vec_env import (
DummyVecEnv,
VecEnv,
VecNormalize,
VecTransposeImage,
is_vecenv_wrapped,
unwrap_vec_normalize,
)
from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")
def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
"""If env is a string, make the environment; otherwise, return env.
:param env: The environment to learn from.
:param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
:return A Gym (vector) environment.
"""
if isinstance(env, str):
env_id = env
if verbose >= 1:
print(f"Creating environment from the given name '{env_id}'")
# Set render_mode to `rgb_array` as default, so we can record video
try:
env = gym.make(env_id, render_mode="rgb_array")
except TypeError:
env = gym.make(env_id)
return env
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
# Policy aliases (see _get_policy_from_name())
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
n_envs: int
lr_schedule: Schedule
_logger: Logger
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
self.policy_class = policy
self.device = get_device(device)
if verbose >= 1:
print(f"Using {self.device} device")
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
self.seed = seed
self.action_noise: Optional[ActionNoise] = None
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._episode_num = 0
# Used for gSDE only
self.use_sde = use_sde
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
# this is used to update the learning rate
self._current_progress_remaining = 1.0
# Buffers for logging
self._stats_window_size = stats_window_size
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
# For logging (and TD3 delayed updates)
self._n_updates = 0 # type: int
# Whether the user passed a custom logger or not
self._custom_logger = False
self.env: Optional[VecEnv] = None
self._vec_normalize_env: Optional[VecNormalize] = None
# Create and wrap the env if needed
if env is not None:
env = maybe_make_env(env, self.verbose)
env = self._wrap_env(env, self.verbose, monitor_wrapper)
self.observation_space = env.observation_space
self.action_space = env.action_space
self.n_envs = env.num_envs
self.env = env
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
if supported_action_spaces is not None:
assert isinstance(self.action_space, supported_action_spaces), (
f"The algorithm only supports {supported_action_spaces} as action spaces "
f"but {self.action_space} was provided"
)
if not support_multi_env and self.n_envs > 1:
raise ValueError(
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
)
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
if self.use_sde and not isinstance(self.action_space, spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
if isinstance(self.action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([self.action_space.low, self.action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
@staticmethod
def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> VecEnv:
""" "
Wrap environment with the appropriate wrappers if needed.
For instance, to have a vectorized environment
or to re-order the image channels.
:param env:
:param verbose: Verbosity level: 0 for no output, 1 for indicating wrappers used
:param monitor_wrapper: Whether to wrap the env in a ``Monitor`` when possible.
:return: The wrapped environment.
"""
if not isinstance(env, VecEnv):
# Patch to support gym 0.21/0.26 and gymnasium
env = _patch_env(env)
if not is_wrapped(env, Monitor) and monitor_wrapper:
if verbose >= 1:
print("Wrapping the env with a `Monitor` wrapper")
env = Monitor(env)
if verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
if not is_vecenv_wrapped(env, VecTransposeImage):
wrap_with_vectranspose = False
if isinstance(env.observation_space, spaces.Dict):
# If even one of the keys is a image-space in need of transpose, apply transpose
# If the image spaces are not consistent (for instance one is channel first,
# the other channel last), VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space # type: ignore[arg-type]
)
if wrap_with_vectranspose:
if verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)
return env
@abstractmethod
def _setup_model(self) -> None:
"""Create networks, buffer and optimizers."""
def set_logger(self, logger: Logger) -> None:
"""
Setter for for logger object.
.. warning::
When passing a custom logger object,
this will overwrite ``tensorboard_log`` and ``verbose`` settings
passed to the constructor.
"""
self._logger = logger
# User defined logger
self._custom_logger = True
@property
def logger(self) -> Logger:
"""Getter for the logger object."""
return self._logger
def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
"""
Compute current progress remaining (starts from 1 and ends to 0)
:param num_timesteps: current number of timesteps
:param total_timesteps:
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
:param optimizers:
An optimizer or a list of optimizers.
"""
# Log the current learning rate
self.logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
if not isinstance(optimizers, list):
optimizers = [optimizers]
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
def _excluded_save_params(self) -> List[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
as they take up a lot of space. PyTorch variables should be excluded
with this so they can be stored with ``th.save``.
:return: List of parameters that should be excluded from being saved with pickle.
"""
return [
"policy",
"device",
"env",
"replay_buffer",
"rollout_buffer",
"_vec_normalize_env",
"_episode_storage",
"_logger",
"_custom_logger",
]
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
"""
Get a policy class from its name representation.
The goal here is to standardize policy naming, e.g.
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
and they receive respective policies that work for them.
:param policy_name: Alias of the policy
:return: A policy class (type)
"""
if policy_name in self.policy_aliases:
return self.policy_aliases[policy_name]
else:
raise ValueError(f"Policy {policy_name} unknown")
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
pickling strategy. This is to handle device placement correctly.
Names can point to specific variables under classes, e.g.
"policy.optimizer" would point to ``optimizer`` object of ``self.policy``
if this object.
:return:
List of Torch variables whose state dicts to save (e.g. th.nn.Modules),
and list of other Torch variables to store with ``th.save``.
"""
state_dicts = ["policy"]
return state_dicts, []
def _init_callback(
self,
callback: MaybeCallback,
progress_bar: bool = False,
) -> BaseCallback:
"""
:param callback: Callback(s) called at every step with state of the algorithm.
:param progress_bar: Display a progress bar using tqdm and rich.
:return: A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
callback = CallbackList(callback)
# Convert functional callback to object
if not isinstance(callback, BaseCallback):
callback = ConvertCallback(callback)
# Add progress bar callback
if progress_bar:
callback = CallbackList([callback, ProgressBarCallback()])
callback.init_callback(self)
return callback
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: Callback(s) called at every step with state of the algorithm.
:param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: the name of the run for tensorboard log
:param progress_bar: Display a progress bar using tqdm and rich.
:return: Total timesteps and callback(s)
"""
self.start_time = time.time_ns()
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=self._stats_window_size)
self.ep_success_buffer = deque(maxlen=self._stats_window_size)
if self.action_noise is not None:
self.action_noise.reset()
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps
self._total_timesteps = total_timesteps
self._num_timesteps_at_start = self.num_timesteps
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
assert self.env is not None
self._last_obs = self.env.reset() # type: ignore[assignment]
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
self._last_original_obs = self._vec_normalize_env.get_original_obs()
# Configure logger's outputs if no logger was passed
if not self._custom_logger:
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
# Create eval callback if needed
callback = self._init_callback(callback, progress_bar)
return total_timesteps, callback
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
:param infos: List of additional information about the transition.
:param dones: Termination signals
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
maybe_ep_info = info.get("episode")
maybe_is_success = info.get("is_success")
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
if maybe_is_success is not None and dones[idx]:
self.ep_success_buffer.append(maybe_is_success)
def get_env(self) -> Optional[VecEnv]:
"""
Returns the current environment (can be None if not defined).
:return: The current environment
"""
return self.env
def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
Return the ``VecNormalize`` wrapper of the training env
if it exists.
:return: The ``VecNormalize`` env.
"""
return self._vec_normalize_env
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
checked parameters:
- observation_space
- action_space
:param env: The environment for learning a policy
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
"""
# if it is not a VecEnv, make it a VecEnv
# and do other transformations (dict obs, image transpose) if needed
env = self._wrap_env(env, self.verbose)
assert env.num_envs == self.n_envs, (
"The number of environments to be set is different from the number of environments in the model: "
f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with "
f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead"
)
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Update VecNormalize object
# otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637
self._vec_normalize_env = unwrap_vec_normalize(env)
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset:
self._last_obs = None
self.n_envs = env.num_envs
self.env = env
@abstractmethod
def learn(
self: SelfBaseAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 100,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfBaseAlgorithm:
"""
Return a trained model.
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of
training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging;
for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before
logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param progress_bar: Display a progress bar using tqdm and rich.
:return: the trained model
"""
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic)
def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
Set the seed of the pseudo-random generators
(python, numpy, pytorch, gym, action_space)
:param seed:
"""
if seed is None:
return
set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type)
self.action_space.seed(seed)
# self.env is always a VecEnv
if self.env is not None:
self.env.seed(seed)
def set_parameters(
self,
load_path_or_dict: Union[str, TensorDict],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
"""
Load parameters from a given zip-file or a nested dictionary containing parameters for
different modules (see ``get_parameters``).
:param load_path_or_iter: Location of the saved data (path or file-like, see ``save``), or a nested
dictionary containing nn.Module parameters used by the policy. The dictionary maps
object names to a state-dictionary returned by ``torch.nn.Module.state_dict()``.
:param exact_match: If True, the given parameters should include parameters for each
module and each of their parameters, otherwise raises an Exception. If set to False, this
can be used to update only specific parameters.
:param device: Device on which the code should run.
"""
params = {}
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
_, params, _ = load_from_zip_file(load_path_or_dict, device=device)
# Keep track which objects were updated.
# `_get_torch_save_params` returns [params, other_pytorch_variables].
# We are only interested in former here.
objects_needing_update = set(self._get_torch_save_params()[0])
updated_objects = set()
for name in params:
attr = None
try:
attr = recursive_getattr(self, name)
except Exception as e:
# What errors recursive_getattr could throw? KeyError, but
# possible something else too (e.g. if key is an int?).
# Catch anything for now.
raise ValueError(f"Key {name} is an invalid object name.") from e
if isinstance(attr, th.optim.Optimizer):
# Optimizers do not support "strict" keyword...
# Seems like they will just replace the whole
# optimizer state with the given one.
# On top of this, optimizer state-dict
# seems to change (e.g. first ``optim.step()``),
# which makes comparing state dictionary keys
# invalid (there is also a nesting of dictionaries
# with lists with dictionaries with ...), adding to the
# mess.
#
# TL;DR: We might not be able to reliably say
# if given state-dict is missing keys.
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name]) # type: ignore[arg-type]
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
updated_objects.add(name)
if exact_match and updated_objects != objects_needing_update:
raise ValueError(
"Names of parameters do not match agents' parameters: "
f"expected {objects_needing_update}, got {updated_objects}"
)
@classmethod
def load( # noqa: C901
cls: Type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> SelfBaseAlgorithm:
"""
Load the model from a zip-file.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
For an in-place load use ``set_parameters`` instead.
:param path: path to the file (or a file-like) where to
load the agent from
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param print_system_info: Whether to print system info from the saved model
and the current system info (useful to debug loading issues)
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See https://github.com/DLR-RM/stable-baselines3/issues/597
:param kwargs: extra arguments to change the model when loading
:return: new model instance with loaded parameters
"""
if print_system_info:
print("== CURRENT SYSTEM INFO ==")
get_system_info()
data, params, pytorch_variables = load_from_zip_file(
path,
device=device,
custom_objects=custom_objects,
print_system_info=print_system_info,
)
assert data is not None, "No data found in the saved file"
assert params is not None, "No params found in the saved file"
# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
del data["policy_kwargs"]["device"]
# backward compatibility, convert to new format
if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
saved_net_arch = data["policy_kwargs"]["net_arch"]
if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
raise ValueError(
f"The specified policy kwargs do not equal the stored policy kwargs."
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
)
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# Gym -> Gymnasium space conversion
for key in {"observation_space", "action_space"}:
data[key] = _convert_space(data[key])
if env is not None:
# Wrap first if needed
env = cls._wrap_env(env, data["verbose"])
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset and data is not None:
data["_last_obs"] = None
# `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018
if data is not None:
data["n_envs"] = env.num_envs
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
env = data["env"]
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # type: ignore[call-arg]
)
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
model._setup_model()
try:
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load Policy saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
# Skip if PyTorch variable was not defined (to ensure backward compatibility).
# This happens when using SAC/TQC.
# SAC has an entropy coefficient which can be fixed or optimized.
# If it is optimized, an additional PyTorch variable `log_ent_coef` is defined,
# otherwise it is initialized to `None`.
if pytorch_variables[name] is None:
continue
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, f"{name}.data", pytorch_variables[name].data)
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # type: ignore[operator]
return model
def get_parameters(self) -> Dict[str, Dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).
:return: Mapping of from names of the objects to PyTorch state-dicts.
"""
state_dicts_names, _ = self._get_torch_save_params()
params = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Retrieve state dict
params[name] = attr.state_dict()
return params
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
:param path: path to the file where the rl agent should be saved
:param exclude: name of parameters that should be excluded in addition to the default ones
:param include: name of parameters that might be excluded but should be included anyway
"""
# Copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()
# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = []
exclude = set(exclude).union(self._excluded_save_params())
# Do not exclude params if they are specifically included
if include is not None:
exclude = exclude.difference(include)
state_dicts_names, torch_variable_names = self._get_torch_save_params()
all_pytorch_variables = state_dicts_names + torch_variable_names
for torch_var in all_pytorch_variables:
# We need to get only the name of the top most module as we'll remove that
var_name = torch_var.split(".")[0]
# Any params that are in the save vars must not be saved by data
exclude.add(var_name)
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
data.pop(param_name, None)
# Build dict of torch variables
pytorch_variables = None
if torch_variable_names is not None:
pytorch_variables = {}
for name in torch_variable_names:
attr = recursive_getattr(self, name)
pytorch_variables[name] = attr
# Build dict of state_dicts
params_to_save = self.get_parameters()
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)

View File

@ -0,0 +1,839 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictReplayBufferSamples,
DictRolloutBufferSamples,
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
class BaseBuffer(ABC):
"""
Base class that represent a buffer (rollout or replay)
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
to which the values will be converted
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Space
obs_shape: Tuple[int, ...]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
super().__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
self.action_space = action_space
self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment]
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = get_device(device)
self.n_envs = n_envs
@staticmethod
def swap_and_flatten(arr: np.ndarray) -> np.ndarray:
"""
Swap and then flatten axes 0 (buffer_size) and 1 (n_envs)
to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features)
to [n_steps * n_envs, ...] (which maintain the order)
:param arr:
:return:
"""
shape = arr.shape
if len(shape) < 3:
shape = (*shape, 1)
return arr.swapaxes(0, 1).reshape(shape[0] * shape[1], *shape[2:])
def size(self) -> int:
"""
:return: The current size of the buffer
"""
if self.full:
return self.buffer_size
return self.pos
def add(self, *args, **kwargs) -> None:
"""
Add elements to the buffer.
"""
raise NotImplementedError()
def extend(self, *args, **kwargs) -> None:
"""
Add a new batch of transitions to the buffer
"""
# Do a for loop along the batch axis
for data in zip(*args):
self.add(*data)
def reset(self) -> None:
"""
Reset the buffer.
"""
self.pos = 0
self.full = False
def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
"""
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
upper_bound = self.buffer_size if self.full else self.pos
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
@abstractmethod
def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> Union[ReplayBufferSamples, RolloutBufferSamples]:
"""
:param batch_inds:
:param env:
:return:
"""
raise NotImplementedError()
def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
"""
Convert a numpy array to a PyTorch tensor.
Note: it copies the data by default
:param array:
:param copy: Whether to copy or not the data (may be useful to avoid changing things
by reference). This argument is inoperative if the device is not the CPU.
:return:
"""
if copy:
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)
@staticmethod
def _normalize_obs(
obs: Union[np.ndarray, Dict[str, np.ndarray]],
env: Optional[VecNormalize] = None,
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
if env is not None:
return env.normalize_obs(obs)
return obs
@staticmethod
def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
if env is not None:
return env.normalize_reward(reward).astype(np.float32)
return reward
class ReplayBuffer(BaseBuffer):
"""
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
Cannot be used in combination with handle_timeout_termination.
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
observations: np.ndarray
next_observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
dones: np.ndarray
timeouts: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
# Adjust buffer size
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
# there is a bug if both optimize_memory_usage and handle_timeout_termination are true
# see https://github.com/DLR-RM/stable-baselines3/issues/934
if optimize_memory_usage and handle_timeout_termination:
raise ValueError(
"ReplayBuffer does not support optimize_memory_usage = True "
"and handle_timeout_termination = True simultaneously."
)
self.optimize_memory_usage = optimize_memory_usage
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
if not optimize_memory_usage:
# When optimizing memory, `observations` contains also the next observation
self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
total_memory_usage: float = (
self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
)
if not optimize_memory_usage:
total_memory_usage += self.next_observations.nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs)
if self.optimize_memory_usage:
self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
else:
self.next_observations[self.pos] = np.array(next_obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
as we should not sample the element with index `self.pos`
See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
# Do not sample the element with index `self.pos` as the transitions is invalid
# (we use only one array to store `obs` and `next_obs`)
if self.full:
batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
else:
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
else:
next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
data = (
self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
self.actions[batch_inds, env_indices, :],
next_obs,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
"""
Cast `np.float64` action datatype to `np.float32`,
keep the others dtype unchanged.
See GH#1572 for more information.
:param dtype: The original action space dtype
:return: ``np.float32`` if the dtype was float64,
the original dtype otherwise.
"""
if dtype == np.float64:
return np.float32
return dtype
class RolloutBuffer(BaseBuffer):
"""
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> RolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
class DictReplayBuffer(ReplayBuffer):
"""
Dict Replay buffer used in off-policy algorithms like SAC/TD3.
Extends the ReplayBuffer to use dictionary observations
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
:param handle_timeout_termination: Handle timeout termination (due to timelimit)
separately and treat the task as infinite horizon task.
https://github.com/DLR-RM/stable-baselines3/issues/284
"""
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
next_observations: Dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
):
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only"
self.buffer_size = max(buffer_size // n_envs, 1)
# Check that the replay buffer can fit into the memory
if psutil is not None:
mem_available = psutil.virtual_memory().available
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage
self.observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.next_observations = {
key: np.zeros((self.buffer_size, self.n_envs, *_obs_shape), dtype=observation_space[key].dtype)
for key, _obs_shape in self.obs_shape.items()
}
self.actions = np.zeros(
(self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
# Handle timeouts termination properly if needed
# see https://github.com/DLR-RM/stable-baselines3/issues/284
self.handle_timeout_termination = handle_timeout_termination
self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
obs_nbytes = 0
for _, obs in self.observations.items():
obs_nbytes += obs.nbytes
total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if not optimize_memory_usage:
next_obs_nbytes = 0
for _, obs in self.observations.items():
next_obs_nbytes += obs.nbytes
total_memory_usage += next_obs_nbytes
if total_memory_usage > mem_available:
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
warnings.warn(
"This system does not have apparently enough memory to store the complete "
f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
)
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
next_obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
# Copy to avoid modification by reference
for key in self.observations.keys():
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs[key] = obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = np.array(obs[key])
for key in self.next_observations.keys():
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key])
self.next_observations[key][self.pos] = np.array(next_obs[key])
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.dones[self.pos] = np.array(done)
if self.handle_timeout_termination:
self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
self.pos = 0
def sample( # type: ignore[override]
self,
batch_size: int,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples:
"""
Sample elements from the replay buffer.
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
"""
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictReplayBufferSamples:
# Sample randomly the env idx
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
# Normalize if needed and remove extra dimension (we are using only one env for now)
obs_ = self._normalize_obs({key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items()}, env)
next_obs_ = self._normalize_obs(
{key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env
)
assert isinstance(obs_, dict)
assert isinstance(next_obs_, dict)
# Convert to torch tensor
observations = {key: self.to_torch(obs) for key, obs in obs_.items()}
next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()}
return DictReplayBufferSamples(
observations=observations,
actions=self.to_torch(self.actions[batch_inds, env_indices]),
next_observations=next_observations,
# Only use dones that are not due to timeouts
# deactivated by default (timeouts is initialized as an array of False)
dones=self.to_torch(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(
-1, 1
),
rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env)),
)
class DictRolloutBuffer(RolloutBuffer):
"""
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = {}
for key, obs_input_shape in self.obs_shape.items():
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
for key in self.observations.keys():
obs_ = np.array(obs[key])
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get( # type: ignore[override]
self,
batch_size: Optional[int] = None,
) -> Generator[DictRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> DictRolloutBufferSamples:
return DictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
)

View File

@ -0,0 +1,709 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import gymnasium as gym
import numpy as np
from stable_baselines3.common.logger import Logger
try:
from tqdm import TqdmExperimentalWarning
# Remove experimental warning
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
from tqdm.rich import tqdm
except ImportError:
# Rich not installed, we only throw an error
# if the progress bar is used
tqdm = None
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
if TYPE_CHECKING:
from stable_baselines3.common import base_class
class BaseCallback(ABC):
"""
Base class for callback.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
# The RL model
# Type hint as string to avoid circular import
model: "base_class.BaseAlgorithm"
def __init__(self, verbose: int = 0):
super().__init__()
# Number of time the callback was called
self.n_calls = 0 # type: int
# n_envs * n times env.step() was called
self.num_timesteps = 0 # type: int
self.verbose = verbose
self.locals: Dict[str, Any] = {}
self.globals: Dict[str, Any] = {}
# Sometimes, for event callback, it is useful
# to have access to the parent object
self.parent = None # type: Optional[BaseCallback]
@property
def training_env(self) -> VecEnv:
training_env = self.model.get_env()
assert (
training_env is not None
), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks"
return training_env
@property
def logger(self) -> Logger:
return self.model.logger
# Type hint as string to avoid circular import
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
"""
self.model = model
self._init_callback()
def _init_callback(self) -> None:
pass
def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
# Those are reference and will be updated automatically
self.locals = locals_
self.globals = globals_
# Update num_timesteps in case training was done before
self.num_timesteps = self.model.num_timesteps
self._on_training_start()
def _on_training_start(self) -> None:
pass
def on_rollout_start(self) -> None:
self._on_rollout_start()
def _on_rollout_start(self) -> None:
pass
@abstractmethod
def _on_step(self) -> bool:
"""
:return: If the callback returns False, training is aborted early.
"""
return True
def on_step(self) -> bool:
"""
This method will be called by the model after each call to ``env.step()``.
For child callback (of an ``EventCallback``), this will be called
when the event is triggered.
:return: If the callback returns False, training is aborted early.
"""
self.n_calls += 1
self.num_timesteps = self.model.num_timesteps
return self._on_step()
def on_training_end(self) -> None:
self._on_training_end()
def _on_training_end(self) -> None:
pass
def on_rollout_end(self) -> None:
self._on_rollout_end()
def _on_rollout_end(self) -> None:
pass
def update_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
self.locals.update(locals_)
self.update_child_locals(locals_)
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables on sub callbacks.
:param locals_: the local variables during rollout collection
"""
pass
class EventCallback(BaseCallback):
"""
Base class for triggering callback on event.
:param callback: Callback that will be called
when an event is triggered.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super().__init__(verbose=verbose)
self.callback = callback
# Give access to the parent
if callback is not None:
assert self.callback is not None
self.callback.parent = self
def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
super().init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
def _on_training_start(self) -> None:
if self.callback is not None:
self.callback.on_training_start(self.locals, self.globals)
def _on_event(self) -> bool:
if self.callback is not None:
return self.callback.on_step()
return True
def _on_step(self) -> bool:
return True
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback is not None:
self.callback.update_locals(locals_)
class CallbackList(BaseCallback):
"""
Class for chaining callbacks.
:param callbacks: A list of callbacks that will be called
sequentially.
"""
def __init__(self, callbacks: List[BaseCallback]):
super().__init__()
assert isinstance(callbacks, list)
self.callbacks = callbacks
def _init_callback(self) -> None:
for callback in self.callbacks:
callback.init_callback(self.model)
def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)
def _on_rollout_start(self) -> None:
for callback in self.callbacks:
callback.on_rollout_start()
def _on_step(self) -> bool:
continue_training = True
for callback in self.callbacks:
# Return False (stop training) if at least one callback returns False
continue_training = callback.on_step() and continue_training
return continue_training
def _on_rollout_end(self) -> None:
for callback in self.callbacks:
callback.on_rollout_end()
def _on_training_end(self) -> None:
for callback in self.callbacks:
callback.on_training_end()
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
for callback in self.callbacks:
callback.update_locals(locals_)
class CheckpointCallback(BaseCallback):
"""
Callback for saving a model every ``save_freq`` calls
to ``env.step()``.
By default, it only saves model checkpoints,
you need to pass ``save_replay_buffer=True``,
and ``save_vecnormalize=True`` to also save replay buffer checkpoints
and normalization statistics checkpoints.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``save_freq = max(save_freq // n_envs, 1)``
:param save_freq: Save checkpoints every ``save_freq`` call of the callback.
:param save_path: Path to the folder where the model will be saved.
:param name_prefix: Common prefix to the saved models
:param save_replay_buffer: Save the model replay buffer
:param save_vecnormalize: Save the ``VecNormalize`` statistics
:param verbose: Verbosity level: 0 for no output, 2 for indicating when saving model checkpoint
"""
def __init__(
self,
save_freq: int,
save_path: str,
name_prefix: str = "rl_model",
save_replay_buffer: bool = False,
save_vecnormalize: bool = False,
verbose: int = 0,
):
super().__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
self.save_replay_buffer = save_replay_buffer
self.save_vecnormalize = save_vecnormalize
def _init_callback(self) -> None:
# Create folder if needed
if self.save_path is not None:
os.makedirs(self.save_path, exist_ok=True)
def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
"""
Helper to get checkpoint path for each type of checkpoint.
:param checkpoint_type: empty for the model, "replay_buffer_"
or "vecnormalize_" for the other checkpoints.
:param extension: Checkpoint file extension (zip for model, pkl for others)
:return: Path to the checkpoint
"""
return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")
def _on_step(self) -> bool:
if self.n_calls % self.save_freq == 0:
model_path = self._checkpoint_path(extension="zip")
self.model.save(model_path)
if self.verbose >= 2:
print(f"Saving model checkpoint to {model_path}")
if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None:
# If model has a replay buffer, save it too
replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl")
self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined]
if self.verbose > 1:
print(f"Saving model replay buffer checkpoint to {replay_buffer_path}")
if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None:
# Save the VecNormalize statistics
vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl")
self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr]
if self.verbose >= 2:
print(f"Saving model VecNormalize to {vec_normalize_path}")
return True
class ConvertCallback(BaseCallback):
"""
Convert functional callback (old-style) to object.
:param callback:
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
super().__init__(verbose)
self.callback = callback
def _on_step(self) -> bool:
if self.callback is not None:
return self.callback(self.locals, self.globals)
return True
class EvalCallback(EventCallback):
"""
Callback for evaluating an agent.
.. warning::
When using multiple environments, each call to ``env.step()``
will effectively correspond to ``n_envs`` steps.
To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)``
:param eval_env: The environment used for initialization
:param callback_on_new_best: Callback to trigger
when there is a new best model according to the ``mean_reward``
:param callback_after_eval: Callback to trigger after every evaluation
:param n_eval_episodes: The number of episodes to test the agent
:param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback.
:param log_path: Path to a folder where the evaluations (``evaluations.npz``)
will be saved. It will be updated at each evaluation.
:param best_model_save_path: Path to a folder where the best model
according to performance on the eval env will be saved.
:param deterministic: Whether the evaluation should
use a stochastic or deterministic actions.
:param render: Whether to render or not the environment during evaluation
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about evaluation results
:param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been
wrapped with a Monitor wrapper)
"""
def __init__(
self,
eval_env: Union[gym.Env, VecEnv],
callback_on_new_best: Optional[BaseCallback] = None,
callback_after_eval: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: Optional[str] = None,
best_model_save_path: Optional[str] = None,
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
warn: bool = True,
):
super().__init__(callback_after_eval, verbose=verbose)
self.callback_on_new_best = callback_on_new_best
if self.callback_on_new_best is not None:
# Give access to the parent
self.callback_on_new_best.parent = self
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
self.best_mean_reward = -np.inf
self.last_mean_reward = -np.inf
self.deterministic = deterministic
self.render = render
self.warn = warn
# Convert to VecEnv for consistency
if not isinstance(eval_env, VecEnv):
eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value]
self.eval_env = eval_env
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
if log_path is not None:
log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results: List[List[float]] = []
self.evaluations_timesteps: List[int] = []
self.evaluations_length: List[List[int]] = []
# For computing success rate
self._is_success_buffer: List[bool] = []
self.evaluations_successes: List[List[bool]] = []
def _init_callback(self) -> None:
# Does not work in some corner cases, where the wrapper is not the same
if not isinstance(self.training_env, type(self.eval_env)):
warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
# Create folders if needed
if self.best_model_save_path is not None:
os.makedirs(self.best_model_save_path, exist_ok=True)
if self.log_path is not None:
os.makedirs(os.path.dirname(self.log_path), exist_ok=True)
# Init callback called on new best model
if self.callback_on_new_best is not None:
self.callback_on_new_best.init_callback(self.model)
def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None:
"""
Callback passed to the ``evaluate_policy`` function
in order to log the success rate (when applicable),
for instance when using HER.
:param locals_:
:param globals_:
"""
info = locals_["info"]
if locals_["done"]:
maybe_is_success = info.get("is_success")
if maybe_is_success is not None:
self._is_success_buffer.append(maybe_is_success)
def _on_step(self) -> bool:
continue_training = True
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError as e:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
) from e
# Reset success rate buffer
self._is_success_buffer = []
episode_rewards, episode_lengths = evaluate_policy(
self.model,
self.eval_env,
n_eval_episodes=self.n_eval_episodes,
render=self.render,
deterministic=self.deterministic,
return_episode_rewards=True,
warn=self.warn,
callback=self._log_success_callback,
)
if self.log_path is not None:
assert isinstance(episode_rewards, list)
assert isinstance(episode_lengths, list)
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
kwargs = {}
# Save success log if present
if len(self._is_success_buffer) > 0:
self.evaluations_successes.append(self._is_success_buffer)
kwargs = dict(successes=self.evaluations_successes)
np.savez(
self.log_path,
timesteps=self.evaluations_timesteps,
results=self.evaluations_results,
ep_lengths=self.evaluations_length,
**kwargs,
)
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = float(mean_reward)
if self.verbose >= 1:
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
self.logger.record("eval/mean_reward", float(mean_reward))
self.logger.record("eval/mean_ep_length", mean_ep_length)
if len(self._is_success_buffer) > 0:
success_rate = np.mean(self._is_success_buffer)
if self.verbose >= 1:
print(f"Success rate: {100 * success_rate:.2f}%")
self.logger.record("eval/success_rate", success_rate)
# Dump log so the evaluation results are printed with the correct timestep
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(self.num_timesteps)
if mean_reward > self.best_mean_reward:
if self.verbose >= 1:
print("New best mean reward!")
if self.best_model_save_path is not None:
self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = float(mean_reward)
# Trigger callback on new best model, if needed
if self.callback_on_new_best is not None:
continue_training = self.callback_on_new_best.on_step()
# Trigger callback after every evaluation, if needed
if self.callback is not None:
continue_training = continue_training and self._on_event()
return continue_training
def update_child_locals(self, locals_: Dict[str, Any]) -> None:
"""
Update the references to the local variables.
:param locals_: the local variables during rollout collection
"""
if self.callback:
self.callback.update_locals(locals_)
class StopTrainingOnRewardThreshold(BaseCallback):
"""
Stop the training once a threshold in episodic reward
has been reached (i.e. when the model is good enough).
It must be used with the ``EvalCallback``.
:param reward_threshold: Minimum expected reward per episode
to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because episodic reward
threshold reached
"""
parent: EvalCallback
def __init__(self, reward_threshold: float, verbose: int = 0):
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
f" is above the threshold {self.reward_threshold}"
)
return continue_training
class EveryNTimesteps(EventCallback):
"""
Trigger a callback every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
:param callback: Callback that will be called
when the event is triggered.
"""
def __init__(self, n_steps: int, callback: BaseCallback):
super().__init__(callback)
self.n_steps = n_steps
self.last_time_trigger = 0
def _on_step(self) -> bool:
if (self.num_timesteps - self.last_time_trigger) >= self.n_steps:
self.last_time_trigger = self.num_timesteps
return self._on_event()
return True
class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.
For multiple environments presumes that, the desired behavior is that the agent trains on each env for ``max_episodes``
and in total for ``max_episodes * n_envs`` episodes.
:param max_episodes: Maximum number of episodes to stop training.
:param verbose: Verbosity level: 0 for no output, 1 for indicating information about when training ended by
reaching ``max_episodes``
"""
def __init__(self, max_episodes: int, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_episodes = max_episodes
self._total_max_episodes = max_episodes
self.n_episodes = 0
def _init_callback(self) -> None:
# At start set total max according to number of envirnments
self._total_max_episodes = self.max_episodes * self.training_env.num_envs
def _on_step(self) -> bool:
# Check that the `dones` local variable is defined
assert "dones" in self.locals, "`dones` variable is not defined, please check your code next to `callback.on_step()`"
self.n_episodes += np.sum(self.locals["dones"]).item()
continue_training = self.n_episodes < self._total_max_episodes
if self.verbose >= 1 and not continue_training:
mean_episodes_per_env = self.n_episodes / self.training_env.num_envs
mean_ep_str = (
f"with an average of {mean_episodes_per_env:.2f} episodes per env" if self.training_env.num_envs > 1 else ""
)
print(
f"Stopping training with a total of {self.num_timesteps} steps because the "
f"{self.locals.get('tb_log_name')} model reached max_episodes={self.max_episodes}, "
f"by playing for {self.n_episodes} episodes "
f"{mean_ep_str}"
)
return continue_training
class StopTrainingOnNoModelImprovement(BaseCallback):
"""
Stop the training early if there is no new best model (new best mean reward) after more than N consecutive evaluations.
It is possible to define a minimum number of evaluations before start to count evaluations without improvement.
It must be used with the ``EvalCallback``.
:param max_no_improvement_evals: Maximum number of consecutive evaluations without a new best model.
:param min_evals: Number of evaluations before start to count evaluations without improvements.
:param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model
"""
parent: EvalCallback
def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0):
super().__init__(verbose=verbose)
self.max_no_improvement_evals = max_no_improvement_evals
self.min_evals = min_evals
self.last_best_mean_reward = -np.inf
self.no_improvement_evals = 0
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnNoModelImprovement`` callback must be used with an ``EvalCallback``"
continue_training = True
if self.n_calls > self.min_evals:
if self.parent.best_mean_reward > self.last_best_mean_reward:
self.no_improvement_evals = 0
else:
self.no_improvement_evals += 1
if self.no_improvement_evals > self.max_no_improvement_evals:
continue_training = False
self.last_best_mean_reward = self.parent.best_mean_reward
if self.verbose >= 1 and not continue_training:
print(
f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
)
return continue_training
class ProgressBarCallback(BaseCallback):
"""
Display a progress bar when training SB3 agent
using tqdm and rich packages.
"""
pbar: tqdm
def __init__(self) -> None:
super().__init__()
if tqdm is None:
raise ImportError(
"You must install tqdm and rich in order to use the progress bar callback. "
"It is included if you install stable-baselines with the extra packages: "
"`pip install stable-baselines3[extra]`"
)
def _on_training_start(self) -> None:
# Initialize progress bar
# Remove timesteps that were done in previous training sessions
self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)
def _on_step(self) -> bool:
# Update progress bar, we do num_envs steps per call to `env.step()`
self.pbar.update(self.training_env.num_envs)
return True
def _on_training_end(self) -> None:
# Flush and close progress bar
self.pbar.refresh()
self.pbar.close()

View File

@ -0,0 +1,723 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim
SelfDistribution = TypeVar("SelfDistribution", bound="Distribution")
SelfDiagGaussianDistribution = TypeVar("SelfDiagGaussianDistribution", bound="DiagGaussianDistribution")
SelfSquashedDiagGaussianDistribution = TypeVar(
"SelfSquashedDiagGaussianDistribution", bound="SquashedDiagGaussianDistribution"
)
SelfCategoricalDistribution = TypeVar("SelfCategoricalDistribution", bound="CategoricalDistribution")
SelfMultiCategoricalDistribution = TypeVar("SelfMultiCategoricalDistribution", bound="MultiCategoricalDistribution")
SelfBernoulliDistribution = TypeVar("SelfBernoulliDistribution", bound="BernoulliDistribution")
SelfStateDependentNoiseDistribution = TypeVar("SelfStateDependentNoiseDistribution", bound="StateDependentNoiseDistribution")
class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
super().__init__()
self.distribution = None
@abstractmethod
def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]:
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""
@abstractmethod
def proba_distribution(self: SelfDistribution, *args, **kwargs) -> SelfDistribution:
"""Set parameters of the distribution.
:return: self
"""
@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
Returns the log likelihood
:param x: the taken action
:return: The log likelihood of the distribution
"""
@abstractmethod
def entropy(self) -> Optional[th.Tensor]:
"""
Returns Shannon's entropy of the probability
:return: the entropy, or None if no analytical form is known
"""
@abstractmethod
def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution
:return: the stochastic action
"""
@abstractmethod
def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
from the probability distribution
:return: the stochastic action
"""
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
Return actions according to the probability distribution.
:param deterministic:
:return:
"""
if deterministic:
return self.mode()
return self.sample()
@abstractmethod
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
given its parameters.
:return: actions
"""
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
from the probability distribution given its parameters.
:return: actions and log prob
"""
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: shape: (n_batch, n_actions) or (n_batch,)
:return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
"""
if len(tensor.shape) > 1:
tensor = tensor.sum(dim=1)
else:
tensor = tensor.sum()
return tensor
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix, for continuous actions.
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(
self: SelfDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfDiagGaussianDistribution:
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:return:
"""
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> Optional[th.Tensor]:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions)
return actions, log_prob
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
:param action_dim: Dimension of the action space.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, action_dim: int, epsilon: float = 1e-6):
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions: Optional[th.Tensor] = None
def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
) -> SelfSquashedDiagGaussianDistribution:
super().proba_distribution(mean_actions, log_std)
return self
def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
if gaussian_actions is None:
# It will be clipped to avoid NaN when inversing tanh
gaussian_actions = TanhBijector.inverse(actions)
# Log likelihood for a Gaussian distribution
log_prob = super().log_prob(gaussian_actions)
# Squash correction (from original SAC implementation)
# this comes from the fact that tanh is bijective and differentiable
log_prob -= th.sum(th.log(1 - actions**2 + self.epsilon), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = super().sample()
return th.tanh(self.gaussian_actions)
def mode(self) -> th.Tensor:
self.gaussian_actions = super().mode()
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
class CategoricalDistribution(Distribution):
"""
Categorical distribution for discrete actions.
:param action_dim: Number of discrete actions
"""
def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Categorical distribution.
You can then get probabilities using a softmax.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
def proba_distribution(self: SelfCategoricalDistribution, action_logits: th.Tensor) -> SelfCategoricalDistribution:
self.distribution = Categorical(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class MultiCategoricalDistribution(Distribution):
"""
MultiCategorical distribution for multi discrete actions.
:param action_dims: List of sizes of discrete action spaces
"""
def __init__(self, action_dims: List[int]):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits (flattened) of the MultiCategorical distribution.
You can then get probabilities using a softmax on each sub-space.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack(
[dist.log_prob(action) for dist, action in zip(self.distribution, th.unbind(actions, dim=1))], dim=1
).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distribution], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distribution], dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distribution], dim=1)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class BernoulliDistribution(Distribution):
"""
Bernoulli distribution for MultiBinary action spaces.
:param action_dim: Number of binary actions
"""
def __init__(self, action_dims: int):
super().__init__()
self.action_dims = action_dims
def proba_distribution_net(self, latent_dim: int) -> nn.Module:
"""
Create the layer that represents the distribution:
it will be the logits of the Bernoulli distribution.
:param latent_dim: Dimension of the last layer
of the policy network (before the action layer)
:return:
"""
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
def proba_distribution(self: SelfBernoulliDistribution, action_logits: th.Tensor) -> SelfBernoulliDistribution:
self.distribution = Bernoulli(logits=action_logits)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(action_logits)
log_prob = self.log_prob(actions)
return actions, log_prob
class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using generalized State Dependent Exploration (gSDE).
Paper: https://arxiv.org/abs/2005.05719
It is used to create the noise exploration matrix and
compute the log probability of an action with that noise.
:param action_dim: Dimension of the action space.
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this ensures bounds are satisfied.
:param learn_features: Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
bijector: Optional["TanhBijector"]
latent_sde_dim: Optional[int]
weights_dist: Normal
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor
def __init__(
self,
action_dim: int,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
learn_features: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.action_dim = action_dim
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
self.learn_features = learn_features
if squash_output:
self.bijector = TanhBijector(epsilon)
else:
self.bijector = None
def get_std(self, log_std: th.Tensor) -> th.Tensor:
"""
Get the standard deviation from the learned parameter
(log of it by default). This ensures that the std is positive.
:param log_std:
:return:
"""
if self.use_expln:
# From gSDE paper, it allows to keep variance
# above zero and prevent it from growing too fast
below_threshold = th.exp(log_std) * (log_std <= 0)
# Avoid NaN: zeros values that are below zero
safe_log_std = log_std * (log_std > 0) + self.epsilon
above_threshold = (th.log1p(safe_log_std) + 1.0) * (log_std > 0)
std = below_threshold + above_threshold
else:
# Use normal exponential
std = th.exp(log_std)
if self.full_std:
return std
assert self.latent_sde_dim is not None
# Reduce the number of parameters:
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std
def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
"""
Sample weights for the noise exploration matrix,
using a centered Gaussian distribution.
:param log_std:
:param batch_size:
"""
std = self.get_std(log_std)
self.weights_dist = Normal(th.zeros_like(std), std)
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.rsample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
def proba_distribution_net(
self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
standard deviation of the distribution that control the weights of the noise matrix.
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:param latent_sde_dim: Dimension of the last layer of the features extractor
for gSDE. By default, it is shared with the policy network.
:return:
"""
# Network for the deterministic action, it represents the mean of the distribution
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
# When we learn features for the noise, the feature dimension
# can be different between the policy and the noise network
self.latent_sde_dim = latent_dim if latent_sde_dim is None else latent_sde_dim
# Reduce the number of parameters if needed
log_std = th.ones(self.latent_sde_dim, self.action_dim) if self.full_std else th.ones(self.latent_sde_dim, 1)
# Transform it to a parameter so it can be optimized
log_std = nn.Parameter(log_std * log_std_init, requires_grad=True)
# Sample an exploration matrix
self.sample_weights(log_std)
return mean_actions_net, log_std
def proba_distribution(
self: SelfStateDependentNoiseDistribution, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> SelfStateDependentNoiseDistribution:
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:param latent_sde:
:return:
"""
# Stop gradient if we don't want to influence the features
self._latent_sde = latent_sde if self.learn_features else latent_sde.detach()
variance = th.mm(self._latent_sde**2, self.get_std(log_std) ** 2)
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
if self.bijector is not None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def mode(self) -> th.Tensor:
actions = self.distribution.mean
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return th.mm(latent_sde, self.exploration_mat)
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(dim=1)
def actions_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_sde)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(
self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(actions)
return actions, log_prob
class TanhBijector:
"""
Bijective transformation of a probability distribution
using a squashing function (tanh)
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""
def __init__(self, epsilon: float = 1e-6):
super().__init__()
self.epsilon = epsilon
@staticmethod
def forward(x: th.Tensor) -> th.Tensor:
return th.tanh(x)
@staticmethod
def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
@staticmethod
def inverse(y: th.Tensor) -> th.Tensor:
"""
Inverse tanh.
:param y:
:return:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
def make_proba_distribution(
action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, spaces.Box):
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor:
"""
Wrapper for the PyTorch implementation of the full form KL Divergence
:param dist_true: the p distribution
:param dist_pred: the q distribution
:return: KL(dist_true||dist_pred)
"""
# KL Divergence for different distribution types is out of scope
assert dist_true.__class__ == dist_pred.__class__, "Error: input distributions should be the same type"
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy
assert np.allclose(
dist_pred.action_dims, dist_true.action_dims
), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,
).sum(dim=1)
# Use the PyTorch kl_divergence implementation
else:
return th.distributions.kl_divergence(dist_true.distribution, dist_pred.distribution)

View File

@ -0,0 +1,485 @@
import warnings
from typing import Any, Dict, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
def _is_numpy_array_space(space: spaces.Space) -> bool:
"""
Returns False if provided space is not representable as a single numpy array
(e.g. Dict and Tuple spaces return False)
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))
def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
"""
Return False if a (Multi)Discrete space has a non-zero start.
"""
return np.allclose(space.start, np.zeros_like(space.start))
def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
"""
:param space: Observation or action space
:param space_type: information about whether it is an observation or action space
(for the warning message)
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
maybe_key = f"(key='{key}')" if key else ""
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
)
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
warnings.warn(
f"It seems that your observation {key} is an image but its `dtype` "
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
warnings.warn(
f"It seems that your observation space {key} is an image but the "
"upper and lower bounds are not in [0, 255]. "
"Because the CNN policy normalize automatically the observation "
"you may encounter issue if the values are not in that range."
)
non_channel_idx = 0
# Check only if width/height of the image is big enough
if is_image_space_channels_first(observation_space):
non_channel_idx = -1
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
warnings.warn(
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
"You might need to use a custom features extractor "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
)
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""Emit warnings when the observation space or action space used is not supported by Stable-Baselines."""
if isinstance(observation_space, spaces.Dict):
nested_dict = False
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
_check_non_zero_start(space, "observation", key)
if nested_dict:
warnings.warn(
"Nested observation spaces are not supported by Stable Baselines3 "
"(Dict spaces inside Dict space). "
"You should flatten it to have only one level of keys."
"For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` "
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
)
if isinstance(observation_space, spaces.Tuple):
warnings.warn(
"The observation space is a Tuple, "
"this is currently not supported by Stable Baselines3. "
"However, you can convert it to a Dict observation space "
"(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). "
"which is supported by SB3."
)
_check_non_zero_start(observation_space, "observation")
if isinstance(observation_space, spaces.Sequence):
warnings.warn(
"Sequence observation space is not supported by Stable-Baselines3. "
"You can pad your observation to have a fixed size instead.\n"
"Note: The checks for returned values are skipped."
)
_check_non_zero_start(action_space, "action")
if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
"This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the "
"action using a wrapper."
)
def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
vec_env.reset()
for _ in range(10):
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)
def _is_goal_env(env: gym.Env) -> bool:
"""
Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface)
"""
# We need to unwrap the env since gym.Wrapper has the compute_reward method
return hasattr(env.unwrapped, "compute_reward")
def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None:
"""
Check that an environment implementing the `compute_rewards()` method
(previously known as GoalEnv in gym) contains at least three elements,
namely `observation`, `achieved_goal`, and `desired_goal`.
"""
assert len(observation_space.spaces) >= 3, (
"A goal conditioned env must contain at least 3 observation keys: `observation`, `achieved_goal`, and `desired_goal`. "
f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
)
for key in ["achieved_goal", "desired_goal"]:
if key not in observation_space.spaces:
raise AssertionError(
f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
"key to be part of the observation dictionary. "
f"Current keys are {list(observation_space.spaces.keys())}"
)
def _check_goal_env_compute_reward(
obs: Dict[str, Union[np.ndarray, int]],
env: gym.Env,
reward: float,
info: Dict[str, Any],
) -> None:
"""
Check that reward is computed with `compute_reward`
and that the implementation is vectorized.
"""
achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
assert reward == env.compute_reward( # type: ignore[attr-defined]
achieved_goal, desired_goal, info
), "The reward was not computed with `compute_reward()`"
achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal)
batch_achieved_goals = np.array([achieved_goal, achieved_goal])
batch_desired_goals = np.array([desired_goal, desired_goal])
if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0:
batch_achieved_goals = batch_achieved_goals.reshape(2, 1)
batch_desired_goals = batch_desired_goals.reshape(2, 1)
batch_infos = np.array([info, info])
rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined]
assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)"
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"
def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
assert not isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
# The check for a GoalEnv is done by the base class
if isinstance(observation_space, spaces.Discrete):
# Since https://github.com/Farama-Foundation/Gymnasium/pull/141,
# `sample()` will return a np.int64 instead of an int
assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
# check obs dimensions, dtype and bounds
assert observation_space.shape == obs.shape, (
f"The observation returned by the `{method_name}()` method does not match the shape "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.shape}, actual shape: {obs.shape}"
)
assert np.can_cast(obs.dtype, observation_space.dtype), (
f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
lower_bounds, upper_bounds = observation_space.low, observation_space.high
# Expose all invalid indices at once
invalid_indices = np.where(np.logical_or(obs < lower_bounds, obs > upper_bounds))
if (obs > upper_bounds).any() or (obs < lower_bounds).any():
message = (
f"The observation returned by the `{method_name}()` method does not match the bounds "
f"of the given observation space {observation_space}. \n"
)
message += f"{len(invalid_indices[0])} invalid indices: \n"
for index in zip(*invalid_indices):
index_str = ",".join(map(str, index))
message += (
f"Expected: {lower_bounds[index]} <= obs[{index_str}] <= {upper_bounds[index]}, "
f"actual value: {obs[index]} \n"
)
raise AssertionError(message)
assert observation_space.contains(obs), (
f"The observation returned by the `{method_name}()` method "
f"does not match the given observation space {observation_space}"
)
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the observation space is correctly formatted
when dealing with a ``Box()`` space. In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not)
"""
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space, key)
if len(observation_space.shape) not in [1, 3]:
warnings.warn(
f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation "
"to have only a 1D vector or use a custom policy to properly process the data."
)
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Check the returned values by the env when calling `.reset()` or `.step()` methods.
"""
# because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists
reset_returns = env.reset()
assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)"
assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}"
obs, info = reset_returns
assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}"
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "reset")
elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary"
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `reset()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
# Sample a random action
action = action_space.sample()
data = env.step(action)
assert len(data) == 5, (
"The `step()` method must return five values: "
f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned."
)
# Unpack
obs, reward, terminated, truncated, info = data
if isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"
# Additional checks for GoalEnvs
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "step")
_check_goal_env_compute_reward(obs, env, float(reward), info)
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `step()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"
# Goal conditioned env
if _is_goal_env(env):
# for mypy, env.unwrapped was checked by _is_goal_env()
assert hasattr(env, "compute_reward")
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined and inherit from spaces.Space. For
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
the observation space is gymnasium.spaces.Dict
"""
gym_spaces = "cf. https://gymnasium.farama.org/api/spaces/"
assert hasattr(env, "observation_space"), f"You must specify an observation space ({gym_spaces})"
assert hasattr(env, "action_space"), f"You must specify an action space ({gym_spaces})"
assert isinstance(
env.observation_space, spaces.Space
), f"The observation space must inherit from gymnasium.spaces ({gym_spaces})"
assert isinstance(env.action_space, spaces.Space), f"The action space must inherit from gymnasium.spaces ({gym_spaces})"
if _is_goal_env(env):
print(
"We detected your env to be a GoalEnv because `env.compute_reward()` was defined.\n"
"If it's not the case, please rename `env.compute_reward()` to something else to avoid False positives."
)
assert isinstance(env.observation_space, spaces.Dict), (
"Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gymnasium.spaces.Dict.\n"
"Note: if your env is not a GoalEnv, please rename `env.compute_reward()` "
"to something else to avoid False positive."
)
# Check render cannot be covered by CI
def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
"""
Check the instantiated render mode (if any) by calling the `render()`/`close()`
method of the environment.
:param env: The environment to check
:param warn: Whether to output additional warnings
:param headless: Whether to disable render modes
that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
if warn:
warnings.warn(
"No render modes was declared in the environment "
"(env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
)
# Only check currrent render mode
if env.render_mode:
env.render()
env.close()
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
"""
Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Please take a look at https://gymnasium.farama.org/api/env/
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
:param env: The Gym environment that will be checked
:param warn: Whether to output additional warnings
mainly related to the interaction with Stable Baselines
:param skip_render_check: Whether to skip the checks for the render method.
True by default (useful for the CI)
"""
assert isinstance(
env, gym.Env
), "Your environment must inherit from the gymnasium.Env class cf. https://gymnasium.farama.org/api/env/"
# ============= Check the spaces (observation and action) ================
_check_spaces(env)
# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space
try:
env.reset(seed=0)
except TypeError as e:
raise TypeError("The reset() method must accept a `seed` parameter") from e
# Warn the user if needed.
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
if warn:
_check_unsupported_spaces(env, observation_space, action_space)
obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space}
for key, space in obs_spaces.items():
if isinstance(space, spaces.Box):
_check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, spaces.Box) and (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(action_space.low != -1)
or np.any(action_space.high != 1)
):
warnings.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
if isinstance(action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([action_space.low, action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
warnings.warn(
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
)
# If Sequence observation space, do not check the observation any further
if isinstance(observation_space, spaces.Sequence):
return
# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)
# ==== Check the render method and the declared render modes ====
if not skip_render_check:
_check_render(env, warn) # pragma: no cover
try:
check_for_nested_spaces(env.observation_space)
# The check doesn't support nested observations/dict actions
# A warning about it has already been emitted
_check_nan(env)
except NotImplementedError:
pass

View File

@ -0,0 +1,173 @@
import os
from typing import Any, Callable, Dict, Optional, Type, Union
import gymnasium as gym
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: Environment to unwrap
:param wrapper_class: Wrapper to look for
:return: Environment unwrapped till ``wrapper_class`` if it has been wrapped with it
"""
env_tmp = env
while isinstance(env_tmp, gym.Wrapper):
if isinstance(env_tmp, wrapper_class):
return env_tmp
env_tmp = env_tmp.env
return None
def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool:
"""
Check if a given environment has been wrapped with a given wrapper.
:param env: Environment to check
:param wrapper_class: Wrapper class to look for
:return: True if environment has been wrapped with ``wrapper_class``.
"""
return unwrap_wrapper(env, wrapper_class) is not None
def make_vec_env(
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
than a ``SubprocVecEnv``.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_class: Additional wrapper to use on the environment.
This can also be a function with single argument that wraps the environment in many things.
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
:return: The wrapped environment
"""
env_kwargs = env_kwargs or {}
vec_env_kwargs = vec_env_kwargs or {}
monitor_kwargs = monitor_kwargs or {}
wrapper_kwargs = wrapper_kwargs or {}
assert vec_env_kwargs is not None # for mypy
def make_env(rank: int) -> Callable[[], gym.Env]:
def _init() -> gym.Env:
# For type checker:
assert monitor_kwargs is not None
assert wrapper_kwargs is not None
assert env_kwargs is not None
if isinstance(env_id, str):
# if the render mode was not specified, we set it to `rgb_array` as default.
kwargs = {"render_mode": "rgb_array"}
kwargs.update(env_kwargs)
try:
env = gym.make(env_id, **kwargs) # type: ignore[arg-type]
except TypeError:
env = gym.make(env_id, **env_kwargs)
else:
env = env_id(**env_kwargs)
# Patch to support gym 0.21/0.26 and gymnasium
env = _patch_env(env)
if seed is not None:
# Note: here we only seed the action space
# We will seed the env at the next reset
env.action_space.seed(seed + rank)
# Wrap the env in a Monitor wrapper
# to have additional training information
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
# Create the monitor folder if needed
if monitor_path is not None and monitor_dir is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
# Optionally, wrap the environment with the provided wrapper
if wrapper_class is not None:
env = wrapper_class(env, **wrapper_kwargs)
return env
return _init
# No custom VecEnv is passed
if vec_env_cls is None:
# Default: use a DummyVecEnv
vec_env_cls = DummyVecEnv
vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
# Prepare the seeds for the first reset
vec_env.seed(seed)
return vec_env
def make_atari_env(
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_kwargs: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None,
vec_env_kwargs: Optional[Dict[str, Any]] = None,
monitor_kwargs: Optional[Dict[str, Any]] = None,
) -> VecEnv:
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
in a Monitor wrapper to provide additional information about training.
:param wrapper_kwargs: Optional keyword argument to pass to the ``AtariWrapper``
:param env_kwargs: Optional keyword argument to pass to the env constructor
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
:return: The wrapped environment
"""
return make_vec_env(
env_id,
n_envs=n_envs,
seed=seed,
start_index=start_index,
monitor_dir=monitor_dir,
wrapper_class=AtariWrapper,
env_kwargs=env_kwargs,
vec_env_cls=vec_env_cls,
vec_env_kwargs=vec_env_kwargs,
monitor_kwargs=monitor_kwargs,
wrapper_kwargs=wrapper_kwargs,
)

View File

@ -0,0 +1,20 @@
from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.envs.identity_env import (
FakeImageEnv,
IdentityEnv,
IdentityEnvBox,
IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete,
)
from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv
__all__ = [
"BitFlippingEnv",
"FakeImageEnv",
"IdentityEnv",
"IdentityEnvBox",
"IdentityEnvMultiBinary",
"IdentityEnvMultiDiscrete",
"SimpleMultiObsEnv",
"SimpleMultiObsEnv",
]

View File

@ -0,0 +1,235 @@
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from gymnasium import Env, spaces
from gymnasium.envs.registration import EnvSpec
from stable_baselines3.common.type_aliases import GymStepReturn
class BitFlippingEnv(Env):
"""
Simple bit flipping env, useful to test HER.
The goal is to flip all the bits to get a vector of ones.
In the continuous variant, if the ith action component has a value > 0,
then the ith bit will be flipped. Uses a ``MultiBinary`` observation space
by default.
:param n_bits: Number of bits to flip
:param continuous: Whether to use the continuous actions version or not,
by default, it uses the discrete one
:param max_steps: Max number of steps, by default, equal to n_bits
:param discrete_obs_space: Whether to use the discrete observation
version or not, ie a one-hot encoding of all possible states
:param image_obs_space: Whether to use an image observation version
or not, ie a greyscale image of the state
:param channel_first: Whether to use channel-first or last image.
"""
spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point")
state: np.ndarray
def __init__(
self,
n_bits: int = 10,
continuous: bool = False,
max_steps: Optional[int] = None,
discrete_obs_space: bool = False,
image_obs_space: bool = False,
channel_first: bool = True,
render_mode: str = "human",
):
super().__init__()
self.render_mode = render_mode
# Shape of the observation when using image space
self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1)
# The achieved goal is determined by the current state
# here, it is a special where they are equal
# observation space for observations given to the model
self.observation_space = self._make_observation_space(discrete_obs_space, image_obs_space, n_bits)
# observation space used to update internal state
self._obs_space = spaces.MultiBinary(n_bits)
if continuous:
self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32)
else:
self.action_space = spaces.Discrete(n_bits)
self.continuous = continuous
self.discrete_obs_space = discrete_obs_space
self.image_obs_space = image_obs_space
self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype)
if max_steps is None:
max_steps = n_bits
self.max_steps = max_steps
self.current_step = 0
def seed(self, seed: int) -> None:
self._obs_space.seed(seed)
def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
"""
Convert to discrete space if needed.
:param state:
:return:
"""
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
return image.reshape(self.image_shape).astype(np.uint8)
return state
def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int) -> np.ndarray:
"""
Convert to bit vector if needed.
:param state: The state to be converted, which can be either an integer or a numpy array.
:param batch_size: The batch size.
:return: The state converted into a bit vector.
"""
# Convert back to bit vector
if isinstance(state, int):
bit_vector = np.array(state).reshape(batch_size, -1)
# Convert to binary representation
bit_vector = ((bit_vector[:, :] & (1 << np.arange(len(self.state)))) > 0).astype(int)
elif self.image_obs_space:
bit_vector = state.reshape(batch_size, -1)[:, : len(self.state)] / 255
else:
bit_vector = np.array(state).reshape(batch_size, -1)
return bit_vector
def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: bool, n_bits: int) -> spaces.Dict:
"""
Helper to create observation space
:param discrete_obs_space: Whether to use the discrete observation version
:param image_obs_space: Whether to use the image observation version
:param n_bits: The number of bits used to represent the state
:return: the environment observation space
"""
if discrete_obs_space and image_obs_space:
raise ValueError("Cannot use both discrete and image observation spaces")
if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
return spaces.Dict(
{
"observation": spaces.Discrete(2**n_bits),
"achieved_goal": spaces.Discrete(2**n_bits),
"desired_goal": spaces.Discrete(2**n_bits),
}
)
if image_obs_space:
# When using image as input,
# one image contains the bits 0 -> 0, 1 -> 255
# and the rest is filled with zeros
return spaces.Dict(
{
"observation": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"achieved_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
"desired_goal": spaces.Box(
low=0,
high=255,
shape=self.image_shape,
dtype=np.uint8,
),
}
)
return spaces.Dict(
{
"observation": spaces.MultiBinary(n_bits),
"achieved_goal": spaces.MultiBinary(n_bits),
"desired_goal": spaces.MultiBinary(n_bits),
}
)
def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]:
"""
Helper to create the observation.
:return: The current observation.
"""
return OrderedDict(
[
("observation", self.convert_if_needed(self.state.copy())),
("achieved_goal", self.convert_if_needed(self.state.copy())),
("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
]
)
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict] = None
) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]:
if seed is not None:
self._obs_space.seed(seed)
self.current_step = 0
self.state = self._obs_space.sample()
return self._get_obs(), {}
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
"""
Step into the env.
:param action:
:return:
"""
if self.continuous:
self.state[action > 0] = 1 - self.state[action > 0]
else:
self.state[action] = 1 - self.state[action]
obs = self._get_obs()
reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None).item())
terminated = reward == 0
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
info = {"is_success": terminated}
truncated = self.current_step >= self.max_steps
return obs, reward, terminated, truncated, info
def compute_reward(
self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]]
) -> np.float32:
# As we are using a vectorized version, we need to keep track of the `batch_size`
if isinstance(achieved_goal, int):
batch_size = 1
elif self.image_obs_space:
batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1
else:
batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1
desired_goal = self.convert_to_bit_vector(desired_goal, batch_size)
achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size)
# Deceptive reward: it is positive only when the goal is achieved
# Here we are using a vectorized version
distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
return -(distance > 0).astype(np.float32)
def render(self) -> Optional[np.ndarray]: # type: ignore[override]
if self.render_mode == "rgb_array":
return self.state.copy()
print(self.state)
return None
def close(self) -> None:
pass

View File

@ -0,0 +1,159 @@
from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymStepReturn
T = TypeVar("T", int, np.ndarray)
class IdentityEnv(gym.Env, Generic[T]):
def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = None, ep_length: int = 100):
"""
Identity environment for testing purposes
:param dim: the size of the action and observation dimension you want
to learn. Provide at most one of ``dim`` and ``space``. If both are
None, then initialization proceeds with ``dim=1`` and ``space=None``.
:param space: the action and observation space. Provide at most one of
``dim`` and ``space``.
:param ep_length: the length of each episode in timesteps
"""
if space is None:
if dim is None:
dim = 1
space = spaces.Discrete(dim)
else:
assert dim is None, "arguments for both 'dim' and 'space' provided: at most one allowed"
self.action_space = self.observation_space = space
self.ep_length = ep_length
self.current_step = 0
self.num_resets = -1 # Becomes 0 after __init__ exits.
self.reset()
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0
self.num_resets += 1
self._choose_next_state()
return self.state, {}
def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}
def _choose_next_state(self) -> None:
self.state = self.action_space.sample()
def _get_reward(self, action: T) -> float:
return 1.0 if np.all(self.state == action) else 0.0
def render(self, mode: str = "human") -> None:
pass
class IdentityEnvBox(IdentityEnv[np.ndarray]):
def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
"""
Identity environment for testing purposes
:param low: the lower bound of the box dim
:param high: the upper bound of the box dim
:param eps: the epsilon bound for correct value
:param ep_length: the length of each episode in timesteps
"""
space = spaces.Box(low=low, high=high, shape=(1,), dtype=np.float32)
super().__init__(ep_length=ep_length, space=space)
self.eps = eps
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
reward = self._get_reward(action)
self._choose_next_state()
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.state, reward, terminated, truncated, {}
def _get_reward(self, action: np.ndarray) -> float:
return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0
class IdentityEnvMultiDiscrete(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = spaces.MultiDiscrete([dim, dim])
super().__init__(ep_length=ep_length, space=space)
class IdentityEnvMultiBinary(IdentityEnv[np.ndarray]):
def __init__(self, dim: int = 1, ep_length: int = 100) -> None:
"""
Identity environment for testing purposes
:param dim: the size of the dimensions you want to learn
:param ep_length: the length of each episode in timesteps
"""
space = spaces.MultiBinary(dim)
super().__init__(ep_length=ep_length, space=space)
class FakeImageEnv(gym.Env):
"""
Fake image environment for testing purposes, it mimics Atari games.
:param action_dim: Number of discrete actions
:param screen_height: Height of the image
:param screen_width: Width of the image
:param n_channels: Number of color channels
:param discrete: Create discrete action space instead of continuous
:param channel_first: Put channels on first axis instead of last
"""
def __init__(
self,
action_dim: int = 6,
screen_height: int = 84,
screen_width: int = 84,
n_channels: int = 1,
discrete: bool = True,
channel_first: bool = False,
) -> None:
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = spaces.Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = spaces.Discrete(action_dim)
else:
self.action_space = spaces.Box(low=-1, high=1, shape=(5,), dtype=np.float32)
self.ep_length = 10
self.current_step = 0
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
if seed is not None:
super().reset(seed=seed)
self.current_step = 0
return self.observation_space.sample(), {}
def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
reward = 0.0
self.current_step += 1
terminated = False
truncated = self.current_step >= self.ep_length
return self.observation_space.sample(), reward, terminated, truncated, {}
def render(self, mode: str = "human") -> None:
pass

View File

@ -0,0 +1,183 @@
from typing import Dict, List, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.type_aliases import GymStepReturn
class SimpleMultiObsEnv(gym.Env):
"""
Base class for GridWorld-based MultiObs Environments 4x4 grid world.
.. code-block:: text
____________
| 0 1 2 3|
| 4|¯5¯¯6¯| 7|
| 8|_9_10_|11|
|12 13 14 15|
¯¯¯¯¯¯¯¯¯¯¯¯¯¯
start is 0
states 5, 6, 9, and 10 are blocked
goal is 15
actions are = [left, down, right, up]
simple linear state env of 15 states but encoded with a vector and an image observation:
each column is represented by a random vector and each row is
represented by a random image, both sampled once at creation time.
:param num_col: Number of columns in the grid
:param num_row: Number of rows in the grid
:param random_start: If true, agent starts in random position
:param channel_last: If true, the image will be channel last, else it will be channel first
"""
def __init__(
self,
num_col: int = 4,
num_row: int = 4,
random_start: bool = True,
discrete_actions: bool = True,
channel_last: bool = True,
):
super().__init__()
self.vector_size = 5
if channel_last:
self.img_size = [64, 64, 1]
else:
self.img_size = [1, 64, 64]
self.random_start = random_start
self.discrete_actions = discrete_actions
if discrete_actions:
self.action_space = spaces.Discrete(4)
else:
self.action_space = spaces.Box(0, 1, (4,))
self.observation_space = spaces.Dict(
spaces={
"vec": spaces.Box(0, 1, (self.vector_size,), dtype=np.float64),
"img": spaces.Box(0, 255, self.img_size, dtype=np.uint8),
}
)
self.count = 0
# Timeout
self.max_count = 100
self.log = ""
self.state = 0
self.action2str = ["left", "down", "right", "up"]
self.init_possible_transitions()
self.num_col = num_col
self.state_mapping: List[Dict[str, np.ndarray]] = []
self.init_state_mapping(num_col, num_row)
self.max_state = len(self.state_mapping) - 1
def init_state_mapping(self, num_col: int, num_row: int) -> None:
"""
Initializes the state_mapping array which holds the observation values for each state
:param num_col: Number of columns.
:param num_row: Number of rows.
"""
# Each column is represented by a random vector
col_vecs = np.random.random((num_col, self.vector_size))
# Each row is represented by a random image
row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.uint8)
for i in range(num_col):
for j in range(num_row):
self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)})
def get_state_mapping(self) -> Dict[str, np.ndarray]:
"""
Uses the state to get the observation mapping.
:return: observation dict {'vec': ..., 'img': ...}
"""
return self.state_mapping[self.state]
def init_possible_transitions(self) -> None:
"""
Initializes the transitions of the environment
The environment exploits the cardinal directions of the grid by noting that
they correspond to simple addition and subtraction from the cell id within the grid
- up => means moving up a row => means subtracting the length of a column
- down => means moving down a row => means adding the length of a column
- left => means moving left by one => means subtracting 1
- right => means moving right by one => means adding 1
Thus one only needs to specify in which states each action is possible
in order to define the transitions of the environment
"""
self.left_possible = [1, 2, 3, 13, 14, 15]
self.down_possible = [0, 4, 8, 3, 7, 11]
self.right_possible = [0, 1, 2, 12, 13, 14]
self.up_possible = [4, 8, 12, 7, 11, 15]
def step(self, action: Union[int, np.ndarray]) -> GymStepReturn:
"""
Run one timestep of the environment's dynamics. When end of
episode is reached, you are responsible for calling `reset()`
to reset this environment's state.
Accepts an action and returns a tuple (observation, reward, terminated, truncated, info).
:param action:
:return: tuple (observation, reward, terminated, truncated, info).
"""
if not self.discrete_actions:
action = np.argmax(action) # type: ignore[assignment]
self.count += 1
prev_state = self.state
reward = -0.1
# define state transition
if self.state in self.left_possible and action == 0: # left
self.state -= 1
elif self.state in self.down_possible and action == 1: # down
self.state += self.num_col
elif self.state in self.right_possible and action == 2: # right
self.state += 1
elif self.state in self.up_possible and action == 3: # up
self.state -= self.num_col
got_to_end = self.state == self.max_state
reward = 1.0 if got_to_end else reward
truncated = self.count > self.max_count
terminated = got_to_end
self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}"
return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end}
def render(self, mode: str = "human") -> None:
"""
Prints the log of the environment.
:param mode:
"""
print(self.log)
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
"""
Resets the environment state and step count and returns reset observation.
:param seed:
:return: observation dict {'vec': ..., 'img': ...}
"""
if seed is not None:
super().reset(seed=seed)
self.count = 0
if not self.random_start:
self.state = 0
else:
self.state = np.random.randint(0, self.max_state)
return self.state_mapping[self.state], {}

View File

@ -0,0 +1,139 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
from stable_baselines3.common import type_aliases
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
def evaluate_policy(
model: "type_aliases.PolicyPredictor",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
warn: bool = True,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
If a vector env is passed in, this divides the episodes to evaluate onto the
different elements of the vector env. This static division of work is done to
remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
details and discussion.
.. note::
If environment has not been wrapped with ``Monitor`` wrapper, reward and
episode lengths are counted as it appears with ``env.step`` calls. If
the environment contains wrappers that modify rewards or episode lengths
(e.g. reward scaling, early episode reset), these will affect the evaluation
results as well. You can avoid this by wrapping environment with ``Monitor``
wrapper before anything else.
:param model: The RL agent you want to evaluate. This can be any object
that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
or policy (``BasePolicy``).
:param env: The gym environment or ``VecEnv`` environment.
:param n_eval_episodes: Number of episode to evaluate the agent
:param deterministic: Whether to use deterministic or stochastic actions
:param render: Whether to render the environment or not
:param callback: callback function to do additional checks,
called after each step. Gets locals() and globals() passed as parameters.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of rewards and episode lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
:return: Mean reward per episode, std of reward per episode.
Returns ([float], [int]) when ``return_episode_rewards`` is True, first
list containing per-episode rewards and second containing per-episode lengths
(in number of steps).
"""
is_monitor_wrapped = False
# Avoid circular import
from stable_baselines3.common.monitor import Monitor
if not isinstance(env, VecEnv):
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
if not is_monitor_wrapped and warn:
warnings.warn(
"Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
"This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. "
"Consider wrapping environment first with ``Monitor`` wrapper.",
UserWarning,
)
n_envs = env.num_envs
episode_rewards = []
episode_lengths = []
episode_counts = np.zeros(n_envs, dtype="int")
# Divides episodes among different sub environments in the vector as evenly as possible
episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int")
current_rewards = np.zeros(n_envs)
current_lengths = np.zeros(n_envs, dtype="int")
observations = env.reset()
states = None
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(
observations, # type: ignore[arg-type]
state=states,
episode_start=episode_starts,
deterministic=deterministic,
)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
if episode_counts[i] < episode_count_targets[i]:
# unpack values so that the callback can access the local variables
reward = rewards[i]
done = dones[i]
info = infos[i]
episode_starts[i] = done
if callback is not None:
callback(locals(), globals())
if dones[i]:
if is_monitor_wrapped:
# Atari wrapper can send a "done" signal when
# the agent loses a life, but it does not correspond
# to the true end of episode
if "episode" in info.keys():
# Do not trust "done" with episode endings.
# Monitor wrapper includes "episode" key in info if environment
# has been wrapped with it. Use those rewards instead.
episode_rewards.append(info["episode"]["r"])
episode_lengths.append(info["episode"]["l"])
# Only increment at the real end of an episode
episode_counts[i] += 1
else:
episode_rewards.append(current_rewards[i])
episode_lengths.append(current_lengths[i])
episode_counts[i] += 1
current_rewards[i] = 0
current_lengths[i] = 0
observations = new_observations
if render:
env.render()
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward

View File

@ -0,0 +1,694 @@
import datetime
import json
import os
import sys
import tempfile
import warnings
from collections import defaultdict
from io import TextIOBase
from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union
import matplotlib.figure
import numpy as np
import pandas
import torch as th
try:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
try:
from tqdm import tqdm
except ImportError:
tqdm = None
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class Video:
"""
Video data class storing the video frames and the frame per seconds
:param frames: frames to create the video from
:param fps: frames per second
"""
def __init__(self, frames: th.Tensor, fps: float):
self.frames = frames
self.fps = fps
class Figure:
"""
Figure data class storing a matplotlib figure and whether to close the figure after logging it
:param figure: figure to log
:param close: if true, close the figure after logging it
"""
def __init__(self, figure: matplotlib.figure.Figure, close: bool):
self.figure = figure
self.close = close
class Image:
"""
Image data class storing an image and data format
:param image: image to log
:param dataformats: Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc.
More info in add_image method doc at https://pytorch.org/docs/stable/tensorboard.html
Gym envs normally use 'HWC' (channel last)
"""
def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
self.image = image
self.dataformats = dataformats
class HParam:
"""
Hyperparameter data class storing hyperparameters and metrics in dictionaries
:param hparam_dict: key-value pairs of hyperparameters to log
:param metric_dict: key-value pairs of metrics to log
A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section.
"""
def __init__(self, hparam_dict: Mapping[str, Union[bool, str, float, None]], metric_dict: Mapping[str, float]):
self.hparam_dict = hparam_dict
if not metric_dict:
raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.")
self.metric_dict = metric_dict
class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
a value is not supported by some formats.
:param unsupported_formats: A sequence of unsupported formats,
for instance ``["stdout"]``.
:param value_description: Description of the value that cannot be logged by this format.
"""
def __init__(self, unsupported_formats: Sequence[str], value_description: str):
if len(unsupported_formats) > 1:
format_str = f"formats {', '.join(unsupported_formats)} are"
else:
format_str = f"format {unsupported_formats[0]} is"
super().__init__(
f"The {format_str} not supported for the {value_description} value logged.\n"
f"You can exclude formats via the `exclude` parameter of the logger's `record` function."
)
class KVWriter:
"""
Key Value writer
"""
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
"""
Write a dictionary to file
:param key_values:
:param key_excluded:
:param step:
"""
raise NotImplementedError
def close(self) -> None:
"""
Close owned resources
"""
raise NotImplementedError
class SeqWriter:
"""
sequence writer
"""
def write_sequence(self, sequence: List[str]) -> None:
"""
write_sequence an array to file
:param sequence:
"""
raise NotImplementedError
class HumanOutputFormat(KVWriter, SeqWriter):
"""A human-readable output format producing ASCII tables of key-value pairs.
Set attribute ``max_length`` to change the maximum length of keys and values
to write to output (or specify it when calling ``__init__``).
:param filename_or_file: the file to write the log to
:param max_length: the maximum length of keys and values to write to output.
Outputs longer than this will be truncated. An error will be raised
if multiple keys are truncated to the same value. The maximum output
width will be ``2*max_length + 7``. The default of 36 produces output
no longer than 79 characters wide.
"""
def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36):
self.max_length = max_length
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "w")
self.own_file = True
elif isinstance(filename_or_file, TextIOBase) or hasattr(filename_or_file, "write"):
# Note: in theory `TextIOBase` check should be sufficient,
# in practice, libraries don't always inherit from it, see GH#1598
self.file = filename_or_file # type: ignore[assignment]
self.own_file = False
else:
raise ValueError(f"Expected file or str, got {filename_or_file}")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Create strings for printing
key2str = {}
tag = ""
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
elif isinstance(value, Video):
raise FormatUnsupportedError(["stdout", "log"], "video")
elif isinstance(value, Figure):
raise FormatUnsupportedError(["stdout", "log"], "figure")
elif isinstance(value, Image):
raise FormatUnsupportedError(["stdout", "log"], "image")
elif isinstance(value, HParam):
raise FormatUnsupportedError(["stdout", "log"], "hparam")
elif isinstance(value, float):
# Align left
value_str = f"{value:<8.3g}"
else:
value_str = str(value)
if key.find("/") > 0: # Find tag and add it to the dict
tag = key[: key.find("/") + 1]
key2str[(tag, self._truncate(tag))] = ""
# Remove tag from key and indent the key
if len(tag) > 0 and tag in key:
key = f"{'':3}{key[len(tag) :]}"
truncated_key = self._truncate(key)
if (tag, truncated_key) in key2str:
raise ValueError(
f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`."
)
key2str[(tag, truncated_key)] = self._truncate(value_str)
# Find max widths
if len(key2str) == 0:
warnings.warn("Tried to write empty key-value dict")
return
else:
tagless_keys = map(lambda x: x[1], key2str.keys())
key_width = max(map(len, tagless_keys))
val_width = max(map(len, key2str.values()))
# Write out the data
dashes = "-" * (key_width + val_width + 7)
lines = [dashes]
for (_, key), value in key2str.items():
key_space = " " * (key_width - len(key))
val_space = " " * (val_width - len(value))
lines.append(f"| {key}{key_space} | {value}{val_space} |")
lines.append(dashes)
if tqdm is not None and hasattr(self.file, "name") and self.file.name == "<stdout>":
# Do not mess up with progress bar
tqdm.write("\n".join(lines) + "\n", file=sys.stdout, end="")
else:
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
def _truncate(self, string: str) -> str:
if len(string) > self.max_length:
string = string[: self.max_length - 3] + "..."
return string
def write_sequence(self, sequence: List[str]) -> None:
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
self.file.write(" ")
self.file.write("\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
if self.own_file:
self.file.close()
def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
:param key_values: log dictionary to be filtered
:param key_excluded: keys to be excluded per format
:param _format: format for which this filter is run
:return: dict without the excluded keys
"""
def is_excluded(key: str) -> bool:
return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
return {key: value for key, value in key_values.items() if not is_excluded(key)}
class JSONOutputFormat(KVWriter):
"""
Log to a file, in the JSON format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "w")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
def cast_to_json_serializable(value: Any):
if isinstance(value, Video):
raise FormatUnsupportedError(["json"], "video")
if isinstance(value, Figure):
raise FormatUnsupportedError(["json"], "figure")
if isinstance(value, Image):
raise FormatUnsupportedError(["json"], "image")
if isinstance(value, HParam):
raise FormatUnsupportedError(["json"], "hparam")
if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
return float(value.item())
else:
# otherwise, a value is a numpy array, serialize as a list or nested lists
return value.tolist()
return value
key_values = {
key: cast_to_json_serializable(value)
for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
}
self.file.write(json.dumps(key_values) + "\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
self.file.close()
class CSVOutputFormat(KVWriter):
"""
Log to a file, in a CSV format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "w+t")
self.keys: List[str] = []
self.separator = ","
self.quotechar = '"'
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(key)
self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.separator * len(extra_keys))
self.file.write("\n")
for i, key in enumerate(self.keys):
if i > 0:
self.file.write(",")
value = key_values.get(key)
if isinstance(value, Video):
raise FormatUnsupportedError(["csv"], "video")
elif isinstance(value, Figure):
raise FormatUnsupportedError(["csv"], "figure")
elif isinstance(value, Image):
raise FormatUnsupportedError(["csv"], "image")
elif isinstance(value, HParam):
raise FormatUnsupportedError(["csv"], "hparam")
elif isinstance(value, str):
# escape quotechars by prepending them with another quotechar
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
# additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
self.file.write(self.quotechar + value + self.quotechar)
elif value is not None:
self.file.write(str(value))
self.file.write("\n")
self.file.flush()
def close(self) -> None:
"""
closes the file
"""
self.file.close()
class TensorBoardOutputFormat(KVWriter):
"""
Dumps key/value pairs into TensorBoard's numeric format.
:param folder: the folder to write the log to
"""
def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use `pip install tensorboard` to do so"
self.writer = SummaryWriter(log_dir=folder)
self._is_closed = False
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None:
assert not self._is_closed, "The SummaryWriter was closed, please re-create one."
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "tensorboard" in excluded:
continue
if isinstance(value, np.ScalarType):
if isinstance(value, str):
# str is considered a np.ScalarType
self.writer.add_text(key, value, step)
else:
self.writer.add_scalar(key, value, step)
if isinstance(value, th.Tensor):
self.writer.add_histogram(key, value, step)
if isinstance(value, Video):
self.writer.add_video(key, value.frames, step, value.fps)
if isinstance(value, Figure):
self.writer.add_figure(key, value.figure, step, close=value.close)
if isinstance(value, Image):
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
if isinstance(value, HParam):
# we don't use `self.writer.add_hparams` to have control over the log_dir
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
self.writer.file_writer.add_summary(experiment)
self.writer.file_writer.add_summary(session_start_info)
self.writer.file_writer.add_summary(session_end_info)
# Flush the output to the file
self.writer.flush()
def close(self) -> None:
"""
closes the file
"""
if self.writer:
self.writer.close()
self._is_closed = True
def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
"""
return a logger for the requested format
:param _format: the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
:param log_dir: the logging directory
:param log_suffix: the suffix for the log file
:return: the logger
"""
os.makedirs(log_dir, exist_ok=True)
if _format == "stdout":
return HumanOutputFormat(sys.stdout)
elif _format == "log":
return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
elif _format == "json":
return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
elif _format == "csv":
return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
elif _format == "tensorboard":
return TensorBoardOutputFormat(log_dir)
else:
raise ValueError(f"Unknown format specified: {_format}")
# ================================================================
# Backend
# ================================================================
class Logger:
"""
The logger class.
:param folder: the logging location
:param output_formats: the list of output formats
"""
def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration
self.name_to_count: Dict[str, int] = defaultdict(int)
self.name_to_excluded: Dict[str, Tuple[str, ...]] = {}
self.level = INFO
self.dir = folder
self.output_formats = output_formats
@staticmethod
def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]:
"""
Helper function to convert str to tuple of str.
"""
if string_or_tuple is None:
return ("",)
if isinstance(string_or_tuple, tuple):
return string_or_tuple
return (string_or_tuple,)
def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
:param key: save to log this key
:param value: save to log this value
:param exclude: outputs to be excluded
"""
self.name_to_value[key] = value
self.name_to_excluded[key] = self.to_tuple(exclude)
def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
:param key: save to log this key
:param value: save to log this value
:param exclude: outputs to be excluded
"""
if value is None:
return
old_val, count = self.name_to_value[key], self.name_to_count[key]
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
self.name_to_count[key] = count + 1
self.name_to_excluded[key] = self.to_tuple(exclude)
def dump(self, step: int = 0) -> None:
"""
Write all of the diagnostics from the current iteration
"""
if self.level == DISABLED:
return
for _format in self.output_formats:
if isinstance(_format, KVWriter):
_format.write(self.name_to_value, self.name_to_excluded, step)
self.name_to_value.clear()
self.name_to_count.clear()
self.name_to_excluded.clear()
def log(self, *args, level: int = INFO) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
:param args: log the arguments
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
"""
if self.level <= level:
self._do_log(args)
def debug(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the DEBUG level.
:param args: log the arguments
"""
self.log(*args, level=DEBUG)
def info(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the INFO level.
:param args: log the arguments
"""
self.log(*args, level=INFO)
def warn(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the WARN level.
:param args: log the arguments
"""
self.log(*args, level=WARN)
def error(self, *args) -> None:
"""
Write the sequence of args, with no separators,
to the console and output files (if you've configured an output file).
Using the ERROR level.
:param args: log the arguments
"""
self.log(*args, level=ERROR)
# Configuration
# ----------------------------------------
def set_level(self, level: int) -> None:
"""
Set logging threshold on current logger.
:param level: the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
"""
self.level = level
def get_dir(self) -> Optional[str]:
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
:return: the logging directory
"""
return self.dir
def close(self) -> None:
"""
closes the file
"""
for _format in self.output_formats:
_format.close()
# Misc
# ----------------------------------------
def _do_log(self, args: Tuple[Any, ...]) -> None:
"""
log to the requested format outputs
:param args: the arguments to log
"""
for _format in self.output_formats:
if isinstance(_format, SeqWriter):
_format.write_sequence(list(map(str, args)))
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger:
"""
Configure the current logger.
:param folder: the save location
(if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time])
:param format_strings: the output logging format
(if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
:return: The logger object.
"""
if folder is None:
folder = os.getenv("SB3_LOGDIR")
if folder is None:
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
assert isinstance(folder, str)
os.makedirs(folder, exist_ok=True)
log_suffix = ""
if format_strings is None:
format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
format_strings = list(filter(None, format_strings))
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
logger = Logger(folder=folder, output_formats=output_formats)
# Only print when some files will be saved
if len(format_strings) > 0 and format_strings != ["stdout"]:
logger.log(f"Logging to {folder}")
return logger
# ================================================================
# Readers
# ================================================================
def read_json(filename: str) -> pandas.DataFrame:
"""
read a json file using pandas
:param filename: the file path to read
:return: the data in the json
"""
data = []
with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)
def read_csv(filename: str) -> pandas.DataFrame:
"""
read a csv file using pandas
:param filename: the file path to read
:return: the data in the csv
"""
return pandas.read_csv(filename, index_col=None, comment="#")

View File

@ -0,0 +1,254 @@
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]
import csv
import json
import os
import time
from glob import glob
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union
import gymnasium as gym
import pandas
from gymnasium.core import ActType, ObsType
class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
:param env: The environment
:param filename: the location to save a log file, can be None for no log
:param allow_early_resets: allows the reset of the environment before it is done
:param reset_keywords: extra keywords for the reset call,
if extra parameters are needed at reset
:param info_keywords: extra information to log, from the information return of env.step()
:param override_existing: appends to file if ``filename`` exists, otherwise
override existing files (default)
"""
EXT = "monitor.csv"
def __init__(
self,
env: gym.Env,
filename: Optional[str] = None,
allow_early_resets: bool = True,
reset_keywords: Tuple[str, ...] = (),
info_keywords: Tuple[str, ...] = (),
override_existing: bool = True,
):
super().__init__(env=env)
self.t_start = time.time()
self.results_writer = None
if filename is not None:
env_id = env.spec.id if env.spec is not None else None
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": str(env_id)},
extra_keys=reset_keywords + info_keywords,
override_existing=override_existing,
)
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards: List[float] = []
self.needs_reset = True
self.episode_returns: List[float] = []
self.episode_lengths: List[int] = []
self.episode_times: List[float] = []
self.total_steps = 0
# extra info about the current episode, that was passed in during reset()
self.current_reset_info: Dict[str, Any] = {}
def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
:param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
:return: the first observation of the environment
"""
if not self.allow_early_resets and not self.needs_reset:
raise RuntimeError(
"Tried to reset an environment before done. If you want to allow early resets, "
"wrap your env with Monitor(env, path, allow_early_resets=True)"
)
self.rewards = []
self.needs_reset = False
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""
Step the environment with the given action
:param action: the action
:return: observation, reward, terminated, truncated, information
"""
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
observation, reward, terminated, truncated, info = self.env.step(action)
self.rewards.append(float(reward))
if terminated or truncated:
self.needs_reset = True
ep_rew = sum(self.rewards)
ep_len = len(self.rewards)
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
ep_info[key] = info[key]
self.episode_returns.append(ep_rew)
self.episode_lengths.append(ep_len)
self.episode_times.append(time.time() - self.t_start)
ep_info.update(self.current_reset_info)
if self.results_writer:
self.results_writer.write_row(ep_info)
info["episode"] = ep_info
self.total_steps += 1
return observation, reward, terminated, truncated, info
def close(self) -> None:
"""
Closes the environment
"""
super().close()
if self.results_writer is not None:
self.results_writer.close()
def get_total_steps(self) -> int:
"""
Returns the total number of timesteps
:return:
"""
return self.total_steps
def get_episode_rewards(self) -> List[float]:
"""
Returns the rewards of all the episodes
:return:
"""
return self.episode_returns
def get_episode_lengths(self) -> List[int]:
"""
Returns the number of timesteps of all the episodes
:return:
"""
return self.episode_lengths
def get_episode_times(self) -> List[float]:
"""
Returns the runtime in seconds of all the episodes
:return:
"""
return self.episode_times
class LoadMonitorResultsError(Exception):
"""
Raised when loading the monitor log fails.
"""
pass
class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
:param filename: the location to save a log file. When it does not end in
the string ``"monitor.csv"``, this suffix will be appended to it
:param header: the header dictionary object of the saved csv
:param extra_keys: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
:param override_existing: appends to file if ``filename`` exists, otherwise
override existing files (default)
"""
def __init__(
self,
filename: str = "",
header: Optional[Dict[str, Union[float, str]]] = None,
extra_keys: Tuple[str, ...] = (),
override_existing: bool = True,
):
if header is None:
header = {}
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
filename = os.path.realpath(filename)
# Create (if any) missing filename directories
os.makedirs(os.path.dirname(filename), exist_ok=True)
# Append mode when not overridding existing file
mode = "w" if override_existing else "a"
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, f"{mode}t", newline="\n")
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t", *extra_keys))
if override_existing:
self.file_handler.write(f"#{json.dumps(header)}\n")
self.logger.writeheader()
self.file_handler.flush()
def write_row(self, epinfo: Dict[str, float]) -> None:
"""
Write row of monitor data to csv log file.
:param epinfo: the information on episodic return, length, and time
"""
if self.logger:
self.logger.writerow(epinfo)
self.file_handler.flush()
def close(self) -> None:
"""
Close the file handler
"""
self.file_handler.close()
def get_monitor_files(path: str) -> List[str]:
"""
get all the monitor files in the given path
:param path: the logging folder
:return: the log files
"""
return glob(os.path.join(path, "*" + Monitor.EXT))
def load_results(path: str) -> pandas.DataFrame:
"""
Load all Monitor logs from a given directory path matching ``*monitor.csv``
:param path: the directory path containing the log file(s)
:return: the logged data
"""
monitor_files = get_monitor_files(path)
if len(monitor_files) == 0:
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
data_frames, headers = [], []
for file_name in monitor_files:
with open(file_name) as file_handler:
first_line = file_handler.readline()
assert first_line[0] == "#"
header = json.loads(first_line[1:])
data_frame = pandas.read_csv(file_handler, index_col=None)
headers.append(header)
data_frame["t"] += header["t_start"]
data_frames.append(data_frame)
data_frame = pandas.concat(data_frames)
data_frame.sort_values("t", inplace=True)
data_frame.reset_index(inplace=True)
data_frame["t"] -= min(header["t_start"] for header in headers)
return data_frame

View File

@ -0,0 +1,173 @@
import copy
from abc import ABC, abstractmethod
from typing import Iterable, List, Optional
import numpy as np
from numpy.typing import DTypeLike
class ActionNoise(ABC):
"""
The action noise base class
"""
def __init__(self) -> None:
super().__init__()
def reset(self) -> None:
"""
Call end of episode reset for the noise
"""
pass
@abstractmethod
def __call__(self) -> np.ndarray:
raise NotImplementedError()
class NormalActionNoise(ActionNoise):
"""
A Gaussian action noise.
:param mean: Mean value of the noise
:param sigma: Scale of the noise (std here)
:param dtype: Type of the output noise
"""
def __init__(self, mean: np.ndarray, sigma: np.ndarray, dtype: DTypeLike = np.float32) -> None:
self._mu = mean
self._sigma = sigma
self._dtype = dtype
super().__init__()
def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma).astype(self._dtype)
def __repr__(self) -> str:
return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
class OrnsteinUhlenbeckActionNoise(ActionNoise):
"""
An Ornstein Uhlenbeck action noise, this is designed to approximate Brownian motion with friction.
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
:param mean: Mean of the noise
:param sigma: Scale of the noise
:param theta: Rate of mean reversion
:param dt: Timestep for the noise
:param initial_noise: Initial value for the noise output, (if None: 0)
:param dtype: Type of the output noise
"""
def __init__(
self,
mean: np.ndarray,
sigma: np.ndarray,
theta: float = 0.15,
dt: float = 1e-2,
initial_noise: Optional[np.ndarray] = None,
dtype: DTypeLike = np.float32,
) -> None:
self._theta = theta
self._mu = mean
self._sigma = sigma
self._dt = dt
self._dtype = dtype
self.initial_noise = initial_noise
self.noise_prev = np.zeros_like(self._mu)
self.reset()
super().__init__()
def __call__(self) -> np.ndarray:
noise = (
self.noise_prev
+ self._theta * (self._mu - self.noise_prev) * self._dt
+ self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
)
self.noise_prev = noise
return noise.astype(self._dtype)
def reset(self) -> None:
"""
reset the Ornstein Uhlenbeck noise, to the initial position
"""
self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
def __repr__(self) -> str:
return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
class VectorizedActionNoise(ActionNoise):
"""
A Vectorized action noise for parallel environments.
:param base_noise: Noise generator to use
:param n_envs: Number of parallel environments
"""
def __init__(self, base_noise: ActionNoise, n_envs: int) -> None:
try:
self.n_envs = int(n_envs)
assert self.n_envs > 0
except (TypeError, AssertionError) as e:
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0") from e
self.base_noise = base_noise
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
"""
Reset all the noise processes, or those listed in indices.
:param indices: The indices to reset. Default: None.
If the parameter is None, then all processes are reset to their initial position.
"""
if indices is None:
indices = range(len(self.noises))
for index in indices:
self.noises[index].reset()
def __repr__(self) -> str:
return f"VecNoise(BaseNoise={self.base_noise!r}), n_envs={len(self.noises)})"
def __call__(self) -> np.ndarray:
"""
Generate and stack the action noise from each noise object.
"""
noise = np.stack([noise() for noise in self.noises])
return noise
@property
def base_noise(self) -> ActionNoise:
return self._base_noise
@base_noise.setter
def base_noise(self, base_noise: ActionNoise) -> None:
if base_noise is None:
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
if not isinstance(base_noise, ActionNoise):
raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
self._base_noise = base_noise
@property
def noises(self) -> List[ActionNoise]:
return self._noises
@noises.setter
def noises(self, noises: List[ActionNoise]) -> None:
noises = list(noises) # raises TypeError if not iterable
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
if len(different_types):
raise ValueError(
f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
)
self._noises = noises
for noise in noises:
noise.reset()

View File

@ -0,0 +1,600 @@
import io
import pathlib
import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
class OffPolicyAlgorithm(BaseAlgorithm):
"""
The base for Off-Policy algorithms (ex: SAC/TD3)
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param learning_rate: learning rate for the optimizer,
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param device: Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
if it is not possible.
:param support_multi_env: Whether the algorithm supports training
with multiple environments (as in A2C)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param sde_support: Whether the model support gSDE or not
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
actor: th.nn.Module
def __init__(
self,
policy: Union[str, Type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
support_multi_env: bool = False,
monitor_wrapper: bool = True,
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
support_multi_env=support_multi_env,
monitor_wrapper=monitor_wrapper,
seed=seed,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
supported_action_spaces=supported_action_spaces,
)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.learning_starts = learning_starts
self.tau = tau
self.gamma = gamma
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
self.replay_buffer: Optional[ReplayBuffer] = None
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self._episode_storage = None
# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _convert_train_freq(self) -> None:
"""
Convert `train_freq` parameter (int or tuple)
to a TrainFreq object.
"""
if not isinstance(self.train_freq, TrainFreq):
train_freq = self.train_freq
# The value of the train frequency will be checked later
if not isinstance(train_freq, tuple):
train_freq = (train_freq, "step")
try:
train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) # type: ignore[assignment]
except ValueError as e:
raise ValueError(
f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
) from e
if not isinstance(train_freq[0], int):
raise ValueError(f"The frequency of `train_freq` must be an integer and not {train_freq[0]}")
self.train_freq = TrainFreq(*train_freq) # type: ignore[assignment,arg-type]
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.replay_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.replay_buffer_class = DictReplayBuffer
else:
self.replay_buffer_class = ReplayBuffer
if self.replay_buffer is None:
# Make a local copy as we should not pickle
# the environment when using HerReplayBuffer
replay_buffer_kwargs = self.replay_buffer_kwargs.copy()
if issubclass(self.replay_buffer_class, HerReplayBuffer):
assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`"
replay_buffer_kwargs["env"] = self.env
self.replay_buffer = self.replay_buffer_class(
self.buffer_size,
self.observation_space,
self.action_space,
device=self.device,
n_envs=self.n_envs,
optimize_memory_usage=self.optimize_memory_usage,
**replay_buffer_kwargs,
)
self.policy = self.policy_class(
self.observation_space,
self.action_space,
self.lr_schedule,
**self.policy_kwargs,
)
self.policy = self.policy.to(self.device)
# Convert train freq parameter to TrainFreq object
self._convert_train_freq()
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
"""
Save the replay buffer as a pickle file.
:param path: Path to the file where the replay buffer should be saved.
if path is a str or pathlib.Path, the path is automatically created if necessary.
"""
assert self.replay_buffer is not None, "The replay buffer is not defined"
save_to_pkl(path, self.replay_buffer, self.verbose)
def load_replay_buffer(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
truncate_last_traj: bool = True,
) -> None:
"""
Load a replay buffer from a pickle file.
:param path: Path to the pickled replay buffer.
:param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling:
If set to ``True``, we assume that the last trajectory in the replay buffer was finished
(and truncate it).
If set to ``False``, we assume that we continue the same trajectory (same episode).
"""
self.replay_buffer = load_from_pkl(path, self.verbose)
assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
# Backward compatibility with SB3 < 2.1.0 replay buffer
# Keep old behavior: do not handle timeout termination separately
if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover
self.replay_buffer.handle_timeout_termination = False
self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones)
if isinstance(self.replay_buffer, HerReplayBuffer):
assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`"
self.replay_buffer.set_env(self.env)
if truncate_last_traj:
self.replay_buffer.truncate_last_trajectory()
# Update saved replay buffer device to match current setting, see GH#1561
self.replay_buffer.device = self.device
def _setup_learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
"""
# Prevent continuity issue by truncating trajectory
# when using memory efficient replay buffer
# see https://github.com/DLR-RM/stable-baselines3/issues/46
replay_buffer = self.replay_buffer
truncate_last_traj = (
self.optimize_memory_usage
and reset_num_timesteps
and replay_buffer is not None
and (replay_buffer.full or replay_buffer.pos > 0)
)
if truncate_last_traj:
warnings.warn(
"The last trajectory in the replay buffer will be truncated, "
"see https://github.com/DLR-RM/stable-baselines3/issues/46."
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
"to avoid that issue."
)
assert replay_buffer is not None # for mypy
# Go to the previous index
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
replay_buffer.dones[pos] = True
assert self.env is not None, "You must set the environment before calling _setup_learn()"
# Vectorize action noise if needed
if (
self.action_noise is not None
and self.env.num_envs > 1
and not isinstance(self.action_noise, VectorizedActionNoise)
):
self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs)
return super()._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
def learn(
self: SelfOffPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "run",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None, "You must set the environment before calling learn()"
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()
while self.num_timesteps < total_timesteps:
rollout = self.collect_rollouts(
self.env,
train_freq=self.train_freq,
action_noise=self.action_noise,
callback=callback,
learning_starts=self.learning_starts,
replay_buffer=self.replay_buffer,
log_interval=log_interval,
)
if not rollout.continue_training:
break
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
gradient_steps = self.gradient_steps if self.gradient_steps >= 0 else rollout.episode_timesteps
# Special case when the user passes `gradient_steps=0`
if gradient_steps > 0:
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
callback.on_training_end()
return self
def train(self, gradient_steps: int, batch_size: int) -> None:
"""
Sample the replay buffer and do the updates
(gradient descent and update target networks)
"""
raise NotImplementedError()
def _sample_action(
self,
learning_starts: int,
action_noise: Optional[ActionNoise] = None,
n_envs: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Sample an action according to the exploration policy.
This is either done by sampling the probability distribution of the policy,
or sampling a random action (from a uniform distribution over the action space)
or by adding noise to the deterministic output.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param n_envs:
:return: action to take in the environment
and scaled action that will be stored in the replay buffer.
The two differs when the action space is not normalized (bounds are not [-1, 1]).
"""
# Select action randomly or according to policy
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
# Warmup phase
unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
else:
# Note: when using continuous actions,
# we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
if isinstance(self.action_space, spaces.Box):
scaled_action = self.policy.scale_action(unscaled_action)
# Add noise to the action (improve exploration)
if action_noise is not None:
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action
return action, buffer_action
def _dump_logs(self) -> None:
"""
Write log.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if self.use_sde:
self.logger.record("train/std", (self.actor.get_std()).mean().item())
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
# Pass the number of timesteps for tensorboard
self.logger.dump(step=self.num_timesteps)
def _on_step(self) -> None:
"""
Method called after each step in the environment.
It is meant to trigger DQN target network update
but can be used for other purposes
"""
pass
def _store_transition(
self,
replay_buffer: ReplayBuffer,
buffer_action: np.ndarray,
new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
reward: np.ndarray,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
"""
Store transition in the replay buffer.
We store the normalized action and the unnormalized observation.
It also handles terminal observations (because VecEnv resets automatically).
:param replay_buffer: Replay buffer object where to store the transition.
:param buffer_action: normalized action
:param new_obs: next observation in the current episode
or first observation of the episode (when dones is True)
:param reward: reward for the current transition
:param dones: Termination signal
:param infos: List of additional information about the transition.
It may contain the terminal observations and information about timeout.
"""
# Store only the unnormalized version
if self._vec_normalize_env is not None:
new_obs_ = self._vec_normalize_env.get_original_obs()
reward_ = self._vec_normalize_env.get_original_reward()
else:
# Avoid changing the original ones
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
# Avoid modification by reference
next_obs = deepcopy(new_obs_)
# As the VecEnv resets automatically, new_obs is already the
# first observation of the next episode
for i, done in enumerate(dones):
if done and infos[i].get("terminal_observation") is not None:
if isinstance(next_obs, dict):
next_obs_ = infos[i]["terminal_observation"]
# VecNormalize normalizes the terminal observation
if self._vec_normalize_env is not None:
next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
# Replace next obs for the correct envs
for key in next_obs.keys():
next_obs[key][i] = next_obs_[key]
else:
next_obs[i] = infos[i]["terminal_observation"]
# VecNormalize normalizes the terminal observation
if self._vec_normalize_env is not None:
next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])
replay_buffer.add(
self._last_original_obs, # type: ignore[arg-type]
next_obs, # type: ignore[arg-type]
buffer_action,
reward_,
dones,
infos,
)
self._last_obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
self._last_original_obs = new_obs_
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
train_freq: TrainFreq,
replay_buffer: ReplayBuffer,
action_noise: Optional[ActionNoise] = None,
learning_starts: int = 0,
log_interval: Optional[int] = None,
) -> RolloutReturn:
"""
Collect experiences and store them into a ``ReplayBuffer``.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param train_freq: How much experience to collect
by doing rollouts of current policy.
Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
with ``<n>`` being an integer greater than 0.
:param action_noise: Action noise that will be used for exploration
Required for deterministic policy (e.g. TD3). This can also be used
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param replay_buffer:
:param log_interval: Log data every ``log_interval`` episodes
:return:
"""
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
num_collected_steps, num_collected_episodes = 0, 0
assert isinstance(env, VecEnv), "You must pass a VecEnv"
assert train_freq.frequency > 0, "Should at least collect one step or episode."
if env.num_envs > 1:
assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
if self.use_sde:
self.actor.reset_noise(env.num_envs)
callback.on_rollout_start()
continue_training = True
while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.actor.reset_noise(env.num_envs)
# Select action randomly or according to policy
actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs)
# Rescale and perform action
new_obs, rewards, dones, infos = env.step(actions)
self.num_timesteps += env.num_envs
num_collected_steps += 1
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if not callback.on_step():
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
# Retrieve reward and episode length if using Monitor wrapper
self._update_info_buffer(infos, dones)
# Store data in replay buffer (normalized action and unnormalized observation)
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos) # type: ignore[arg-type]
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
# For DQN, check if the target network should be updated
# and update the exploration schedule
# For SAC/TD3, the update is dones as the same time as the gradient update
# see https://github.com/hill-a/stable-baselines/issues/900
self._on_step()
for idx, done in enumerate(dones):
if done:
# Update stats
num_collected_episodes += 1
self._episode_num += 1
if action_noise is not None:
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
action_noise.reset(**kwargs)
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
callback.on_rollout_end()
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)

View File

@ -0,0 +1,322 @@
import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")
class OnPolicyAlgorithm(BaseAlgorithm):
"""
The base for On-Policy algorithms (ex: A2C/PPO).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param n_steps: The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
:param gamma: Discount factor
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
Equivalent to classic advantage when set to 1.
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param monitor_wrapper: When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
:param supported_action_spaces: The action spaces supported by the algorithm.
"""
rollout_buffer: RolloutBuffer
policy: ActorCriticPolicy
def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
gamma: float,
gae_lambda: float,
ent_coef: float,
vf_coef: float,
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
support_multi_env=True,
monitor_wrapper=monitor_wrapper,
seed=seed,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
supported_action_spaces=supported_action_spaces,
)
self.n_steps = n_steps
self.gamma = gamma
self.gae_lambda = gae_lambda
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.rollout_buffer_class = rollout_buffer_class
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space, # type: ignore[arg-type]
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
def collect_rollouts(
self,
env: VecEnv,
callback: BaseCallback,
rollout_buffer: RolloutBuffer,
n_rollout_steps: int,
) -> bool:
"""
Collect experiences using the current policy and fill a ``RolloutBuffer``.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
:param env: The training environment
:param callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
:param rollout_buffer: Buffer to fill with rollouts
:param n_rollout_steps: Number of experiences to collect per environment
:return: True if function returned with at least `n_rollout_steps`
collected, False if callback terminated rollout prematurely.
"""
assert self._last_obs is not None, "No previous observation was provided"
# Switch to eval mode (this affects batch norm / dropout)
self.policy.set_training_mode(False)
n_steps = 0
rollout_buffer.reset()
# Sample new weights for the state dependent exploration
if self.use_sde:
self.policy.reset_noise(env.num_envs)
callback.on_rollout_start()
while n_steps < n_rollout_steps:
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
# Sample a new noise matrix
self.policy.reset_noise(env.num_envs)
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()
# Rescale and perform action
clipped_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
new_obs, rewards, dones, infos = env.step(clipped_actions)
self.num_timesteps += env.num_envs
# Give access to local variables
callback.update_locals(locals())
if not callback.on_step():
return False
self._update_info_buffer(infos, dones)
n_steps += 1
if isinstance(self.action_space, spaces.Discrete):
# Reshape in case of discrete action
actions = actions.reshape(-1, 1)
# Handle timeout by bootstraping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(
self._last_obs, # type: ignore[arg-type]
actions,
rewards,
self._last_episode_starts, # type: ignore[arg-type]
values,
log_probs,
)
self._last_obs = new_obs # type: ignore[assignment]
self._last_episode_starts = dones
with th.no_grad():
# Compute value for the last timestep
values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
callback.update_locals(locals())
callback.on_rollout_end()
return True
def train(self) -> None:
"""
Consume current rollout data and update policy parameters.
Implemented by individual algorithms.
"""
raise NotImplementedError
def _dump_logs(self, iteration: int) -> None:
"""
Write log.
:param iteration: Current logging iteration
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
if len(self.ep_success_buffer) > 0:
self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer))
self.logger.dump(step=self.num_timesteps)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 1,
tb_log_name: str = "OnPolicyAlgorithm",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOnPolicyAlgorithm:
iteration = 0
total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
)
callback.on_training_start(locals(), globals())
assert self.env is not None
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if not continue_training:
break
iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
self._dump_logs(iteration)
self.train()
callback.on_training_end()
return self
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []

View File

@ -0,0 +1,987 @@
"""Policies: abstract base class and concrete implementations."""
import collections
import copy
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
DiagGaussianDistribution,
Distribution,
MultiCategoricalDistribution,
StateDependentNoiseDistribution,
make_proba_distribution,
)
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
MlpExtractor,
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
class BaseModel(nn.Module):
"""
The base model object: makes predictions in response to observations.
In the case of policies, the prediction is an action. In the case of critics, it is the
estimated value of the observation.
:param observation_space: The observation space of the environment
:param action_space: The action space of the environment
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
optimizer: th.optim.Optimizer
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
features_extractor: Optional[BaseFeaturesExtractor] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.observation_space = observation_space
self.action_space = action_space
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.optimizer_class = optimizer_class
self.optimizer_kwargs = optimizer_kwargs
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))
def _update_features_extractor(
self,
net_kwargs: Dict[str, Any],
features_extractor: Optional[BaseFeaturesExtractor] = None,
) -> Dict[str, Any]:
"""
Update the network keyword arguments and create a new features extractor object if needed.
If a ``features_extractor`` object is passed, then it will be shared.
:param net_kwargs: the base network keyword arguments, without the ones
related to features extractor
:param features_extractor: a features extractor object.
If None, a new object will be created.
:return: The updated keyword arguments
"""
net_kwargs = net_kwargs.copy()
if features_extractor is None:
# The features extractor is not shared, create a new one
features_extractor = self.make_features_extractor()
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
return net_kwargs
def make_features_extractor(self) -> BaseFeaturesExtractor:
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:param features_extractor: The features extractor to use.
:return: The extracted features
"""
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return features_extractor(preprocessed_obs)
def _get_constructor_parameters(self) -> Dict[str, Any]:
"""
Get data that need to be saved in order to re-create the model when loading it from disk.
:return: The dictionary to pass to the as kwargs constructor when reconstruction this model.
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child class
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)
@property
def device(self) -> th.device:
"""Infer which device this policy lives on by inspecting its parameters.
If it has no parameters, the 'cpu' device is used as a fallback.
:return:"""
for param in self.parameters():
return param.device
return get_device("cpu")
def save(self, path: str) -> None:
"""
Save model to a given location.
:param path:
"""
th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path)
@classmethod
def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel:
"""
Load model from path.
:param path:
:param device: Device on which the policy should be loaded.
:return:
"""
device = get_device(device)
# Note(antonin): we cannot use `weights_only=True` here because we need to allow
# gymnasium imports for the policy to be loaded successfully
saved_variables = th.load(path, map_location=device, weights_only=False)
# Create policy object
model = cls(**saved_variables["data"])
# Load weights
model.load_state_dict(saved_variables["state_dict"])
model.to(device)
return model
def load_from_vector(self, vector: np.ndarray) -> None:
"""
Load parameters from a 1D vector.
:param vector:
"""
th.nn.utils.vector_to_parameters(th.as_tensor(vector, dtype=th.float, device=self.device), self.parameters())
def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.
:return:
"""
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
def set_training_mode(self, mode: bool) -> None:
"""
Put the policy in either training or evaluation mode.
This affects certain modules, such as batch normalisation and dropout.
:param mode: if true, set to training mode, else set to evaluation mode
"""
self.train(mode)
def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool:
"""
Check whether or not the observation is vectorized,
apply transposition to image (so that they are channel-first) if needed.
This is used in DQN when sampling random action (epsilon-greedy policy)
:param observation: the input observation to check
:return: whether the given observation is vectorized or not
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
else:
vectorized_env = is_vectorized_observation(
maybe_transpose(observation, self.observation_space), self.observation_space
)
return vectorized_env
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:return: The observation as PyTorch tensor
and whether the observation is vectorized or not
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
observation = copy.deepcopy(observation)
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
if is_image_space(obs_space):
obs_ = maybe_transpose(obs, obs_space)
else:
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]
elif is_image_space(self.observation_space):
# Handle the different cases for images
# as PyTorch use channel first format
observation = maybe_transpose(observation, self.observation_space)
else:
observation = np.array(observation)
if not isinstance(observation, dict):
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]
obs_tensor = obs_as_tensor(observation, self.device)
return obs_tensor, vectorized_env
class BasePolicy(BaseModel, ABC):
"""The base policy object.
Parameters are mostly the same as `BaseModel`; additions are documented below.
:param args: positional arguments passed through to `BaseModel`.
:param kwargs: keyword arguments passed through to `BaseModel`.
:param squash_output: For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""
features_extractor: BaseFeaturesExtractor
def __init__(self, *args, squash_output: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._squash_output = squash_output
@staticmethod
def _dummy_schedule(progress_remaining: float) -> float:
"""(float) Useful for pickling policy."""
del progress_remaining
return 0.0
@property
def squash_output(self) -> bool:
"""(bool) Getter for squash_output."""
return self._squash_output
@staticmethod
def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.orthogonal_(module.weight, gain=gain)
if module.bias is not None:
module.bias.data.fill_(0.0)
@abstractmethod
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
By default provides a dummy implementation -- not all BasePolicy classes
implement this, e.g. if they are a Critic in an Actor-Critic method.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
# Check for common mistake that the user does not mix Gym/VecEnv API
# Tuple obs are not supported by SB3, so we can safely do that check
if isinstance(observation, tuple) and len(observation) == 2 and isinstance(observation[1], dict):
raise ValueError(
"You have passed a tuple to the predict() function instead of a Numpy array or a Dict. "
"You are probably mixing Gym API with SB3 VecEnv API: `obs, info = env.reset()` (Gym) "
"vs `obs = vec_env.reset()` (SB3 VecEnv). "
"See related issue https://github.com/DLR-RM/stable-baselines3/issues/1694 "
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
# Remove batch dimension if needed
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)
return actions, state # type: ignore[return-value]
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
class ActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs["eps"] = 1e-5
super().__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
squash_output=squash_output,
normalize_images=normalize_images,
)
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
warnings.warn(
(
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
"you should now pass directly a dictionary and not a list "
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
),
)
net_arch = net_arch[0]
# Default network architecture, from stable-baselines
if net_arch is None:
if features_extractor_class == NatureCNN:
net_arch = []
else:
net_arch = dict(pi=[64, 64], vf=[64, 64])
self.net_arch = net_arch
self.activation_fn = activation_fn
self.ortho_init = ortho_init
self.share_features_extractor = share_features_extractor
self.features_extractor = self.make_features_extractor()
self.features_dim = self.features_extractor.features_dim
if self.share_features_extractor:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.features_extractor
else:
self.pi_features_extractor = self.features_extractor
self.vf_features_extractor = self.make_features_extractor()
self.log_std_init = log_std_init
dist_kwargs = None
assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)"
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": False,
}
self.use_sde = use_sde
self.dist_kwargs = dist_kwargs
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
self._build(lr_schedule)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]
data.update(
dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
def reset_noise(self, n_envs: int = 1) -> None:
"""
Sample new weights for the exploration matrix.
:param n_envs:
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build_mlp_extractor(self) -> None:
"""
Create the policy and value networks.
Part of the layers can be shared.
"""
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
self.mlp_extractor = MlpExtractor(
self.features_dim,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
device=self.device,
)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the networks and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self._build_mlp_extractor()
latent_dim_pi = self.mlp_extractor.latent_dim_pi
if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
# with small initial weight for the output
if self.ortho_init:
# TODO: check for features_extractor
# Values from stable-baselines.
# features_extractor/mlp values are
# originally from openai/baselines (default gains/init_scales).
module_gains = {
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
self.value_net: 1,
}
if not self.share_features_extractor:
# Note(antonin): this is to keep SB3 results
# consistent, see GH#1148
del module_gains[self.features_extractor]
module_gains[self.pi_features_extractor] = np.sqrt(2)
module_gains[self.vf_features_extractor] = np.sqrt(2)
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
:param obs: Observation
:param deterministic: Whether to sample or use deterministic actions
:return: action, value and log probability of the action
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
# Evaluate the values for the given observations
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
:param obs: Observation
:param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used.
:return: The extracted features. If features extractor is not shared, returns a tuple with the
features for the actor and the features for the critic.
"""
if self.share_features_extractor:
return super().extract_features(obs, self.features_extractor if features_extractor is None else features_extractor)
else:
if features_extractor is not None:
warnings.warn(
"Provided features_extractor will be ignored because the features extractor is not shared.",
UserWarning,
)
pi_features = super().extract_features(obs, self.pi_features_extractor)
vf_features = super().extract_features(obs, self.vf_features_extractor)
return pi_features, vf_features
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
"""
Retrieve action distribution given the latent codes.
:param latent_pi: Latent code for the actor
:return: Action distribution
"""
mean_actions = self.action_net(latent_pi)
if isinstance(self.action_dist, DiagGaussianDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std)
elif isinstance(self.action_dist, CategoricalDistribution):
# Here mean_actions are the logits before the softmax
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
# Here mean_actions are the flattened logits
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, BernoulliDistribution):
# Here mean_actions are the logits (before rounding to get the binary actions)
return self.action_dist.proba_distribution(action_logits=mean_actions)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
:param observation:
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
else:
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
return values, log_prob, entropy
def get_distribution(self, obs: PyTorchObs) -> Distribution:
"""
Get the current policy distribution given the observations.
:param obs:
:return: the action distribution.
"""
features = super().extract_features(obs, self.pi_features_extractor)
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
:param obs: Observation
:return: the estimated values.
"""
features = super().extract_features(obs, self.vf_features_extractor)
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
class ActorCriticCnnPolicy(ActorCriticPolicy):
"""
CNN policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class MultiInputActorCriticPolicy(ActorCriticPolicy):
"""
MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
:param observation_space: Observation space (Tuple)
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param ortho_init: Whether to use or not orthogonal initialization
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param full_std: Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using gSDE
:param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: spaces.Dict,
action_space: spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
share_features_extractor: bool = True,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
use_expln,
squash_output,
features_extractor_class,
features_extractor_kwargs,
share_features_extractor,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
class ContinuousCritic(BaseModel):
"""
Critic network(s) for DDPG/SAC/TD3.
It represents the action-state value function (Q-value function).
Compared to A2C/PPO critics, this one represents the Q-value
and takes the continuous action as input. It is concatenated with the state
and then fed to the network which outputs a single value: Q(s, a).
For more recent algorithms like SAC/TD3, multiple networks
are created to give different estimates.
By default, it creates two critic networks used to reduce overestimation
thanks to clipped Q-learning (cf TD3 paper).
:param observation_space: Obervation space
:param action_space: Action space
:param net_arch: Network architecture
:param features_extractor: Network to extract features
(a CNN when using images, a nn.Flatten() layer otherwise)
:param features_dim: Number of features
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether the features extractor is shared or not
between the actor and the critic (this saves computation time)
"""
features_extractor: BaseFeaturesExtractor
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
net_arch: List[int],
features_extractor: BaseFeaturesExtractor,
features_dim: int,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
):
super().__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
action_dim = get_action_dim(self.action_space)
self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks: List[nn.Module] = []
for idx in range(n_critics):
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs, self.features_extractor)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Only predict the Q-value using the first network.
This allows to reduce computation when all the estimates are not needed
(e.g. when updating the policy in TD3).
"""
with th.no_grad():
features = self.extract_features(obs, self.features_extractor)
return self.q_networks[0](th.cat([features, actions], dim=1))

View File

@ -0,0 +1,227 @@
import warnings
from typing import Dict, Tuple, Union
import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F
def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
"""
Check if an image observation space (see ``is_image_space``)
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
Use a heuristic that channel dimension is the smallest of the three.
If second dimension is smallest, raise an exception (no support).
:param observation_space:
:return: True if observation space is channels-first image, False if channels-last.
"""
smallest_dimension = np.argmin(observation_space.shape).item()
if smallest_dimension == 1:
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
return smallest_dimension == 0
def is_image_space(
observation_space: spaces.Space,
check_channels: bool = False,
normalized_image: bool = False,
) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
The check is conservative, so that it returns False if there is a doubt.
Valid images: RGB, RGBD, GrayScale with values in [0, 255]
:param observation_space:
:param check_channels: Whether to do or not the check for the number of channels.
e.g., with frame-stacking, the observation space may have more channels than expected.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
:return:
"""
check_dtype = check_bounds = not normalized_image
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3:
# Check the type
if check_dtype and observation_space.dtype != np.uint8:
return False
# Check the value range
incorrect_bounds = np.any(observation_space.low != 0) or np.any(observation_space.high != 255)
if check_bounds and incorrect_bounds:
return False
# Skip channels check
if not check_channels:
return True
# Check the number of channels
if is_image_space_channels_first(observation_space):
n_channels = observation_space.shape[0]
else:
n_channels = observation_space.shape[-1]
# GrayScale, RGB, RGBD
return n_channels in [1, 3, 4]
return False
def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
"""
Handle the different cases for images as PyTorch use channel first format.
:param observation:
:param observation_space:
:return: channel first observation if observation is an image
"""
# Avoid circular import
from stable_baselines3.common.vec_env import VecTransposeImage
if is_image_space(observation_space):
if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
observation = transpose_obs
return observation
def preprocess_obs(
obs: Union[th.Tensor, Dict[str, th.Tensor]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
For discrete observations, it create a one hot vector.
:param obs: Observation
:param observation_space:
:param normalize_images: Whether to normalize images or not
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs # type: ignore[return-value]
assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"
if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
return obs.float()
elif isinstance(observation_space, spaces.Discrete):
# One hot encoding and convert to float to avoid errors
return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float()
elif isinstance(observation_space, spaces.MultiDiscrete):
# Tensor concatenation of one hot encodings of each Categorical sub-space
return th.cat(
[
F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
],
dim=-1,
).view(obs.shape[0], sum(observation_space.nvec))
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")
def get_obs_shape(
observation_space: spaces.Space,
) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return observation_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]
else:
raise NotImplementedError(f"{observation_space} observation space is not supported")
def get_flattened_obs_dim(observation_space: spaces.Space) -> int:
"""
Get the dimension of the observation space when flattened.
It does not apply to image observation space.
Used by the ``FlattenExtractor`` to compute the input shape.
:param observation_space:
:return:
"""
# See issue https://github.com/openai/gym/issues/1915
# it may be a problem for Dict/Tuple spaces too...
if isinstance(observation_space, spaces.MultiDiscrete):
return sum(observation_space.nvec)
else:
# Use Gym internal method
return spaces.utils.flatdim(observation_space)
def get_action_dim(action_space: spaces.Space) -> int:
"""
Get the dimension of the action space.
:param action_space:
:return:
"""
if isinstance(action_space, spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
elif isinstance(action_space, spaces.MultiDiscrete):
# Number of discrete actions
return int(len(action_space.nvec))
elif isinstance(action_space, spaces.MultiBinary):
# Number of binary actions
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
def check_for_nested_spaces(obs_space: spaces.Space) -> None:
"""
Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples).
If so, raise an Exception informing that there is no support for this.
:param obs_space: an observation space
"""
if isinstance(obs_space, (spaces.Dict, spaces.Tuple)):
sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces
for sub_space in sub_spaces:
if isinstance(sub_space, (spaces.Dict, spaces.Tuple)):
raise NotImplementedError(
"Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)."
)

View File

@ -0,0 +1,122 @@
from typing import Callable, List, Optional, Tuple
import numpy as np
import pandas as pd
# import matplotlib
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
from matplotlib import pyplot as plt
from stable_baselines3.common.monitor import load_results
X_TIMESTEPS = "timesteps"
X_EPISODES = "episodes"
X_WALLTIME = "walltime_hrs"
POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
EPISODES_WINDOW = 100
def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
"""
Apply a rolling window to a np.ndarray
:param array: the input Array
:param window: length of the rolling window
:return: rolling window on the input array
"""
shape = array.shape[:-1] + (array.shape[-1] - window + 1, window)
strides = (*array.strides, array.strides[-1])
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
"""
Apply a function to the rolling window of 2 arrays
:param var_1: variable 1
:param var_2: variable 2
:param window: length of the rolling window
:param func: function to apply on the rolling window on variable 2 (such as np.mean)
:return: the rolling output with applied function
"""
var_2_window = rolling_window(var_2, window)
function_on_var2 = func(var_2_window, axis=-1)
return var_1[window - 1 :], function_on_var2
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
"""
Decompose a data frame variable to x ans ys
:param data_frame: the input data
:param x_axis: the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:return: the x and y output
"""
if x_axis == X_TIMESTEPS:
x_var = np.cumsum(data_frame.l.values)
y_var = data_frame.r.values
elif x_axis == X_EPISODES:
x_var = np.arange(len(data_frame))
y_var = data_frame.r.values
elif x_axis == X_WALLTIME:
# Convert to hours
x_var = data_frame.t.values / 3600.0
y_var = data_frame.r.values
else:
raise NotImplementedError
return x_var, y_var
def plot_curves(
xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
) -> None:
"""
plot the curves
:param xy_list: the x and y coordinates to plot
:param x_axis: the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param title: the title of the plot
:param figsize: Size of the figure (width, height)
"""
plt.figure(title, figsize=figsize)
max_x = max(xy[0][-1] for xy in xy_list)
min_x = 0
for _, (x, y) in enumerate(xy_list):
plt.scatter(x, y, s=2)
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
if x.shape[0] >= EPISODES_WINDOW:
# Compute and plot rolling mean with window of size EPISODE_WINDOW
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
plt.plot(x, y_mean)
plt.xlim(min_x, max_x)
plt.title(title)
plt.xlabel(x_axis)
plt.ylabel("Episode Rewards")
plt.tight_layout()
def plot_results(
dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
) -> None:
"""
Plot the results using csv files from ``Monitor`` wrapper.
:param dirs: the save location of the results to plot
:param num_timesteps: only plot the points below this value
:param x_axis: the axis for the x and y output
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
:param task_name: the title of the task to plot
:param figsize: Size of the figure (width, height)
"""
data_frames = []
for folder in dirs:
data_frame = load_results(folder)
if num_timesteps is not None:
data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
data_frames.append(data_frame)
xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
plot_curves(xy_list, x_axis, task_name, figsize)

View File

@ -0,0 +1,57 @@
from typing import Tuple
import numpy as np
class RunningMeanStd:
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
"""
Calulates the running mean and std of a data stream
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param epsilon: helps with arithmetic issues
:param shape: the shape of the data stream's output
"""
self.mean = np.zeros(shape, np.float64)
self.var = np.ones(shape, np.float64)
self.count = epsilon
def copy(self) -> "RunningMeanStd":
"""
:return: Return a copy of the current object.
"""
new_object = RunningMeanStd(shape=self.mean.shape)
new_object.mean = self.mean.copy()
new_object.var = self.var.copy()
new_object.count = float(self.count)
return new_object
def combine(self, other: "RunningMeanStd") -> None:
"""
Combine stats from another ``RunningMeanStd`` object.
:param other: The other object to combine with.
"""
self.update_from_moments(other.mean, other.var, other.count)
def update(self, arr: np.ndarray) -> None:
batch_mean = np.mean(arr, axis=0)
batch_var = np.var(arr, axis=0)
batch_count = arr.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: float) -> None:
delta = batch_mean - self.mean
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
new_var = m_2 / (self.count + batch_count)
new_count = batch_count + self.count
self.mean = new_mean
self.var = new_var
self.count = new_count

View File

@ -0,0 +1,466 @@
"""
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import base64
import functools
import io
import json
import os
import pathlib
import pickle
import warnings
import zipfile
from typing import Any, Dict, Optional, Tuple, Union
import cloudpickle
import torch as th
import stable_baselines3 as sb3
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device, get_system_info
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
"""
Recursive version of getattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_getattr(MyObject, 'sub_object.name') # return test
:param obj:
:param attr: Attribute to retrieve
:return: The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj, *attr.split(".")])
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
"""
Recursive version of setattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
:param obj:
:param attr: Attribute to set
:param val: New value of the attribute
"""
pre, _, post = attr.rpartition(".")
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def is_json_serializable(item: Any) -> bool:
"""
Test if an object is serializable into JSON
:param item: The object to be tested for JSON serialization.
:return: True if object is JSON serializable, false otherwise.
"""
# Try with try-except struct.
json_serializable = True
try:
_ = json.dumps(item)
except TypeError:
json_serializable = False
return json_serializable
def data_to_json(data: Dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
:param data: Dictionary of class parameters to be
stored. Items that are not JSON serializable will be
pickled with Cloudpickle and stored as bytearray in
the JSON file
:return: JSON string of the data serialized.
"""
# First, check what elements can not be JSONfied,
# and turn them into byte-strings
serializable_data = {}
for data_key, data_item in data.items():
# See if object is JSON serializable
if is_json_serializable(data_item):
# All good, store as it is
serializable_data[data_key] = data_item
else:
# Not serializable, cloudpickle it into
# bytes and convert to base64 string for storing.
# Also store type of the class for consumption
# from other languages/humans, so we have an
# idea what was being stored.
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
# Use ":" to make sure we do
# not override these keys
# when we include variables of the object later
cloudpickle_serialization = {
":type:": str(type(data_item)),
":serialized:": base64_encoded,
}
# Add first-level JSON-serializable items of the
# object for further details (but not deeper than this to
# avoid deep nesting).
# First we check that object has attributes (not all do,
# e.g. numpy scalars)
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
# string-representation of the object.
if is_json_serializable(variable_item):
cloudpickle_serialization[variable_name] = variable_item
else:
cloudpickle_serialization[variable_name] = str(variable_item)
serializable_data[data_key] = cloudpickle_serialization
json_string = json.dumps(serializable_data, indent=4)
return json_string
def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
:param json_string: JSON serialization of the class-parameters
that should be loaded.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:return: Loaded class parameters.
"""
if custom_objects is not None and not isinstance(custom_objects, dict):
raise ValueError("custom_objects argument must be a dict or None")
json_dict = json.loads(json_string)
# This will be filled with deserialized data
return_data = {}
for data_key, data_item in json_dict.items():
if custom_objects is not None and data_key in custom_objects.keys():
# If item is provided in custom_objects, replace
# the one from JSON with the one in custom_objects
return_data[data_key] = custom_objects[data_key]
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
# If item is dictionary with ":serialized:"
# key, this means it is serialized with cloudpickle.
serialization = data_item[":serialized:"]
# Try-except deserialization in case we run into
# errors. If so, we can tell bit more information to
# user.
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except (RuntimeError, TypeError, AttributeError) as e:
warnings.warn(
f"Could not deserialize object {data_key}. "
"Consider using `custom_objects` argument to replace "
"this object.\n"
f"Exception: {e}"
)
else:
return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item
return return_data
@functools.singledispatch
def open_path(
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
If the mode is write ("w", "write") it checks that the file is writable.
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2,
it raises a warning.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
# Note(antonin): the true annotation should be IO[bytes]
# but there is not easy way to check that
allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom)
if not isinstance(path, allowed_types):
raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.")
if path.closed:
raise ValueError(f"File stream {path} is closed.")
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
error_msg = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {error_msg} file.")
return path
@open_path.register(str)
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to open. If mode is "w" then it ensures that the path exists
by creating the necessary folders and renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix)
@open_path.register(pathlib.Path)
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to check. If mode is "w" then it
ensures that the path exists by creating the necessary folders and
renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r"
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
if mode not in ("w", "r"):
raise ValueError("Expected mode to be either 'w' or 'r'.")
if mode == "r":
try:
return open_path(path.open("rb"), mode, verbose, suffix)
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
if verbose >= 2:
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
path, suffix = newpath, None
else:
raise error
else:
try:
if path.suffix == "" and suffix is not None and suffix != "":
path = pathlib.Path(f"{path}.{suffix}")
if path.exists() and path.is_file() and verbose >= 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
return open_path(path.open("wb"), mode, verbose, suffix)
except IsADirectoryError:
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
path = pathlib.Path(f"{path}_2")
except FileNotFoundError: # Occurs when the parent folder doesn't exist
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
path.parent.mkdir(exist_ok=True, parents=True)
# if opening was successful uses the open_path() function
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
# with corrections
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
return open_path_pathlib(path, mode, verbose, suffix)
def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
pytorch_variables: Optional[Dict[str, Any]] = None,
verbose: int = 0,
) -> None:
"""
Save model data to a zip archive.
:param save_path: Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
if isinstance(save_path, (str, pathlib.Path)):
file.close()
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
"""
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
If the path does not exist, it will attempt to load using the .pkl suffix.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
custom_objects: Optional[Dict[str, Any]] = None,
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
"""
Load model data from a .zip archive
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param device: Device on which the code should run.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param print_system_info: Whether to print or not the system info
about the saved model.
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
# set device to cpu if cuda is not available
device = get_device(device=device)
# Open the zip archive and load data
try:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
pytorch_variables = None
params = {}
# Debug system info first
if print_system_info:
if "system_info.txt" in namelist:
print("== SAVED MODEL SYSTEM INFO ==")
print(archive.read("system_info.txt").decode())
else:
warnings.warn(
"The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
UserWarning,
)
if "data" in namelist and load_data:
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data, custom_objects=custom_objects)
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
# Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
th_object = th.load(file_content, map_location=device, weights_only=False)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables

View File

@ -0,0 +1,136 @@
from typing import Any, Callable, Dict, Iterable, Optional
import torch
from torch.optim import Optimizer
class RMSpropTFLike(Optimizer):
r"""Implements RMSprop algorithm with closer match to Tensorflow version.
For reproducibility with original stable-baselines. Use this
version with e.g. A2C for stabler learning than with the PyTorch
RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop.
See a more throughout conversion in pytorch-image-models repository:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py
Changes to the original RMSprop:
- Move epsilon inside square root
- Initialize squared gradient to ones rather than zeros
Proposed by G. Hinton in his
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
The implementation here takes the square root of the gradient average before
adding epsilon (note that TensorFlow interchanges these two operations). The effective
learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha`
is the scheduled learning rate and :math:`v` is the weighted moving average
of the squared gradient.
:params: iterable of parameters to optimize or dicts defining
parameter groups
:param lr: learning rate (default: 1e-2)
:param momentum: momentum factor (default: 0)
:param alpha: smoothing constant (default: 0.99)
:param eps: term added to the denominator to improve
numerical stability (default: 1e-8)
:param centered: if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
:param weight_decay: weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super().__init__(params, defaults)
def __setstate__(self, state: Dict[str, Any]) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
"""Performs a single optimization step.
:param closure: A closure that reevaluates the model
and returns the loss.
:return: loss
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("RMSpropTF does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# PyTorch initialized to zeros here
state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format)
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
square_avg = state["square_avg"]
alpha = group["alpha"]
state["step"] += 1
if group["weight_decay"] != 0:
grad = grad.add(p, alpha=group["weight_decay"])
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if group["centered"]:
grad_avg = state["grad_avg"]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
# PyTorch added epsilon after square root
# avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_()
else:
# PyTorch added epsilon after square root
# avg = square_avg.sqrt().add_(group['eps'])
avg = square_avg.add(group["eps"]).sqrt_()
if group["momentum"] > 0:
buf = state["momentum_buffer"]
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
p.add_(buf, alpha=-group["lr"])
else:
p.addcdiv_(grad, avg, value=-group["lr"])
return loss

View File

@ -0,0 +1,318 @@
from typing import Dict, List, Tuple, Type, Union
import gymnasium as gym
import torch as th
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device
class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space:
:param features_dim: Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None:
super().__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
return self._features_dim
class FlattenExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space) -> None:
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN Nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(
self,
observation_space: gym.Space,
features_dim: int = 512,
normalized_image: bool = False,
) -> None:
assert isinstance(observation_space, spaces.Box), (
"NatureCNN must be used with a gym.spaces.Box ",
f"observation space, not {observation_space}",
)
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
"You should use NatureCNN "
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html.\n"
"If you are using `VecNormalize` or already normalized channel-first images "
"you should pass `normalize_images=False`: \n"
"https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
def create_mlp(
input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
with_bias: bool = True,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim:
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: The activation function
to use after each layer.
:param squash_output: Whether to squash the output using a Tanh
activation function
:param with_bias: If set to False, the layers will not learn an additive bias
:return:
"""
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias))
if squash_output:
modules.append(nn.Tanh())
return modules
class MlpExtractor(nn.Module):
"""
Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly
the observations (if no features extractor is applied) as an input and outputs a latent representation
for the policy and a value network.
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers.
It can be in either of the following forms:
1. ``dict(vf=[<list of layer sizes>], pi=[<list of layer sizes>])``: to specify the amount and size of the layers in the
policy and value nets individually. If it is missing any of the keys (pi or vf),
zero layers will be considered for that key.
2. ``[<list of layer sizes>]``: "shortcut" in case the amount and size of the layers
in the policy and value nets are the same. Same as ``dict(vf=int_list, pi=int_list)``
where int_list is the same for the actor and critic.
.. note::
If a key is not specified or an empty list is passed ``[]``, a linear network will be used.
:param feature_dim: Dimension of the feature vector (can be the output of a CNN)
:param net_arch: The specification of the policy and value networks.
See above for details on its formatting.
:param activation_fn: The activation function to use for the networks.
:param device: PyTorch device.
"""
def __init__(
self,
feature_dim: int,
net_arch: Union[List[int], Dict[str, List[int]]],
activation_fn: Type[nn.Module],
device: Union[th.device, str] = "auto",
) -> None:
super().__init__()
device = get_device(device)
policy_net: List[nn.Module] = []
value_net: List[nn.Module] = []
last_layer_dim_pi = feature_dim
last_layer_dim_vf = feature_dim
# save dimensions of layers in policy and value nets
if isinstance(net_arch, dict):
# Note: if key is not specificed, assume linear network
pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network
vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network
else:
pi_layers_dims = vf_layers_dims = net_arch
# Iterate through the policy layers and build the policy net
for curr_layer_dim in pi_layers_dims:
policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
policy_net.append(activation_fn())
last_layer_dim_pi = curr_layer_dim
# Iterate through the value layers and build the value net
for curr_layer_dim in vf_layers_dims:
value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
value_net.append(activation_fn())
last_layer_dim_vf = curr_layer_dim
# Save dim, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Create networks
# If the list of layers is empty, the network will just act as an Identity module
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor:
return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor:
return self.value_net(features)
class CombinedExtractor(BaseFeaturesExtractor):
"""
Combined features extractor for Dict observation spaces.
Builds a features extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""
def __init__(
self,
observation_space: spaces.Dict,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(observation_space, features_dim=1)
extractors: Dict[str, nn.Module] = {}
total_concat_size = 0
for key, subspace in observation_space.spaces.items():
if is_image_space(subspace, normalized_image=normalized_image):
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
total_concat_size += get_flattened_obs_dim(subspace)
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations: TensorDict) -> th.Tensor:
encoded_tensor_list = []
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
return th.cat(encoded_tensor_list, dim=1)
def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
"""
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
which can be different for the actor and the critic.
It is assumed to be a list of ints or a dict.
1. If it is a list, actor and critic networks will have the same architecture.
The architecture is represented by a list of integers (of arbitrary length (zero allowed))
each specifying the number of units per layer.
If the number of ints is zero, the network will be linear.
2. If it is a dict, it should have the following structure:
``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
where the network architecture is a list as described in 1.
For example, to have actor and critic that share the same network architecture,
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
If you want a different architecture for the actor and the critic,
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
.. note::
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
between the actor and the critic are allowed (to prevent issues with target networks).
:param net_arch: The specification of the actor and critic networks.
See above for details on its formatting.
:return: The network architectures for the actor and the critic
"""
if isinstance(net_arch, list):
actor_arch, critic_arch = net_arch, net_arch
else:
assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
return actor_arch, critic_arch

View File

@ -0,0 +1,101 @@
"""Common aliases for type hints"""
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
import gymnasium as gym
import numpy as np
import torch as th
# Avoid circular imports, we use type hint as string to avoid it too
if TYPE_CHECKING:
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import VecEnv
GymEnv = Union[gym.Env, "VecEnv"]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymResetReturn = Tuple[GymObs, Dict]
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
GymStepReturn = Tuple[GymObs, float, bool, bool, Dict]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict]
# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Schedule = Callable[[float], float]
class RolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class DictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class ReplayBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor
class DictReplayBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
next_observations: TensorDict
dones: th.Tensor
rewards: th.Tensor
class RolloutReturn(NamedTuple):
episode_timesteps: int
n_episodes: int
continue_training: bool
class TrainFrequencyUnit(Enum):
STEP = "step"
EPISODE = "episode"
class TrainFreq(NamedTuple):
frequency: int
unit: TrainFrequencyUnit # either "step" or "episode"
class PolicyPredictor(Protocol):
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""

View File

@ -0,0 +1,552 @@
import glob
import os
import platform
import random
import re
from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, List, Optional, Tuple, Union
import cloudpickle
import gymnasium as gym
import numpy as np
import torch as th
from gymnasium import spaces
import stable_baselines3 as sb3
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
SummaryWriter = None # type: ignore[misc, assignment]
from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
"""
Seed the different random generators.
:param seed:
:param using_cuda:
"""
# Seed python RNG
random.seed(seed)
# Seed numpy RNG
np.random.seed(seed)
# seed the RNG for all devices (both CPU and CUDA)
th.manual_seed(seed)
if using_cuda:
# Deterministic operations for CuDNN, it may impact performances
th.backends.cudnn.deterministic = True
th.backends.cudnn.benchmark = False
# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
interpretation:
ev=0 => might as well have predicted zero
ev=1 => perfect prediction
ev<0 => worse than just predicting zero
:param y_pred: the prediction
:param y_true: the expected value
:return: explained variance of ypred and y
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
"""
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule:
"""
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
# If the passed schedule is a float
# create a constant function
if isinstance(value_schedule, (float, int)):
# Cast to float to avoid errors
value_schedule = constant_fn(float(value_schedule))
else:
assert callable(value_schedule)
# Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900
# Some types are have odd behaviors when part of a Schedule, like numpy floats
return lambda progress_remaining: float(value_schedule(progress_remaining))
def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
"""
Create a function that interpolates linearly between start and end
between ``progress_remaining`` = 1 and ``progress_remaining`` = ``end_fraction``.
This is used in DQN for linearly annealing the exploration fraction
(epsilon for the epsilon-greedy strategy).
:params start: value to start with if ``progress_remaining`` = 1
:params end: value to end with if ``progress_remaining`` = 0
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return: Linear schedule function.
"""
def func(progress_remaining: float) -> float:
if (1 - progress_remaining) > end_fraction:
return end
else:
return start + (1 - progress_remaining) * (end - start) / end_fraction
return func
def constant_fn(val: float) -> Schedule:
"""
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val: constant value
:return: Constant schedule function.
"""
def func(_):
return val
return func
def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
For now, it supports only cpu and cuda.
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
if device.type == th.device("cuda").type and not th.cuda.is_available():
return th.device("cpu")
return device
def get_latest_run_id(log_path: str = "", log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
:param log_path: Path to the log folder containing several runs.
:param log_name: Name of the experiment. Each run is stored
in a folder named ``log_name_1``, ``log_name_2``, ...
:return: latest run number
"""
max_run_id = 0
for path in glob.glob(os.path.join(log_path, f"{glob.escape(log_name)}_[0-9]*")):
file_name = path.split(os.sep)[-1]
ext = file_name.split("_")[-1]
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
max_run_id = int(ext)
return max_run_id
def configure_logger(
verbose: int = 0,
tensorboard_log: Optional[str] = None,
tb_log_name: str = "",
reset_num_timesteps: bool = True,
) -> Logger:
"""
Configure the logger's outputs.
:param verbose: Verbosity level: 0 for no output, 1 for the standard output to be part of the logger outputs
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param tb_log_name: tensorboard log
:param reset_num_timesteps: Whether the ``num_timesteps`` attribute is reset or not.
It allows to continue a previous learning curve (``reset_num_timesteps=False``)
or start from t=0 (``reset_num_timesteps=True``, the default).
:return: The logger object
"""
save_path, format_strings = None, ["stdout"]
if tensorboard_log is not None and SummaryWriter is None:
raise ImportError("Trying to log data to tensorboard but tensorboard is not installed.")
if tensorboard_log is not None and SummaryWriter is not None:
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
if not reset_num_timesteps:
# Continue training in the same directory
latest_run_id -= 1
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
if verbose >= 1:
format_strings = ["stdout", "tensorboard"]
else:
format_strings = ["tensorboard"]
elif verbose == 0:
format_strings = [""]
return configure(save_path, format_strings=format_strings)
def check_for_correct_spaces(env: GymEnv, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
spaces match after loading the model with given env.
Checked parameters:
- observation_space
- action_space
:param env: Environment to check for valid spaces
:param observation_space: Observation space to check against
:param action_space: Action space to check against
"""
if observation_space != env.observation_space:
raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
"""
If the spaces are Box, check that they have the same shape.
If the spaces are Dict, it recursively checks the subspaces.
:param space1: Space
:param space2: Other space
"""
if isinstance(space1, spaces.Dict):
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
for key in space1.spaces.keys():
check_shape_equal(space1.spaces[key], space2.spaces[key])
elif isinstance(space1, spaces.Box):
assert space1.shape == space2.shape, "spaces must have the same shape"
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:
"""
For box observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ f"Box environment, please use {observation_space.shape} "
+ "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
)
def is_vectorized_discrete_observation(observation: Union[int, np.ndarray], observation_space: spaces.Discrete) -> bool:
"""
For discrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if isinstance(observation, int) or observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for "
+ "Discrete environment, please use () or (n_env,) for the observation shape."
)
def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: spaces.MultiDiscrete) -> bool:
"""
For multidiscrete observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ f"environment, please use ({len(observation_space.nvec)},) or "
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape."
)
def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: spaces.MultiBinary) -> bool:
"""
For multibinary observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
if observation.shape == observation_space.shape:
return False
elif len(observation.shape) == len(observation_space.shape) + 1 and observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError(
f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ f"environment, please use {observation_space.shape} or "
+ f"(n_env, {observation_space.n}) for the observation shape."
)
def is_vectorized_dict_observation(observation: np.ndarray, observation_space: spaces.Dict) -> bool:
"""
For dict observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
# We first assume that all observations are not vectorized
all_non_vectorized = True
for key, subspace in observation_space.spaces.items():
# This fails when the observation is not vectorized
# or when it has the wrong shape
if observation[key].shape != subspace.shape:
all_non_vectorized = False
break
if all_non_vectorized:
return False
all_vectorized = True
# Now we check that all observation are vectorized and have the correct shape
for key, subspace in observation_space.spaces.items():
if observation[key].shape[1:] != subspace.shape:
all_vectorized = False
break
if all_vectorized:
return True
else:
# Retrieve error message
error_msg = ""
try:
is_vectorized_observation(observation[key], observation_space.spaces[key])
except ValueError as e:
error_msg = f"{e}"
raise ValueError(
f"There seems to be a mix of vectorized and non-vectorized observations. "
f"Unexpected observation shape {observation[key].shape} for key {key} "
f"of type {observation_space.spaces[key]}. {error_msg}"
)
def is_vectorized_observation(observation: Union[int, np.ndarray], observation_space: spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: the input observation to validate
:param observation_space: the observation space
:return: whether the given observation is vectorized or not
"""
is_vec_obs_func_dict = {
spaces.Box: is_vectorized_box_observation,
spaces.Discrete: is_vectorized_discrete_observation,
spaces.MultiDiscrete: is_vectorized_multidiscrete_observation,
spaces.MultiBinary: is_vectorized_multibinary_observation,
spaces.Dict: is_vectorized_dict_observation,
}
for space_type, is_vec_obs_func in is_vec_obs_func_dict.items():
if isinstance(observation_space, space_type):
return is_vec_obs_func(observation, observation_space) # type: ignore[operator]
else:
# for-else happens if no break is called
raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.")
def safe_mean(arr: Union[np.ndarray, list, deque]) -> float:
"""
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type]
def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]
def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Raises ``ValueError`` if iterables not of equal length.
Code inspired by Stackoverflow answer for question #32954486.
:param \*iterables: iterables to ``zip()``
"""
# As in Stackoverflow #32954486, use
# new object for "empty" in case we have
# Nones in iterable.
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
def polyak_update(
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Perform a Polyak average update on ``target_params`` using ``params``:
target parameters are slowly updated towards the main parameters.
``tau``, the soft update coefficient controls the interpolation:
``tau=1`` corresponds to copying the parameters to the target ones whereas nothing happens when ``tau=0``.
The Polyak update is done in place, with ``no_grad``, and therefore does not create intermediate tensors,
or a computation graph, reducing memory cost and improving performance. We scale the target params
by ``1-tau`` (in-place), add the new weights, scaled by ``tau`` and store the result of the sum in the target
params (in place).
See https://github.com/DLR-RM/stable-baselines3/issues/93
:param params: parameters to use to update the target params
:param target_params: parameters to update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
"""
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
"""
Moves the observation to the given device.
:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")
def should_collect_more_steps(
train_freq: TrainFreq,
num_collected_steps: int,
num_collected_episodes: int,
) -> bool:
"""
Helper used in ``collect_rollouts()`` of off-policy algorithms
to determine the termination condition.
:param train_freq: How much experience should be collected before updating the policy.
:param num_collected_steps: The number of already collected steps.
:param num_collected_episodes: The number of already collected episodes.
:return: Whether to continue or not collecting experience
by doing rollouts of the current policy.
"""
if train_freq.unit == TrainFrequencyUnit.STEP:
return num_collected_steps < train_freq.frequency
elif train_freq.unit == TrainFrequencyUnit.EPISODE:
return num_collected_episodes < train_freq.frequency
else:
raise ValueError(
"The unit of the `train_freq` must be either TrainFrequencyUnit.STEP "
f"or TrainFrequencyUnit.EPISODE not '{train_freq.unit}'!"
)
def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]:
"""
Retrieve system and python env info for the current system.
:param print_info: Whether to print or not those infos
:return: Dictionary summing up the version for each relevant package
and a formatted string.
"""
env_info = {
# In OS, a regex is used to add a space between a "#" and a number to avoid
# wrongly linking to another issue on GitHub. Example: turn "#42" to "# 42".
"OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"),
"Python": platform.python_version(),
"Stable-Baselines3": sb3.__version__,
"PyTorch": th.__version__,
"GPU Enabled": str(th.cuda.is_available()),
"Numpy": np.__version__,
"Cloudpickle": cloudpickle.__version__,
"Gymnasium": gym.__version__,
}
try:
import gym as openai_gym
env_info.update({"OpenAI Gym": openai_gym.__version__})
except ImportError:
pass
env_info_str = ""
for key, value in env_info.items():
env_info_str += f"- {key}: {value}\n"
if print_info:
print(env_info_str)
return env_info, env_info_str

View File

@ -0,0 +1,105 @@
from copy import deepcopy
from typing import Optional, Type, TypeVar
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: The ``VecEnv`` that is going to be unwrapped
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
"""
Retrieve a ``VecNormalize`` object by recursively searching.
:param env: The VecEnv that is going to be unwrapped
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
"""
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
:param env: The VecEnv that is going to be checked
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Synchronize the normalization statistics of an eval environment and train environment
when they are both wrapped in a ``VecNormalize`` wrapper.
:param env: Training env
:param eval_env: Environment used for evaluation.
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
assert isinstance(eval_env_tmp, VecEnvWrapper), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
if isinstance(env_tmp, VecNormalize):
assert isinstance(eval_env_tmp, VecNormalize), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecNormalize but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv
__all__ = [
"CloudpickleWrapper",
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
"VecExtractDictObs",
"VecFrameStack",
"VecMonitor",
"VecNormalize",
"VecTransposeImage",
"VecVideoRecorder",
"unwrap_vec_wrapper",
"unwrap_vec_normalize",
"is_vecenv_wrapped",
"sync_envs_normalization",
]

View File

@ -0,0 +1,482 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
import cloudpickle
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Define type aliases here to avoid circular import
# Used when we want to access one or more VecEnv
VecEnvIndices = Union[None, int, Iterable[int]]
# VecEnvObs is what is returned by the reset() method
# it contains the observation for each env
VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
# VecEnvStepReturn is what is returned by the step() method
# it contains the observation, reward, done, info for each env
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param images_nhwc: list or array of images, ndim=4 once turned into array.
n = batch index, h = height, w = width, c = channel
:return: img_HWc, ndim=3
"""
img_nhwc = np.asarray(images_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
# new_width was named W before
new_width = int(np.ceil(float(n_images) / new_height))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
# img_HWhwc
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
# img_HhWwc
out_image = out_image.transpose(0, 2, 1, 3, 4)
# img_Hh_Ww_c
out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
return out_image
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: Number of environments
:param observation_space: Observation space
:param action_space: Action space
"""
def __init__(
self,
num_envs: int,
observation_space: spaces.Space,
action_space: spaces.Space,
):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
# options to be used in the next call to env.reset()
self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
except AttributeError:
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
render_modes = [None for _ in range(num_envs)]
assert all(
render_mode == render_modes[0] for render_mode in render_modes
), "render_mode mode should be the same for all environments"
self.render_mode = render_modes[0]
render_modes = []
if self.render_mode is not None:
if self.render_mode == "rgb_array":
# SB3 uses OpenCV for the "human" mode
render_modes = ["human", "rgb_array"]
else:
render_modes = [self.render_mode]
self.metadata = {"render_modes": render_modes}
def _reset_seeds(self) -> None:
"""
Reset the seeds that are going to be used at the next reset.
"""
self._seeds = [None for _ in range(self.num_envs)]
def _reset_options(self) -> None:
"""
Reset the options that are going to be used at the next reset.
"""
self._options = [{} for _ in range(self.num_envs)]
@abstractmethod
def reset(self) -> VecEnvObs:
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
:return: observation
"""
raise NotImplementedError()
@abstractmethod
def step_async(self, actions: np.ndarray) -> None:
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
raise NotImplementedError()
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
"""
Wait for the step taken with step_async().
:return: observation, reward, done, information
"""
raise NotImplementedError()
@abstractmethod
def close(self) -> None:
"""
Clean up the environment's resources.
"""
raise NotImplementedError()
@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""
Return attribute from vectorized environment.
:param attr_name: The name of the attribute whose value to return
:param indices: Indices of envs to get attribute from
:return: List of values of 'attr_name' in all environments
"""
raise NotImplementedError()
@abstractmethod
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""
Set attribute inside vectorized environments.
:param attr_name: The name of attribute to assign new value
:param value: Value to assign to `attr_name`
:param indices: Indices of envs to assign value
:return:
"""
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""
Call instance methods of vectorized environments.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: List of items returned by the environment's method call
"""
raise NotImplementedError()
@abstractmethod
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""
Check if environments are wrapped with a given wrapper.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: True if the env is wrapped, False otherwise, for each env queried.
"""
raise NotImplementedError()
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
:param actions: the action
:return: observation, reward, done, information
"""
self.step_async(actions)
return self.step_wait()
def get_images(self) -> Sequence[Optional[np.ndarray]]:
"""
Return RGB images from each environment when available
"""
raise NotImplementedError
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
"""
Gym environment rendering
:param mode: the rendering type
"""
if mode == "human" and self.render_mode != mode:
# Special case, if the render_mode="rgb_array"
# we can still display that image using opencv
if self.render_mode != "rgb_array":
warnings.warn(
f"You tried to render a VecEnv with mode='{mode}' "
"but the render mode defined when initializing the environment must be "
f"'human' or 'rgb_array', not '{self.render_mode}'."
)
return None
elif mode and self.render_mode != mode:
warnings.warn(
f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment.
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
)
return None
mode = mode or self.render_mode
if mode is None:
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
return None
# mode == self.render_mode == "human"
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
if self.render_mode == "human":
self.env_method("render")
return None
if mode == "rgb_array" or mode == "human":
# call the render method of the environments
images = self.get_images()
# Create a big image by tiling images from subprocesses
bigimg = tile_images(images) # type: ignore[arg-type]
if mode == "human":
# Display it using OpenCV
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
else:
return bigimg
else:
# Other render modes:
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
# and we don't return the values
self.env_method("render")
return None
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
WARNING: since gym 0.26, those seeds will only be passed to the environment
at the next reset.
:param seed: The random seed. May be None for completely random seeding.
:return: Returns a list containing the seeds for each individual env.
Note that all list elements may be None, if the env does not return anything when being seeded.
"""
if seed is None:
# To ensure that subprocesses have different seeds,
# we still populate the seed variable when no argument is passed
seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
"""
if options is None:
options = {}
# Use deepcopy to avoid side effects
if isinstance(options, dict):
self._options = deepcopy([options] * self.num_envs)
else:
self._options = deepcopy(options)
@property
def unwrapped(self) -> "VecEnv":
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: name of attribute to check for
:param already_found: whether this attribute has already been found in a wrapper
:return: name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return f"{type(self).__module__}.{type(self).__name__}"
else:
return None
def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
:param indices: refers to indices of envs.
:return: the implied list of indices.
"""
if indices is None:
indices = range(self.num_envs)
elif isinstance(indices, int):
indices = [indices]
return indices
class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
:param venv: the vectorized environment to wrap
:param observation_space: the observation space (can be None to load from venv)
:param action_space: the action space (can be None to load from venv)
"""
def __init__(
self,
venv: VecEnv,
observation_space: Optional[spaces.Space] = None,
action_space: Optional[spaces.Space] = None,
):
self.venv = venv
super().__init__(
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
)
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions: np.ndarray) -> None:
self.venv.step_async(actions)
@abstractmethod
def reset(self) -> VecEnvObs:
pass
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
pass
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
return self.venv.set_options(options)
def close(self) -> None:
return self.venv.close()
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
return self.venv.render(mode=mode)
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
return self.venv.get_attr(attr_name, indices)
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
return self.venv.set_attr(attr_name, value, indices)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
"""
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = f"{type(self).__module__}.{type(self).__name__}"
error_str = (
f"Error: Recursive attribute lookup for {name} from {own_class} is "
f"ambiguous and hides attribute from {blocked_class}"
)
raise AttributeError(error_str)
return self.getattr_recursive(name)
def _get_all_attributes(self) -> Dict[str, Any]:
"""Get all (inherited) instance and class attributes
:return: all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes
def getattr_recursive(self, name: str) -> Any:
"""Recursively check wrappers to find attribute.
:param name: name of attribute to look for
:return: attribute
"""
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
elif hasattr(self.venv, "getattr_recursive"):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
else: # attribute not present, child is an unwrapped VecEnv
attr = getattr(self.venv, name)
return attr
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""See base class.
:return: name of module whose attribute is being shadowed, if any.
"""
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}"
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
# this wrapper does not have the attribute. Keep searching.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
return shadowed_wrapper_class
class CloudpickleWrapper:
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param var: the variable you wish to wrap for pickling with cloudpickle
"""
def __init__(self, var: Any):
self.var = var
def __getstate__(self) -> Any:
return cloudpickle.dumps(self.var)
def __setstate__(self, var: Any) -> None:
self.var = cloudpickle.loads(var)

View File

@ -0,0 +1,141 @@
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ``Cartpole-v1``,
as the overhead of multiprocess or multithread outweighs the environment computation time.
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: a list of functions
that return environments to vectorize
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
"""
actions: np.ndarray
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
raise ValueError(
"You tried to create multiple environments, but the function to create them returned the same instance "
"instead of creating different objects. "
"You are probably using `make_vec_env(lambda: env)` or `DummyVecEnv([lambda: env] * n_envs)`. "
"You should replace `lambda: env` by a `make_env` function that "
"creates a new instance of the environment at every call "
"(using `gym.make()` for instance). You can take a look at the documentation for an example. "
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
)
env = self.envs[0]
super().__init__(len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.metadata = env.metadata
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def step_wait(self) -> VecEnvStepReturn:
# Avoid circular imports
for env_idx in range(self.num_envs):
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
self.actions[env_idx]
)
# convert to SB3 VecEnv api
self.buf_dones[env_idx] = terminated or truncated
# See https://github.com/openai/gym/issues/3102
# Gym 0.26 introduces a breaking change
self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {}
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
self._save_obs(env_idx, obs)
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return self._obs_from_buf()
def close(self) -> None:
for env in self.envs:
env.close()
def get_images(self) -> Sequence[Optional[np.ndarray]]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.envs]
return [env.render() for env in self.envs] # type: ignore[misc]
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
"""
Gym environment rendering. If there are multiple environments then
they are tiled together in one image via ``BaseVecEnv.render()``.
:param mode: The rendering type.
"""
return super().render(mode=mode)
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
for key in self.keys:
if key is None:
self.buf_obs[key][env_idx] = obs
else:
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View File

@ -0,0 +1,100 @@
import warnings
from inspect import signature
from typing import Union
import gymnasium
try:
import gym
gym_installed = True
except ImportError:
gym_installed = False
def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover
"""
Adapted from https://github.com/thu-ml/tianshou.
Takes an environment and patches it to return Gymnasium env.
This function takes the environment object and returns a patched
env, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium env
:return: Patched env (gymnasium env)
"""
# Gymnasium env, no patching to be done
if isinstance(env, gymnasium.Env):
return env
if not gym_installed or not isinstance(env, gym.Env):
raise ValueError(
f"The environment is of type {type(env)}, not a Gymnasium "
f"environment. In this case, we expect OpenAI Gym to be "
f"installed and the environment to be an OpenAI Gym environment."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e
warnings.warn(
"You provided an OpenAI Gym environment. "
"We strongly recommend transitioning to Gymnasium environments. "
"Stable-Baselines3 is automatically wrapping your environments in a compatibility "
"layer, which could potentially cause issues."
)
if "seed" in signature(env.unwrapped.reset).parameters:
# Gym 0.26+ env
return shimmy.GymV26CompatibilityV0(env=env)
# Gym 0.21 env
return shimmy.GymV21CompatibilityV0(env=env)
def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover
"""
Takes a space and patches it to return Gymnasium Space.
This function takes the space object and returns a patched
space, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium Space
:return: Patched space (gymnasium Space)
"""
# Gymnasium space, no convertion to be done
if isinstance(space, gymnasium.Space):
return space
if not gym_installed or not isinstance(space, gym.Space):
raise ValueError(
f"The space is of type {type(space)}, not a Gymnasium "
f"space. In this case, we expect OpenAI Gym to be "
f"installed and the space to be an OpenAI Gym space."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym space. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym space with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e
warnings.warn(
"You loaded a model that was trained using OpenAI Gym. "
"We strongly recommend transitioning to Gymnasium by saving that model again."
)
return shimmy.openai_gym_compatibility._convert_space(space)

View File

@ -0,0 +1,176 @@
import warnings
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
class StackedObservations(Generic[TObs]):
"""
Frame stacking wrapper for data.
Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if observation is an image space.
:param num_envs: Number of environments
:param n_stack: Number of frames to stack
:param observation_space: Environment observation space
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last".
For Dict space, channels_order can also be a dictionary.
"""
def __init__(
self,
num_envs: int,
n_stack: int,
observation_space: Union[spaces.Box, spaces.Dict],
channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None,
) -> None:
self.n_stack = n_stack
self.observation_space = observation_space
if isinstance(observation_space, spaces.Dict):
if not isinstance(channels_order, Mapping):
channels_order = {key: channels_order for key in observation_space.spaces.keys()}
self.sub_stacked_observations = {
key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type]
for key, subspace in observation_space.spaces.items()
}
self.stacked_observation_space = spaces.Dict(
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
) # type: Union[spaces.Dict, spaces.Box] # make mypy happy
elif isinstance(observation_space, spaces.Box):
if isinstance(channels_order, Mapping):
raise TypeError("When the observation space is Box, channels_order can't be a dict.")
self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking(
n_stack, observation_space, channels_order
)
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
self.stacked_observation_space = spaces.Box(
low=low,
high=high,
dtype=observation_space.dtype, # type: ignore[arg-type]
)
self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype)
else:
raise TypeError(
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
)
@staticmethod
def compute_stacking(
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
) -> Tuple[bool, int, Tuple[int, ...], int]:
"""
Calculates the parameters in order to stack observations
:param n_stack: Number of observations to stack
:param observation_space: Observation space
:param channels_order: Order of the channels
:return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis
"""
if channels_order is None:
# Detect channel location automatically for images
if is_image_space(observation_space):
channels_first = is_image_space_channels_first(observation_space)
else:
# Default behavior for non-image space, stack on the last axis
channels_first = False
else:
assert channels_order in {
"last",
"first",
}, "`channels_order` must be one of following: 'last', 'first'"
channels_first = channels_order == "first"
# This includes the vec-env dimension (first)
stack_dimension = 1 if channels_first else -1
repeat_axis = 0 if channels_first else -1
stacked_shape = list(observation_space.shape)
stacked_shape[repeat_axis] *= n_stack
return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis
def reset(self, observation: TObs) -> TObs:
"""
Reset the stacked_obs, add the reset observation to the stack, and return the stack.
:param observation: Reset observation
:return: The stacked reset observation
"""
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
self.stacked_obs[...] = 0
if self.channels_first:
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs
def update(
self,
observations: TObs,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[TObs, List[Dict[str, Any]]]:
"""
Add the observations to the stack and use the dones to update the infos.
:param observations: Observations
:param dones: Dones
:param infos: Infos
:return: Tuple of the stacked observations and the updated infos
"""
if isinstance(observations, dict):
# From [{}, {terminal_obs: {key1: ..., key2: ...}}]
# to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
sub_infos = {
key: [
{"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {}
for info in infos
]
for key in observations.keys()
}
stacked_obs = {}
stacked_infos = {}
for key, obs in observations.items():
stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key])
# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
for key in stacked_infos.keys():
for env_idx in range(len(infos)):
if "terminal_observation" in infos[env_idx]:
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
return stacked_obs, infos
shift = -observations.shape[self.stack_dimension]
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
for env_idx, done in enumerate(dones):
if done:
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
previous_stack = self.stacked_obs[env_idx, ..., :shift]
new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stacked_obs[env_idx] = 0
if self.channels_first:
self.stacked_obs[:, shift:, ...] = observations
else:
self.stacked_obs[..., shift:] = observations
return self.stacked_obs, infos

View File

@ -0,0 +1,232 @@
import multiprocessing as mp
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import (
CloudpickleWrapper,
VecEnv,
VecEnvIndices,
VecEnvObs,
VecEnvStepReturn,
)
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def _worker(
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
reset_info: Optional[Dict[str, Any]] = {}
while True:
try:
cmd, data = remote.recv()
if cmd == "step":
observation, reward, terminated, truncated, info = env.step(data)
# convert to SB3 VecEnv api
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
maybe_options = {"options": data[1]} if data[1] else {}
observation, reset_info = env.reset(seed=data[0], **maybe_options)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(getattr(env, data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
class SubprocVecEnv(VecEnv):
"""
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
process, allowing significant speed up when the environment is computationally complex.
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
number of logical cores on your CPU.
.. warning::
Only 'forkserver' and 'spawn' start methods are thread-safe,
which is important when TensorFlow sessions or other non thread-safe
libraries are used in the parent (see issue #217). However, compared to
'fork' they incur a small start-up cost and have restrictions on
global variables. With those methods, users must wrap the code in an
``if __name__ == "__main__":`` block.
For more information, see the multiprocessing documentation.
:param env_fns: Environments to run in subprocesses
:param start_method: method used to start the subprocesses.
Must be one of the methods returned by multiprocessing.get_all_start_methods().
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
process.start()
self.processes.append(process)
work_remote.close()
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
super().__init__(len(env_fns), observation_space, action_space)
def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions):
remote.send(("step", action))
self.waiting = True
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results) # type: ignore[assignment]
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _flatten_obs(obs, self.observation_space)
def close(self) -> None:
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for process in self.processes:
process.join()
self.closed = True
def get_images(self) -> Sequence[Optional[np.ndarray]]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.remotes]
for pipe in self.remotes:
# gather render return from subprocesses
pipe.send(("render", None))
outputs = [pipe.recv() for pipe in self.remotes]
return outputs
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("get_attr", attr_name))
return [remote.recv() for remote in target_remotes]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("set_attr", (attr_name, value)))
for remote in target_remotes:
remote.recv()
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
:param indices: refers to indices of envs.
:return: Connection object to communicate between processes.
"""
indices = self._get_indices(indices)
return [self.remotes[i] for i in indices]
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Flatten observations, depending on the observation space.
:param obs: observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return: flattened observations.
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs) # type: ignore[arg-type]

View File

@ -0,0 +1,77 @@
"""
Helpers for dealing with vectorized environments.
"""
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Deep-copy a dict of numpy arrays.
:param obs: a dict of numpy arrays.
:return: a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
:param obs_space: an observation space.
:param obs_dict: a dict of numpy arrays.
:return: returns an observation of the same type as space.
If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
otherwise, space is unstructured and returns the value raw_obs[None].
"""
if isinstance(obs_space, spaces.Dict):
return obs_dict
elif isinstance(obs_space, spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
"""
Get dict-structured information about a gym.Space.
Dict spaces are represented directly by their dict of subspaces.
Tuple spaces are converted into a dict with keys indexing into the tuple.
Unstructured spaces are represented by {None: obs_space}.
:param obs_space: an observation space
:return: A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment]
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes

View File

@ -0,0 +1,108 @@
import warnings
from typing import List, Tuple
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecCheckNan(VecEnvWrapper):
"""
NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
allowing you to know from what the NaN of inf originated from.
:param venv: the vectorized environment to wrap
:param raise_exception: Whether to raise a ValueError, instead of a UserWarning
:param warn_once: Whether to only warn once.
:param check_inf: Whether to check for +inf or -inf as well
"""
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None:
super().__init__(venv)
self.raise_exception = raise_exception
self.warn_once = warn_once
self.check_inf = check_inf
self._user_warned = False
self._actions: np.ndarray
self._observations: VecEnvObs
if isinstance(venv.action_space, spaces.Dict):
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")
def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
self._actions = actions
self.venv.step_async(actions)
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones)
self._observations = observations
return observations, rewards, dones, infos
def reset(self) -> VecEnvObs:
observations = self.venv.reset()
self._check_val(event="reset", observations=observations)
self._observations = observations
return observations
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.
:param name: Name of the value being check
:param value: Value (numpy array) to check
:return: A list of issues found.
"""
found = []
has_nan = np.any(np.isnan(value))
has_inf = self.check_inf and np.any(np.isinf(value))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
return found
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
found = []
for name, value in kwargs.items():
if isinstance(value, (np.ndarray, list)):
found += self.check_array_value(name, np.asarray(value))
elif isinstance(value, dict):
for inner_name, inner_val in value.items():
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
elif isinstance(value, tuple):
for idx, inner_val in enumerate(value):
found += self.check_array_value(f"{name}.{idx}", inner_val)
else:
raise TypeError(f"Unsupported observation type {type(value)}.")
if found:
self._user_warned = True
msg = ""
for i, (name, type_val) in enumerate(found):
msg += f"found {type_val} in {name}"
if i != len(found) - 1:
msg += ", "
msg += ".\r\nOriginated from the "
if event == "reset":
msg += "environment observation (at reset)"
elif event == "step_wait":
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
elif event == "step_async":
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
else:
raise ValueError("Internal error.")
if self.raise_exception:
raise ValueError(msg)
else:
warnings.warn(msg, UserWarning)

View File

@ -0,0 +1,33 @@
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
assert isinstance(
venv.observation_space, spaces.Dict
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
assert isinstance(obs, dict)
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
assert isinstance(obs, dict)
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos

View File

@ -0,0 +1,48 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)
def step_wait(
self,
) -> Tuple[
Union[np.ndarray, Dict[str, np.ndarray]],
np.ndarray,
np.ndarray,
List[Dict[str, Any]],
]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
return observations, rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset()
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation

View File

@ -0,0 +1,100 @@
import time
import warnings
from typing import Optional, Tuple
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecMonitor(VecEnvWrapper):
"""
A vectorized monitor wrapper for *vectorized* Gym environments,
it is used to record the episode reward, length, time and other data.
Some environments like `openai/procgen <https://github.com/openai/procgen>`_
or `gym3 <https://github.com/openai/gym3>`_ directly initialize the
vectorized environments, without giving us a chance to use the ``Monitor``
wrapper. So this class simply does the job of the ``Monitor`` wrapper on
a vectorized level.
:param venv: The vectorized environment
:param filename: the location to save a log file, can be None for no log
:param info_keywords: extra information to log, from the information return of env.step()
"""
def __init__(
self,
venv: VecEnv,
filename: Optional[str] = None,
info_keywords: Tuple[str, ...] = (),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter
# This check is not valid for special `VecEnv`
# like the ones created by Procgen, that does follow completely
# the `VecEnv` interface
try:
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
except AttributeError:
is_wrapped_with_monitor = False
if is_wrapped_with_monitor:
warnings.warn(
"The environment is already wrapped with a `Monitor` wrapper"
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
"overwritten by the `VecMonitor` ones.",
UserWarning,
)
VecEnvWrapper.__init__(self, venv)
self.episode_count = 0
self.t_start = time.time()
env_id = None
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id
self.results_writer: Optional[ResultsWriter] = None
if filename:
self.results_writer = ResultsWriter(
filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords
)
self.info_keywords = info_keywords
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
return obs, rewards, dones, new_infos
def close(self) -> None:
if self.results_writer:
self.results_writer.close()
return self.venv.close()

View File

@ -0,0 +1,330 @@
import inspect
import pickle
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
:param venv: the vectorized environment to wrap
:param training: Whether to update or not the moving average
:param norm_obs: Whether to normalize observation or not (default: True)
:param norm_reward: Whether to normalize rewards or not (default: True)
:param clip_obs: Max absolute value for observation
:param clip_reward: Max value absolute for discounted reward
:param gamma: discount factor
:param epsilon: To avoid division by zero
:param norm_obs_keys: Which keys from observation dict to normalize.
If not specified, all keys will be normalized.
"""
obs_spaces: Dict[str, spaces.Space]
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
def __init__(
self,
venv: VecEnv,
training: bool = True,
norm_obs: bool = True,
norm_reward: bool = True,
clip_obs: float = 10.0,
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
norm_obs_keys: Optional[List[str]] = None,
):
VecEnvWrapper.__init__(self, venv)
self.norm_obs = norm_obs
self.norm_obs_keys = norm_obs_keys
# Check observation spaces
if self.norm_obs:
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
self._sanity_checks()
if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
dtype=np.float32,
)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
# VecNormalize is used with an image-like input and
# normalize_images=True.
# For correctness, we should also update the bounds
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
dtype=np.float32,
)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
self.clip_reward = clip_reward
# Returns: discounted rewards
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_reward = np.array([])
def _sanity_checks(self) -> None:
"""
Check the observations that are going to be normalized are of the correct type (spaces.Box).
"""
if isinstance(self.observation_space, spaces.Dict):
# By default, we normalize all keys
if self.norm_obs_keys is None:
self.norm_obs_keys = list(self.observation_space.spaces.keys())
# Check that all keys are of type Box
for obs_key in self.norm_obs_keys:
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
raise ValueError(
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
f"is of type {self.observation_space.spaces[obs_key]}. "
"You should probably explicitely pass the observation keys "
" that should be normalized via the `norm_obs_keys` parameter."
)
elif isinstance(self.observation_space, spaces.Box):
if self.norm_obs_keys is not None:
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
else:
raise ValueError(
"VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
f"not {self.observation_space}"
)
def __getstate__(self) -> Dict[str, Any]:
"""
Gets state for pickling.
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["venv"]
del state["class_attributes"]
# these attributes depend on the above and so we would prefer not to pickle
del state["returns"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""
Restores pickled state.
User must call set_venv() after unpickling before using.
:param state:"""
# Backward compatibility
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state
self.venv = None # type: ignore[assignment]
def set_venv(self, venv: VecEnv) -> None:
"""
Sets the vector environment to wrap to venv.
Also sets attributes derived from this such as `num_env`.
:param venv:
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode
# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
self.returns = np.zeros(self.num_envs)
def step_wait(self) -> VecEnvStepReturn:
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, dones)
where ``dones`` is a boolean vector indicating whether each element is new.
"""
obs, rewards, dones, infos = self.venv.step_wait()
assert isinstance(obs, (np.ndarray, dict)) # for mypy
self.old_obs = obs
self.old_reward = rewards
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
self.obs_rms.update(obs)
obs = self.normalize_obs(obs)
if self.training:
self._update_reward(rewards)
rewards = self.normalize_reward(rewards)
# Normalize the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
self.returns[dones] = 0
return obs, rewards, dones, infos
def _update_reward(self, reward: np.ndarray) -> None:
"""Update reward normalization statistics."""
self.returns = self.returns * self.gamma + reward
self.ret_rms.update(self.returns)
def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to normalize observation.
:param obs:
:param obs_rms: associated statistics
:return: normalized observation
"""
return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to unnormalize observation.
:param obs:
:param obs_rms: associated statistics
:return: unnormalized observation
"""
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
"""
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
return obs_
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using this VecNormalize's rewards statistics.
Calling this method does not update statistics.
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
return obs_
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
if self.norm_reward:
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
return reward
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Returns an unnormalized version of the observations from the most recent
step or reset.
"""
return deepcopy(self.old_obs)
def get_original_reward(self) -> np.ndarray:
"""
Returns an unnormalized version of the rewards from the most recent step.
"""
return self.old_reward.copy()
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
:return: first observation of the episode
"""
obs = self.venv.reset()
assert isinstance(obs, (np.ndarray, dict))
self.old_obs = obs
self.returns = np.zeros(self.num_envs)
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
self.obs_rms.update(obs)
return self.normalize_obs(obs)
@staticmethod
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:return:
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize
def save(self, save_path: str) -> None:
"""
Save current VecNormalize object with
all running statistics and settings (e.g. clip_obs)
:param save_path: The path to save to
"""
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)

View File

@ -0,0 +1,118 @@
from copy import deepcopy
from typing import Dict, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecTransposeImage(VecEnvWrapper):
"""
Re-order channels, from HxWxC to CxHxW.
It is required for PyTorch convolution layers.
:param venv:
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
which may result in unwanted behavior, see GH issue #671.
"""
def __init__(self, venv: VecEnv, skip: bool = False):
assert is_image_space(venv.observation_space) or isinstance(
venv.observation_space, spaces.Dict
), "The observation space must be an image or dictionary observation space"
self.skip = skip
# Do nothing
if skip:
super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.Dict):
self.image_space_keys = []
observation_space = deepcopy(venv.observation_space)
for key, space in observation_space.spaces.items():
if is_image_space(space):
# Keep track of which keys should be transposed later
self.image_space_keys.append(key)
assert isinstance(space, spaces.Box)
observation_space.spaces[key] = self.transpose_space(space, key)
else:
assert isinstance(venv.observation_space, spaces.Box)
observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment]
super().__init__(venv, observation_space=observation_space)
@staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
"""
Transpose an observation space (re-order channels).
:param observation_space:
:param key: In case of dictionary space, the key of the observation space.
:return:
"""
# Sanity checks
assert is_image_space(observation_space), "The observation space must be an image"
assert not is_image_space_channels_first(
observation_space
), f"The observation space {key} must follow the channel last convention"
height, width, channels = observation_space.shape
new_shape = (channels, height, width)
return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type]
@staticmethod
def transpose_image(image: np.ndarray) -> np.ndarray:
"""
Transpose an image or batch of images (re-order channels).
:param image:
:return:
"""
if len(image.shape) == 3:
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
"""
Transpose (if needed) and return new observations.
:param observations:
:return: Transposed observations
"""
# Do nothing
if self.skip:
return observations
if isinstance(observations, dict):
# Avoid modifying the original object in place
observations = deepcopy(observations)
for k in self.image_space_keys:
observations[k] = self.transpose_image(observations[k])
else:
observations = self.transpose_image(observations)
return observations
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
# Transpose the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations), rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict]:
"""
Reset all environments
"""
observations = self.venv.reset()
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations)
def close(self) -> None:
self.venv.close()

View File

@ -0,0 +1,113 @@
import os
from typing import Callable
from gymnasium.wrappers.monitoring import video_recorder
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
class VecVideoRecorder(VecEnvWrapper):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
It requires ffmpeg or avconv to be installed on the machine.
:param venv:
:param video_folder: Where to save videos
:param record_video_trigger: Function that defines when to start recording.
The function takes the current number of step,
and returns whether we should start recording or not.
:param video_length: Length of recorded videos
:param name_prefix: Prefix to the video name
"""
video_recorder: video_recorder.VideoRecorder
def __init__(
self,
venv: VecEnv,
video_folder: str,
record_video_trigger: Callable[[int], bool],
video_length: int = 200,
name_prefix: str = "rl-video",
):
VecEnvWrapper.__init__(self, venv)
self.env = venv
# Temp variable to retrieve metadata
temp_env = venv
# Unwrap to retrieve metadata dict
# that will be used by gym recorder
while isinstance(temp_env, VecEnvWrapper):
temp_env = temp_env.venv
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else:
metadata = temp_env.metadata
self.env.metadata = metadata
assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}"
self.record_video_trigger = record_video_trigger
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.recorded_frames = 0
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.start_video_recorder()
return obs
def start_video_recorder(self) -> None:
self.close_video_recorder()
video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
)
self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True
def _video_enabled(self) -> bool:
return self.record_video_trigger(self.step_id)
def step_wait(self) -> VecEnvStepReturn:
obs, rews, dones, infos = self.venv.step_wait()
self.step_id += 1
if self.recording:
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.recorded_frames > self.video_length:
print(f"Saving video to {self.video_recorder.path}")
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return obs, rews, dones, infos
def close_video_recorder(self) -> None:
if self.recording:
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1
def close(self) -> None:
VecEnvWrapper.close(self)
self.close_video_recorder()
def __del__(self):
self.close_video_recorder()

View File

@ -0,0 +1,4 @@
from stable_baselines3.ddpg.ddpg import DDPG
from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DDPG"]

View File

@ -0,0 +1,130 @@
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union
import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3
SelfDDPG = TypeVar("SelfDDPG", bound="DDPG")
class DDPG(TD3):
"""
Deep Deterministic Policy Gradient (DDPG).
Deterministic Policy Gradient: http://proceedings.mlr.press/v32/silver14.pdf
DDPG Paper: https://arxiv.org/abs/1509.02971
Introduction to DDPG: https://spinningup.openai.com/en/latest/algorithms/ddpg.html
Note: we treat DDPG as a special case of its successor TD3.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: learning rate for adam optimizer,
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
it can be a function of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1)
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param action_noise: the action noise type (None by default), this can help
for hard exploration problem. Cf common.noise for the different action noise type.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[TD3Policy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
optimize_memory_usage=optimize_memory_usage,
# Remove all tricks from TD3 to obtain DDPG:
# we still need to specify target_policy_noise > 0 to avoid errors
policy_delay=1,
target_noise_clip=0.0,
target_policy_noise=0.1,
_init_setup_model=False,
)
# Use only one critic
if "n_critics" not in self.policy_kwargs:
self.policy_kwargs["n_critics"] = 1
if _init_setup_model:
self._setup_model()
def learn(
self: SelfDDPG,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DDPG",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDDPG:
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

View File

@ -0,0 +1,2 @@
# DDPG can be view as a special case of TD3
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401

View File

@ -0,0 +1,4 @@
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "DQN"]

Some files were not shown because too many files have changed in this diff Show More