I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,105 @@
from copy import deepcopy
from typing import Optional, Type, TypeVar
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper)
def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: The ``VecEnv`` that is going to be unwrapped
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: The ``VecEnvWrapper`` object if the ``VecEnv`` is wrapped with the desired wrapper, None otherwise
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None
def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]:
"""
Retrieve a ``VecNormalize`` object by recursively searching.
:param env: The VecEnv that is going to be unwrapped
:return: The ``VecNormalize`` object if the ``VecEnv`` is wrapped with ``VecNormalize``, None otherwise
"""
return unwrap_vec_wrapper(env, VecNormalize)
def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool:
"""
Check if an environment is already wrapped in a given ``VecEnvWrapper``.
:param env: The VecEnv that is going to be checked
:param vec_wrapper_class: The desired ``VecEnvWrapper`` class.
:return: True if the ``VecEnv`` is wrapped with the desired wrapper, False otherwise
"""
return unwrap_vec_wrapper(env, vec_wrapper_class) is not None
def sync_envs_normalization(env: VecEnv, eval_env: VecEnv) -> None:
"""
Synchronize the normalization statistics of an eval environment and train environment
when they are both wrapped in a ``VecNormalize`` wrapper.
:param env: Training env
:param eval_env: Environment used for evaluation.
"""
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
assert isinstance(eval_env_tmp, VecEnvWrapper), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecEnvWrapper but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
if isinstance(env_tmp, VecNormalize):
assert isinstance(eval_env_tmp, VecNormalize), (
"Error while synchronizing normalization stats: expected the eval env to be "
f"a VecNormalize but got {eval_env_tmp} instead. "
"This is probably due to the training env not being wrapped the same way as the evaluation env. "
f"Training env type: {env_tmp}."
)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv
__all__ = [
"CloudpickleWrapper",
"VecEnv",
"VecEnvWrapper",
"DummyVecEnv",
"StackedObservations",
"SubprocVecEnv",
"VecCheckNan",
"VecExtractDictObs",
"VecFrameStack",
"VecMonitor",
"VecNormalize",
"VecTransposeImage",
"VecVideoRecorder",
"unwrap_vec_wrapper",
"unwrap_vec_normalize",
"is_vecenv_wrapped",
"sync_envs_normalization",
]

View File

@ -0,0 +1,482 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union
import cloudpickle
import gymnasium as gym
import numpy as np
from gymnasium import spaces
# Define type aliases here to avoid circular import
# Used when we want to access one or more VecEnv
VecEnvIndices = Union[None, int, Iterable[int]]
# VecEnvObs is what is returned by the reset() method
# it contains the observation for each env
VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
# VecEnvStepReturn is what is returned by the step() method
# it contains the observation, reward, done, info for each env
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover
"""
Tile N images into one big PxQ image
(P,Q) are chosen to be as close as possible, and if N
is square, then P=Q.
:param images_nhwc: list or array of images, ndim=4 once turned into array.
n = batch index, h = height, w = width, c = channel
:return: img_HWc, ndim=3
"""
img_nhwc = np.asarray(images_nhwc)
n_images, height, width, n_channels = img_nhwc.shape
# new_height was named H before
new_height = int(np.ceil(np.sqrt(n_images)))
# new_width was named W before
new_width = int(np.ceil(float(n_images) / new_height))
img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)])
# img_HWhwc
out_image = img_nhwc.reshape((new_height, new_width, height, width, n_channels))
# img_HhWwc
out_image = out_image.transpose(0, 2, 1, 3, 4)
# img_Hh_Ww_c
out_image = out_image.reshape((new_height * height, new_width * width, n_channels))
return out_image
class VecEnv(ABC):
"""
An abstract asynchronous, vectorized environment.
:param num_envs: Number of environments
:param observation_space: Observation space
:param action_space: Action space
"""
def __init__(
self,
num_envs: int,
observation_space: spaces.Space,
action_space: spaces.Space,
):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
# store info returned by the reset method
self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
# seeds to be used in the next call to env.reset()
self._seeds: List[Optional[int]] = [None for _ in range(num_envs)]
# options to be used in the next call to env.reset()
self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)]
try:
render_modes = self.get_attr("render_mode")
except AttributeError:
warnings.warn("The `render_mode` attribute is not defined in your environment. It will be set to None.")
render_modes = [None for _ in range(num_envs)]
assert all(
render_mode == render_modes[0] for render_mode in render_modes
), "render_mode mode should be the same for all environments"
self.render_mode = render_modes[0]
render_modes = []
if self.render_mode is not None:
if self.render_mode == "rgb_array":
# SB3 uses OpenCV for the "human" mode
render_modes = ["human", "rgb_array"]
else:
render_modes = [self.render_mode]
self.metadata = {"render_modes": render_modes}
def _reset_seeds(self) -> None:
"""
Reset the seeds that are going to be used at the next reset.
"""
self._seeds = [None for _ in range(self.num_envs)]
def _reset_options(self) -> None:
"""
Reset the options that are going to be used at the next reset.
"""
self._options = [{} for _ in range(self.num_envs)]
@abstractmethod
def reset(self) -> VecEnvObs:
"""
Reset all the environments and return an array of
observations, or a tuple of observation arrays.
If step_async is still doing work, that work will
be cancelled and step_wait() should not be called
until step_async() is invoked again.
:return: observation
"""
raise NotImplementedError()
@abstractmethod
def step_async(self, actions: np.ndarray) -> None:
"""
Tell all the environments to start taking a step
with the given actions.
Call step_wait() to get the results of the step.
You should not call this if a step_async run is
already pending.
"""
raise NotImplementedError()
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
"""
Wait for the step taken with step_async().
:return: observation, reward, done, information
"""
raise NotImplementedError()
@abstractmethod
def close(self) -> None:
"""
Clean up the environment's resources.
"""
raise NotImplementedError()
@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""
Return attribute from vectorized environment.
:param attr_name: The name of the attribute whose value to return
:param indices: Indices of envs to get attribute from
:return: List of values of 'attr_name' in all environments
"""
raise NotImplementedError()
@abstractmethod
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""
Set attribute inside vectorized environments.
:param attr_name: The name of attribute to assign new value
:param value: Value to assign to `attr_name`
:param indices: Indices of envs to assign value
:return:
"""
raise NotImplementedError()
@abstractmethod
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""
Call instance methods of vectorized environments.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: List of items returned by the environment's method call
"""
raise NotImplementedError()
@abstractmethod
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""
Check if environments are wrapped with a given wrapper.
:param method_name: The name of the environment method to invoke.
:param indices: Indices of envs whose method to call
:param method_args: Any positional arguments to provide in the call
:param method_kwargs: Any keyword arguments to provide in the call
:return: True if the env is wrapped, False otherwise, for each env queried.
"""
raise NotImplementedError()
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
"""
Step the environments with the given action
:param actions: the action
:return: observation, reward, done, information
"""
self.step_async(actions)
return self.step_wait()
def get_images(self) -> Sequence[Optional[np.ndarray]]:
"""
Return RGB images from each environment when available
"""
raise NotImplementedError
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
"""
Gym environment rendering
:param mode: the rendering type
"""
if mode == "human" and self.render_mode != mode:
# Special case, if the render_mode="rgb_array"
# we can still display that image using opencv
if self.render_mode != "rgb_array":
warnings.warn(
f"You tried to render a VecEnv with mode='{mode}' "
"but the render mode defined when initializing the environment must be "
f"'human' or 'rgb_array', not '{self.render_mode}'."
)
return None
elif mode and self.render_mode != mode:
warnings.warn(
f"""Starting from gymnasium v0.26, render modes are determined during the initialization of the environment.
We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode})
has to be the same as the environment render mode ({self.render_mode}) which is not the case."""
)
return None
mode = mode or self.render_mode
if mode is None:
warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.")
return None
# mode == self.render_mode == "human"
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
if self.render_mode == "human":
self.env_method("render")
return None
if mode == "rgb_array" or mode == "human":
# call the render method of the environments
images = self.get_images()
# Create a big image by tiling images from subprocesses
bigimg = tile_images(images) # type: ignore[arg-type]
if mode == "human":
# Display it using OpenCV
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
else:
return bigimg
else:
# Other render modes:
# In that case, we try to call `self.env.render()` but it might
# crash for subprocesses
# and we don't return the values
self.env_method("render")
return None
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
WARNING: since gym 0.26, those seeds will only be passed to the environment
at the next reset.
:param seed: The random seed. May be None for completely random seeding.
:return: Returns a list containing the seeds for each individual env.
Note that all list elements may be None, if the env does not return anything when being seeded.
"""
if seed is None:
# To ensure that subprocesses have different seeds,
# we still populate the seed variable when no argument is passed
seed = int(np.random.randint(0, np.iinfo(np.uint32).max, dtype=np.uint32))
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
"""
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
"""
if options is None:
options = {}
# Use deepcopy to avoid side effects
if isinstance(options, dict):
self._options = deepcopy([options] * self.num_envs)
else:
self._options = deepcopy(options)
@property
def unwrapped(self) -> "VecEnv":
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""Check if an attribute reference is being hidden in a recursive call to __getattr__
:param name: name of attribute to check for
:param already_found: whether this attribute has already been found in a wrapper
:return: name of module whose attribute is being shadowed, if any.
"""
if hasattr(self, name) and already_found:
return f"{type(self).__module__}.{type(self).__name__}"
else:
return None
def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
"""
Convert a flexibly-typed reference to environment indices to an implied list of indices.
:param indices: refers to indices of envs.
:return: the implied list of indices.
"""
if indices is None:
indices = range(self.num_envs)
elif isinstance(indices, int):
indices = [indices]
return indices
class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
:param venv: the vectorized environment to wrap
:param observation_space: the observation space (can be None to load from venv)
:param action_space: the action space (can be None to load from venv)
"""
def __init__(
self,
venv: VecEnv,
observation_space: Optional[spaces.Space] = None,
action_space: Optional[spaces.Space] = None,
):
self.venv = venv
super().__init__(
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space,
)
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions: np.ndarray) -> None:
self.venv.step_async(actions)
@abstractmethod
def reset(self) -> VecEnvObs:
pass
@abstractmethod
def step_wait(self) -> VecEnvStepReturn:
pass
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
return self.venv.set_options(options)
def close(self) -> None:
return self.venv.close()
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
return self.venv.render(mode=mode)
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
return self.venv.get_attr(attr_name, indices)
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
return self.venv.set_attr(attr_name, value, indices)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
return self.venv.env_is_wrapped(wrapper_class, indices=indices)
def __getattr__(self, name: str) -> Any:
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
Useful for accessing attributes from venvs which are wrapped with multiple wrappers
which have unique attributes of interest.
"""
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = f"{type(self).__module__}.{type(self).__name__}"
error_str = (
f"Error: Recursive attribute lookup for {name} from {own_class} is "
f"ambiguous and hides attribute from {blocked_class}"
)
raise AttributeError(error_str)
return self.getattr_recursive(name)
def _get_all_attributes(self) -> Dict[str, Any]:
"""Get all (inherited) instance and class attributes
:return: all_attributes
"""
all_attributes = self.__dict__.copy()
all_attributes.update(self.class_attributes)
return all_attributes
def getattr_recursive(self, name: str) -> Any:
"""Recursively check wrappers to find attribute.
:param name: name of attribute to look for
:return: attribute
"""
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
elif hasattr(self.venv, "getattr_recursive"):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
else: # attribute not present, child is an unwrapped VecEnv
attr = getattr(self.venv, name)
return attr
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""See base class.
:return: name of module whose attribute is being shadowed, if any.
"""
all_attributes = self._get_all_attributes()
if name in all_attributes and already_found:
# this venv's attribute is being hidden because of a higher venv.
shadowed_wrapper_class: Optional[str] = f"{type(self).__module__}.{type(self).__name__}"
elif name in all_attributes and not already_found:
# we have found the first reference to the attribute. Now check for duplicates.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, True)
else:
# this wrapper does not have the attribute. Keep searching.
shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found)
return shadowed_wrapper_class
class CloudpickleWrapper:
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
:param var: the variable you wish to wrap for pickling with cloudpickle
"""
def __init__(self, var: Any):
self.var = var
def __getstate__(self) -> Any:
return cloudpickle.dumps(self.var)
def __setstate__(self, var: Any) -> None:
self.var = cloudpickle.loads(var)

View File

@ -0,0 +1,141 @@
import warnings
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
import gymnasium as gym
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ``Cartpole-v1``,
as the overhead of multiprocess or multithread outweighs the environment computation time.
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: a list of functions
that return environments to vectorize
:raises ValueError: If the same environment instance is passed as the output of two or more different env_fn.
"""
actions: np.ndarray
def __init__(self, env_fns: List[Callable[[], gym.Env]]):
self.envs = [_patch_env(fn()) for fn in env_fns]
if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs):
raise ValueError(
"You tried to create multiple environments, but the function to create them returned the same instance "
"instead of creating different objects. "
"You are probably using `make_vec_env(lambda: env)` or `DummyVecEnv([lambda: env] * n_envs)`. "
"You should replace `lambda: env` by a `make_env` function that "
"creates a new instance of the environment at every call "
"(using `gym.make()` for instance). You can take a look at the documentation for an example. "
"Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information."
)
env = self.envs[0]
super().__init__(len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)]
self.metadata = env.metadata
def step_async(self, actions: np.ndarray) -> None:
self.actions = actions
def step_wait(self) -> VecEnvStepReturn:
# Avoid circular imports
for env_idx in range(self.num_envs):
obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
self.actions[env_idx]
)
# convert to SB3 VecEnv api
self.buf_dones[env_idx] = terminated or truncated
# See https://github.com/openai/gym/issues/3102
# Gym 0.26 introduces a breaking change
self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
self.buf_infos[env_idx]["terminal_observation"] = obs
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {}
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
self._save_obs(env_idx, obs)
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return self._obs_from_buf()
def close(self) -> None:
for env in self.envs:
env.close()
def get_images(self) -> Sequence[Optional[np.ndarray]]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.envs]
return [env.render() for env in self.envs] # type: ignore[misc]
def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
"""
Gym environment rendering. If there are multiple environments then
they are tiled together in one image via ``BaseVecEnv.render()``.
:param mode: The rendering type.
"""
return super().render(mode=mode)
def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
for key in self.keys:
if key is None:
self.buf_obs[key][env_idx] = obs
else:
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_envs = self._get_target_envs(indices)
# Import here to avoid a circular import
from stable_baselines3.common import env_util
return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs]
def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]:
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]

View File

@ -0,0 +1,100 @@
import warnings
from inspect import signature
from typing import Union
import gymnasium
try:
import gym
gym_installed = True
except ImportError:
gym_installed = False
def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma: no cover
"""
Adapted from https://github.com/thu-ml/tianshou.
Takes an environment and patches it to return Gymnasium env.
This function takes the environment object and returns a patched
env, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium env
:return: Patched env (gymnasium env)
"""
# Gymnasium env, no patching to be done
if isinstance(env, gymnasium.Env):
return env
if not gym_installed or not isinstance(env, gym.Env):
raise ValueError(
f"The environment is of type {type(env)}, not a Gymnasium "
f"environment. In this case, we expect OpenAI Gym to be "
f"installed and the environment to be an OpenAI Gym environment."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e
warnings.warn(
"You provided an OpenAI Gym environment. "
"We strongly recommend transitioning to Gymnasium environments. "
"Stable-Baselines3 is automatically wrapping your environments in a compatibility "
"layer, which could potentially cause issues."
)
if "seed" in signature(env.unwrapped.reset).parameters:
# Gym 0.26+ env
return shimmy.GymV26CompatibilityV0(env=env)
# Gym 0.21 env
return shimmy.GymV21CompatibilityV0(env=env)
def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Space: # pragma: no cover
"""
Takes a space and patches it to return Gymnasium Space.
This function takes the space object and returns a patched
space, using shimmy wrapper to convert it to Gymnasium,
if necessary.
:param env: A gym/gymnasium Space
:return: Patched space (gymnasium Space)
"""
# Gymnasium space, no convertion to be done
if isinstance(space, gymnasium.Space):
return space
if not gym_installed or not isinstance(space, gym.Space):
raise ValueError(
f"The space is of type {type(space)}, not a Gymnasium "
f"space. In this case, we expect OpenAI Gym to be "
f"installed and the space to be an OpenAI Gym space."
)
try:
import shimmy
except ImportError as e:
raise ImportError(
"Missing shimmy installation. You provided an OpenAI Gym space. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym space with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
) from e
warnings.warn(
"You loaded a model that was trained using OpenAI Gym. "
"We strongly recommend transitioning to Gymnasium by saving that model again."
)
return shimmy.openai_gym_compatibility._convert_space(space)

View File

@ -0,0 +1,176 @@
import warnings
from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray])
class StackedObservations(Generic[TObs]):
"""
Frame stacking wrapper for data.
Dimension to stack over is either first (channels-first) or last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if observation is an image space.
:param num_envs: Number of environments
:param n_stack: Number of frames to stack
:param observation_space: Environment observation space
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last".
For Dict space, channels_order can also be a dictionary.
"""
def __init__(
self,
num_envs: int,
n_stack: int,
observation_space: Union[spaces.Box, spaces.Dict],
channels_order: Optional[Union[str, Mapping[str, Optional[str]]]] = None,
) -> None:
self.n_stack = n_stack
self.observation_space = observation_space
if isinstance(observation_space, spaces.Dict):
if not isinstance(channels_order, Mapping):
channels_order = {key: channels_order for key in observation_space.spaces.keys()}
self.sub_stacked_observations = {
key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type]
for key, subspace in observation_space.spaces.items()
}
self.stacked_observation_space = spaces.Dict(
{key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()}
) # type: Union[spaces.Dict, spaces.Box] # make mypy happy
elif isinstance(observation_space, spaces.Box):
if isinstance(channels_order, Mapping):
raise TypeError("When the observation space is Box, channels_order can't be a dict.")
self.channels_first, self.stack_dimension, self.stacked_shape, self.repeat_axis = self.compute_stacking(
n_stack, observation_space, channels_order
)
low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis)
high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis)
self.stacked_observation_space = spaces.Box(
low=low,
high=high,
dtype=observation_space.dtype, # type: ignore[arg-type]
)
self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype)
else:
raise TypeError(
f"StackedObservations only supports Box and Dict as observation spaces. {observation_space} was provided."
)
@staticmethod
def compute_stacking(
n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None
) -> Tuple[bool, int, Tuple[int, ...], int]:
"""
Calculates the parameters in order to stack observations
:param n_stack: Number of observations to stack
:param observation_space: Observation space
:param channels_order: Order of the channels
:return: Tuple of channels_first, stack_dimension, stackedobs, repeat_axis
"""
if channels_order is None:
# Detect channel location automatically for images
if is_image_space(observation_space):
channels_first = is_image_space_channels_first(observation_space)
else:
# Default behavior for non-image space, stack on the last axis
channels_first = False
else:
assert channels_order in {
"last",
"first",
}, "`channels_order` must be one of following: 'last', 'first'"
channels_first = channels_order == "first"
# This includes the vec-env dimension (first)
stack_dimension = 1 if channels_first else -1
repeat_axis = 0 if channels_first else -1
stacked_shape = list(observation_space.shape)
stacked_shape[repeat_axis] *= n_stack
return channels_first, stack_dimension, tuple(stacked_shape), repeat_axis
def reset(self, observation: TObs) -> TObs:
"""
Reset the stacked_obs, add the reset observation to the stack, and return the stack.
:param observation: Reset observation
:return: The stacked reset observation
"""
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
self.stacked_obs[...] = 0
if self.channels_first:
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs
def update(
self,
observations: TObs,
dones: np.ndarray,
infos: List[Dict[str, Any]],
) -> Tuple[TObs, List[Dict[str, Any]]]:
"""
Add the observations to the stack and use the dones to update the infos.
:param observations: Observations
:param dones: Dones
:param infos: Infos
:return: Tuple of the stacked observations and the updated infos
"""
if isinstance(observations, dict):
# From [{}, {terminal_obs: {key1: ..., key2: ...}}]
# to {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
sub_infos = {
key: [
{"terminal_observation": info["terminal_observation"][key]} if "terminal_observation" in info else {}
for info in infos
]
for key in observations.keys()
}
stacked_obs = {}
stacked_infos = {}
for key, obs in observations.items():
stacked_obs[key], stacked_infos[key] = self.sub_stacked_observations[key].update(obs, dones, sub_infos[key])
# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
for key in stacked_infos.keys():
for env_idx in range(len(infos)):
if "terminal_observation" in infos[env_idx]:
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
return stacked_obs, infos
shift = -observations.shape[self.stack_dimension]
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
for env_idx, done in enumerate(dones):
if done:
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
previous_stack = self.stacked_obs[env_idx, ..., :shift]
new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stacked_obs[env_idx] = 0
if self.channels_first:
self.stacked_obs[:, shift:, ...] = observations
else:
self.stacked_obs[..., shift:] = observations
return self.stacked_obs, infos

View File

@ -0,0 +1,232 @@
import multiprocessing as mp
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import (
CloudpickleWrapper,
VecEnv,
VecEnvIndices,
VecEnvObs,
VecEnvStepReturn,
)
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def _worker(
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
) -> None:
# Import here to avoid a circular import
from stable_baselines3.common.env_util import is_wrapped
parent_remote.close()
env = _patch_env(env_fn_wrapper.var())
reset_info: Optional[Dict[str, Any]] = {}
while True:
try:
cmd, data = remote.recv()
if cmd == "step":
observation, reward, terminated, truncated, info = env.step(data)
# convert to SB3 VecEnv api
done = terminated or truncated
info["TimeLimit.truncated"] = truncated and not terminated
if done:
# save final observation where user can get it, then reset
info["terminal_observation"] = observation
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
maybe_options = {"options": data[1]} if data[1] else {}
observation, reset_info = env.reset(seed=data[0], **maybe_options)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
elif cmd == "close":
env.close()
remote.close()
break
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(getattr(env, data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
remote.send(is_wrapped(env, data))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
class SubprocVecEnv(VecEnv):
"""
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
process, allowing significant speed up when the environment is computationally complex.
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
number of logical cores on your CPU.
.. warning::
Only 'forkserver' and 'spawn' start methods are thread-safe,
which is important when TensorFlow sessions or other non thread-safe
libraries are used in the parent (see issue #217). However, compared to
'fork' they incur a small start-up cost and have restrictions on
global variables. With those methods, users must wrap the code in an
``if __name__ == "__main__":`` block.
For more information, see the multiprocessing documentation.
:param env_fns: Environments to run in subprocesses
:param start_method: method used to start the subprocesses.
Must be one of the methods returned by multiprocessing.get_all_start_methods().
Defaults to 'forkserver' on available platforms, and 'spawn' otherwise.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None):
self.waiting = False
self.closed = False
n_envs = len(env_fns)
if start_method is None:
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
forkserver_available = "forkserver" in mp.get_all_start_methods()
start_method = "forkserver" if forkserver_available else "spawn"
ctx = mp.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
self.processes = []
for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
args = (work_remote, remote, CloudpickleWrapper(env_fn))
# daemon=True: if the main process crashes, we should not cause things to hang
process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined]
process.start()
self.processes.append(process)
work_remote.close()
self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
super().__init__(len(env_fns), observation_space, action_space)
def step_async(self, actions: np.ndarray) -> None:
for remote, action in zip(self.remotes, actions):
remote.send(("step", action))
self.waiting = True
def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
remote.send(("reset", (self._seeds[env_idx], self._options[env_idx])))
results = [remote.recv() for remote in self.remotes]
obs, self.reset_infos = zip(*results) # type: ignore[assignment]
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _flatten_obs(obs, self.observation_space)
def close(self) -> None:
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(("close", None))
for process in self.processes:
process.join()
self.closed = True
def get_images(self) -> Sequence[Optional[np.ndarray]]:
if self.render_mode != "rgb_array":
warnings.warn(
f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images."
)
return [None for _ in self.remotes]
for pipe in self.remotes:
# gather render return from subprocesses
pipe.send(("render", None))
outputs = [pipe.recv() for pipe in self.remotes]
return outputs
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("get_attr", attr_name))
return [remote.recv() for remote in target_remotes]
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("set_attr", (attr_name, value)))
for remote in target_remotes:
remote.recv()
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
remote.send(("is_wrapped", wrapper_class))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
"""
Get the connection object needed to communicate with the wanted
envs that are in subprocesses.
:param indices: refers to indices of envs.
:return: Connection object to communicate between processes.
"""
indices = self._get_indices(indices)
return [self.remotes[i] for i in indices]
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Flatten observations, depending on the observation space.
:param obs: observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return: flattened observations.
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs) # type: ignore[arg-type]

View File

@ -0,0 +1,77 @@
"""
Helpers for dealing with vectorized environments.
"""
from collections import OrderedDict
from typing import Any, Dict, List, Tuple
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Deep-copy a dict of numpy arrays.
:param obs: a dict of numpy arrays.
:return: a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
specified by space.
:param obs_space: an observation space.
:param obs_dict: a dict of numpy arrays.
:return: returns an observation of the same type as space.
If space is Dict, function is identity; if space is Tuple, converts dict to Tuple;
otherwise, space is unstructured and returns the value raw_obs[None].
"""
if isinstance(obs_space, spaces.Dict):
return obs_dict
elif isinstance(obs_space, spaces.Tuple):
assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space"
return tuple(obs_dict[i] for i in range(len(obs_space.spaces)))
else:
assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space"
return obs_dict[None]
def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]:
"""
Get dict-structured information about a gym.Space.
Dict spaces are represented directly by their dict of subspaces.
Tuple spaces are converted into a dict with keys indexing into the tuple.
Unstructured spaces are represented by {None: obs_space}.
:param obs_space: an observation space
:return: A tuple (keys, shapes, dtypes):
keys: a list of dict keys.
shapes: a dict mapping keys to shapes.
dtypes: a dict mapping keys to dtypes.
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment]
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes

View File

@ -0,0 +1,108 @@
import warnings
from typing import List, Tuple
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecCheckNan(VecEnvWrapper):
"""
NaN and inf checking wrapper for vectorized environment, will raise a warning by default,
allowing you to know from what the NaN of inf originated from.
:param venv: the vectorized environment to wrap
:param raise_exception: Whether to raise a ValueError, instead of a UserWarning
:param warn_once: Whether to only warn once.
:param check_inf: Whether to check for +inf or -inf as well
"""
def __init__(self, venv: VecEnv, raise_exception: bool = False, warn_once: bool = True, check_inf: bool = True) -> None:
super().__init__(venv)
self.raise_exception = raise_exception
self.warn_once = warn_once
self.check_inf = check_inf
self._user_warned = False
self._actions: np.ndarray
self._observations: VecEnvObs
if isinstance(venv.action_space, spaces.Dict):
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")
def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
self._actions = actions
self.venv.step_async(actions)
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
self._check_val(event="step_wait", observations=observations, rewards=rewards, dones=dones)
self._observations = observations
return observations, rewards, dones, infos
def reset(self) -> VecEnvObs:
observations = self.venv.reset()
self._check_val(event="reset", observations=observations)
self._observations = observations
return observations
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.
:param name: Name of the value being check
:param value: Value (numpy array) to check
:return: A list of issues found.
"""
found = []
has_nan = np.any(np.isnan(value))
has_inf = self.check_inf and np.any(np.isinf(value))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
return found
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
found = []
for name, value in kwargs.items():
if isinstance(value, (np.ndarray, list)):
found += self.check_array_value(name, np.asarray(value))
elif isinstance(value, dict):
for inner_name, inner_val in value.items():
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
elif isinstance(value, tuple):
for idx, inner_val in enumerate(value):
found += self.check_array_value(f"{name}.{idx}", inner_val)
else:
raise TypeError(f"Unsupported observation type {type(value)}.")
if found:
self._user_warned = True
msg = ""
for i, (name, type_val) in enumerate(found):
msg += f"found {type_val} in {name}"
if i != len(found) - 1:
msg += ", "
msg += ".\r\nOriginated from the "
if event == "reset":
msg += "environment observation (at reset)"
elif event == "step_wait":
msg += f"environment, Last given value was: \r\n\taction={self._actions}"
elif event == "step_async":
msg += f"RL model, Last given value was: \r\n\tobservations={self._observations}"
else:
raise ValueError("Internal error.")
if self.raise_exception:
raise ValueError(msg)
else:
warnings.warn(msg, UserWarning)

View File

@ -0,0 +1,33 @@
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""
def __init__(self, venv: VecEnv, key: str):
self.key = key
assert isinstance(
venv.observation_space, spaces.Dict
), f"VecExtractDictObs can only be used with Dict obs space, not {venv.observation_space}"
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])
def reset(self) -> np.ndarray:
obs = self.venv.reset()
assert isinstance(obs, dict)
return obs[self.key]
def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, infos = self.venv.step_wait()
assert isinstance(obs, dict)
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos

View File

@ -0,0 +1,48 @@
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment. Designed for image observations.
:param venv: Vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces
"""
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Mapping[str, str]]] = None) -> None:
assert isinstance(
venv.observation_space, (spaces.Box, spaces.Dict)
), "VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces"
self.stacked_obs = StackedObservations(venv.num_envs, n_stack, venv.observation_space, channels_order)
observation_space = self.stacked_obs.stacked_observation_space
super().__init__(venv, observation_space=observation_space)
def step_wait(
self,
) -> Tuple[
Union[np.ndarray, Dict[str, np.ndarray]],
np.ndarray,
np.ndarray,
List[Dict[str, Any]],
]:
observations, rewards, dones, infos = self.venv.step_wait()
observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type]
return observations, rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
"""
observation = self.venv.reset()
observation = self.stacked_obs.reset(observation) # type: ignore[arg-type]
return observation

View File

@ -0,0 +1,100 @@
import time
import warnings
from typing import Optional, Tuple
import numpy as np
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
class VecMonitor(VecEnvWrapper):
"""
A vectorized monitor wrapper for *vectorized* Gym environments,
it is used to record the episode reward, length, time and other data.
Some environments like `openai/procgen <https://github.com/openai/procgen>`_
or `gym3 <https://github.com/openai/gym3>`_ directly initialize the
vectorized environments, without giving us a chance to use the ``Monitor``
wrapper. So this class simply does the job of the ``Monitor`` wrapper on
a vectorized level.
:param venv: The vectorized environment
:param filename: the location to save a log file, can be None for no log
:param info_keywords: extra information to log, from the information return of env.step()
"""
def __init__(
self,
venv: VecEnv,
filename: Optional[str] = None,
info_keywords: Tuple[str, ...] = (),
):
# Avoid circular import
from stable_baselines3.common.monitor import Monitor, ResultsWriter
# This check is not valid for special `VecEnv`
# like the ones created by Procgen, that does follow completely
# the `VecEnv` interface
try:
is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0]
except AttributeError:
is_wrapped_with_monitor = False
if is_wrapped_with_monitor:
warnings.warn(
"The environment is already wrapped with a `Monitor` wrapper"
"but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be"
"overwritten by the `VecMonitor` ones.",
UserWarning,
)
VecEnvWrapper.__init__(self, venv)
self.episode_count = 0
self.t_start = time.time()
env_id = None
if hasattr(venv, "spec") and venv.spec is not None:
env_id = venv.spec.id
self.results_writer: Optional[ResultsWriter] = None
if filename:
self.results_writer = ResultsWriter(
filename, header={"t_start": self.t_start, "env_id": str(env_id)}, extra_keys=info_keywords
)
self.info_keywords = info_keywords
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs
def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, infos = self.venv.step_wait()
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
return obs, rewards, dones, new_infos
def close(self) -> None:
if self.results_writer:
self.results_writer.close()
return self.venv.close()

View File

@ -0,0 +1,330 @@
import inspect
import pickle
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common import utils
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.running_mean_std import RunningMeanStd
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecNormalize(VecEnvWrapper):
"""
A moving average, normalizing wrapper for vectorized environment.
has support for saving/loading moving average,
:param venv: the vectorized environment to wrap
:param training: Whether to update or not the moving average
:param norm_obs: Whether to normalize observation or not (default: True)
:param norm_reward: Whether to normalize rewards or not (default: True)
:param clip_obs: Max absolute value for observation
:param clip_reward: Max value absolute for discounted reward
:param gamma: discount factor
:param epsilon: To avoid division by zero
:param norm_obs_keys: Which keys from observation dict to normalize.
If not specified, all keys will be normalized.
"""
obs_spaces: Dict[str, spaces.Space]
old_obs: Union[np.ndarray, Dict[str, np.ndarray]]
def __init__(
self,
venv: VecEnv,
training: bool = True,
norm_obs: bool = True,
norm_reward: bool = True,
clip_obs: float = 10.0,
clip_reward: float = 10.0,
gamma: float = 0.99,
epsilon: float = 1e-8,
norm_obs_keys: Optional[List[str]] = None,
):
VecEnvWrapper.__init__(self, venv)
self.norm_obs = norm_obs
self.norm_obs_keys = norm_obs_keys
# Check observation spaces
if self.norm_obs:
# Note: mypy doesn't take into account the sanity checks, which lead to several type: ignore...
self._sanity_checks()
if isinstance(self.observation_space, spaces.Dict):
self.obs_spaces = self.observation_space.spaces
self.obs_rms = {key: RunningMeanStd(shape=self.obs_spaces[key].shape) for key in self.norm_obs_keys} # type: ignore[arg-type, union-attr]
# Update observation space when using image
# See explanation below and GH #1214
for key in self.obs_rms.keys():
if is_image_space(self.obs_spaces[key]):
self.observation_space.spaces[key] = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.obs_spaces[key].shape,
dtype=np.float32,
)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) # type: ignore[assignment, arg-type]
# Update observation space when using image
# See GH #1214
# This is to raise proper error when
# VecNormalize is used with an image-like input and
# normalize_images=True.
# For correctness, we should also update the bounds
# in other cases but this will cause backward-incompatible change
# and break already saved policies.
if is_image_space(self.observation_space):
self.observation_space = spaces.Box(
low=-clip_obs,
high=clip_obs,
shape=self.observation_space.shape,
dtype=np.float32,
)
self.ret_rms = RunningMeanStd(shape=())
self.clip_obs = clip_obs
self.clip_reward = clip_reward
# Returns: discounted rewards
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
self.training = training
self.norm_obs = norm_obs
self.norm_reward = norm_reward
self.old_reward = np.array([])
def _sanity_checks(self) -> None:
"""
Check the observations that are going to be normalized are of the correct type (spaces.Box).
"""
if isinstance(self.observation_space, spaces.Dict):
# By default, we normalize all keys
if self.norm_obs_keys is None:
self.norm_obs_keys = list(self.observation_space.spaces.keys())
# Check that all keys are of type Box
for obs_key in self.norm_obs_keys:
if not isinstance(self.observation_space.spaces[obs_key], spaces.Box):
raise ValueError(
f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} "
f"is of type {self.observation_space.spaces[obs_key]}. "
"You should probably explicitely pass the observation keys "
" that should be normalized via the `norm_obs_keys` parameter."
)
elif isinstance(self.observation_space, spaces.Box):
if self.norm_obs_keys is not None:
raise ValueError("`norm_obs_keys` param is applicable only with `gym.spaces.Dict` observation spaces")
else:
raise ValueError(
"VecNormalize only supports `gym.spaces.Box` and `gym.spaces.Dict` observation spaces, "
f"not {self.observation_space}"
)
def __getstate__(self) -> Dict[str, Any]:
"""
Gets state for pickling.
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
del state["venv"]
del state["class_attributes"]
# these attributes depend on the above and so we would prefer not to pickle
del state["returns"]
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""
Restores pickled state.
User must call set_venv() after unpickling before using.
:param state:"""
# Backward compatibility
if "norm_obs_keys" not in state and isinstance(state["observation_space"], spaces.Dict):
state["norm_obs_keys"] = list(state["observation_space"].spaces.keys())
self.__dict__.update(state)
assert "venv" not in state
self.venv = None # type: ignore[assignment]
def set_venv(self, venv: VecEnv) -> None:
"""
Sets the vector environment to wrap to venv.
Also sets attributes derived from this such as `num_env`.
:param venv:
"""
if self.venv is not None:
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode
# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
self.returns = np.zeros(self.num_envs)
def step_wait(self) -> VecEnvStepReturn:
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, dones)
where ``dones`` is a boolean vector indicating whether each element is new.
"""
obs, rewards, dones, infos = self.venv.step_wait()
assert isinstance(obs, (np.ndarray, dict)) # for mypy
self.old_obs = obs
self.old_reward = rewards
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
self.obs_rms.update(obs)
obs = self.normalize_obs(obs)
if self.training:
self._update_reward(rewards)
rewards = self.normalize_reward(rewards)
# Normalize the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
self.returns[dones] = 0
return obs, rewards, dones, infos
def _update_reward(self, reward: np.ndarray) -> None:
"""Update reward normalization statistics."""
self.returns = self.returns * self.gamma + reward
self.ret_rms.update(self.returns)
def _normalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to normalize observation.
:param obs:
:param obs_rms: associated statistics
:return: normalized observation
"""
return np.clip((obs - obs_rms.mean) / np.sqrt(obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarray:
"""
Helper to unnormalize observation.
:param obs:
:param obs_rms: associated statistics
:return: unnormalized observation
"""
return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean
def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Normalize observations using this VecNormalize's observations statistics.
Calling this method does not update statistics.
"""
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
# Only normalize the specified keys
for key in self.norm_obs_keys:
obs_[key] = self._normalize_obs(obs[key], self.obs_rms[key]).astype(np.float32)
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._normalize_obs(obs, self.obs_rms).astype(np.float32)
return obs_
def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
"""
Normalize rewards using this VecNormalize's rewards statistics.
Calling this method does not update statistics.
"""
if self.norm_reward:
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward
def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]:
# Avoid modifying by reference the original object
obs_ = deepcopy(obs)
if self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
assert self.norm_obs_keys is not None
for key in self.norm_obs_keys:
obs_[key] = self._unnormalize_obs(obs[key], self.obs_rms[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
obs_ = self._unnormalize_obs(obs, self.obs_rms)
return obs_
def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray:
if self.norm_reward:
return reward * np.sqrt(self.ret_rms.var + self.epsilon)
return reward
def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Returns an unnormalized version of the observations from the most recent
step or reset.
"""
return deepcopy(self.old_obs)
def get_original_reward(self) -> np.ndarray:
"""
Returns an unnormalized version of the rewards from the most recent step.
"""
return self.old_reward.copy()
def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]:
"""
Reset all environments
:return: first observation of the episode
"""
obs = self.venv.reset()
assert isinstance(obs, (np.ndarray, dict))
self.old_obs = obs
self.returns = np.zeros(self.num_envs)
if self.training and self.norm_obs:
if isinstance(obs, dict) and isinstance(self.obs_rms, dict):
for key in self.obs_rms.keys():
self.obs_rms[key].update(obs[key])
else:
assert isinstance(self.obs_rms, RunningMeanStd)
self.obs_rms.update(obs)
return self.normalize_obs(obs)
@staticmethod
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
"""
Loads a saved VecNormalize object.
:param load_path: the path to load from.
:param venv: the VecEnv to wrap.
:return:
"""
with open(load_path, "rb") as file_handler:
vec_normalize = pickle.load(file_handler)
vec_normalize.set_venv(venv)
return vec_normalize
def save(self, save_path: str) -> None:
"""
Save current VecNormalize object with
all running statistics and settings (e.g. clip_obs)
:param save_path: The path to save to
"""
with open(save_path, "wb") as file_handler:
pickle.dump(self, file_handler)

View File

@ -0,0 +1,118 @@
from copy import deepcopy
from typing import Dict, Union
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper
class VecTransposeImage(VecEnvWrapper):
"""
Re-order channels, from HxWxC to CxHxW.
It is required for PyTorch convolution layers.
:param venv:
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
which may result in unwanted behavior, see GH issue #671.
"""
def __init__(self, venv: VecEnv, skip: bool = False):
assert is_image_space(venv.observation_space) or isinstance(
venv.observation_space, spaces.Dict
), "The observation space must be an image or dictionary observation space"
self.skip = skip
# Do nothing
if skip:
super().__init__(venv)
return
if isinstance(venv.observation_space, spaces.Dict):
self.image_space_keys = []
observation_space = deepcopy(venv.observation_space)
for key, space in observation_space.spaces.items():
if is_image_space(space):
# Keep track of which keys should be transposed later
self.image_space_keys.append(key)
assert isinstance(space, spaces.Box)
observation_space.spaces[key] = self.transpose_space(space, key)
else:
assert isinstance(venv.observation_space, spaces.Box)
observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment]
super().__init__(venv, observation_space=observation_space)
@staticmethod
def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box:
"""
Transpose an observation space (re-order channels).
:param observation_space:
:param key: In case of dictionary space, the key of the observation space.
:return:
"""
# Sanity checks
assert is_image_space(observation_space), "The observation space must be an image"
assert not is_image_space_channels_first(
observation_space
), f"The observation space {key} must follow the channel last convention"
height, width, channels = observation_space.shape
new_shape = (channels, height, width)
return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type]
@staticmethod
def transpose_image(image: np.ndarray) -> np.ndarray:
"""
Transpose an image or batch of images (re-order channels).
:param image:
:return:
"""
if len(image.shape) == 3:
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]:
"""
Transpose (if needed) and return new observations.
:param observations:
:return: Transposed observations
"""
# Do nothing
if self.skip:
return observations
if isinstance(observations, dict):
# Avoid modifying the original object in place
observations = deepcopy(observations)
for k in self.image_space_keys:
observations[k] = self.transpose_image(observations[k])
else:
observations = self.transpose_image(observations)
return observations
def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()
# Transpose the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations), rewards, dones, infos
def reset(self) -> Union[np.ndarray, Dict]:
"""
Reset all environments
"""
observations = self.venv.reset()
assert isinstance(observations, (np.ndarray, dict))
return self.transpose_observations(observations)
def close(self) -> None:
self.venv.close()

View File

@ -0,0 +1,113 @@
import os
from typing import Callable
from gymnasium.wrappers.monitoring import video_recorder
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
class VecVideoRecorder(VecEnvWrapper):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
It requires ffmpeg or avconv to be installed on the machine.
:param venv:
:param video_folder: Where to save videos
:param record_video_trigger: Function that defines when to start recording.
The function takes the current number of step,
and returns whether we should start recording or not.
:param video_length: Length of recorded videos
:param name_prefix: Prefix to the video name
"""
video_recorder: video_recorder.VideoRecorder
def __init__(
self,
venv: VecEnv,
video_folder: str,
record_video_trigger: Callable[[int], bool],
video_length: int = 200,
name_prefix: str = "rl-video",
):
VecEnvWrapper.__init__(self, venv)
self.env = venv
# Temp variable to retrieve metadata
temp_env = venv
# Unwrap to retrieve metadata dict
# that will be used by gym recorder
while isinstance(temp_env, VecEnvWrapper):
temp_env = temp_env.venv
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else:
metadata = temp_env.metadata
self.env.metadata = metadata
assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}"
self.record_video_trigger = record_video_trigger
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.recorded_frames = 0
def reset(self) -> VecEnvObs:
obs = self.venv.reset()
self.start_video_recorder()
return obs
def start_video_recorder(self) -> None:
self.close_video_recorder()
video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
)
self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True
def _video_enabled(self) -> bool:
return self.record_video_trigger(self.step_id)
def step_wait(self) -> VecEnvStepReturn:
obs, rews, dones, infos = self.venv.step_wait()
self.step_id += 1
if self.recording:
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.recorded_frames > self.video_length:
print(f"Saving video to {self.video_recorder.path}")
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return obs, rews, dones, infos
def close_video_recorder(self) -> None:
if self.recording:
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1
def close(self) -> None:
VecEnvWrapper.close(self)
self.close_video_recorder()
def __del__(self):
self.close_video_recorder()