I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,72 @@
"""Module of wrapper classes.
Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly.
Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can
also be chained to combine their effects.
Most environments that are generated via :meth:`gymnasium.make` will already be wrapped by default.
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along
with (possibly optional) parameters to the wrapper's constructor.
>>> import gymnasium as gym
>>> from gymnasium.wrappers import RescaleAction
>>> base_env = gym.make("Hopper-v4")
>>> base_env.action_space
Box(-1.0, 1.0, (3,), float32)
>>> wrapped_env = RescaleAction(base_env, min_action=0, max_action=1)
>>> wrapped_env.action_space
Box(0.0, 1.0, (3,), float32)
You can access the environment underneath the **first** wrapper by using the :attr:`gymnasium.Wrapper.env` attribute.
As the :class:`gymnasium.Wrapper` class inherits from :class:`gymnasium.Env` then :attr:`gymnasium.Wrapper.env` can be another wrapper.
>>> wrapped_env
<RescaleAction<TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>>
>>> wrapped_env.env
<TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>
If you want to get to the environment underneath **all** of the layers of wrappers, you can use the
:attr:`gymnasium.Wrapper.unwrapped` attribute.
If the environment is already a bare environment, the :attr:`gymnasium.Wrapper.unwrapped` attribute will just return itself.
>>> wrapped_env
<RescaleAction<TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>>
>>> wrapped_env.unwrapped # doctest: +SKIP
<gymnasium.envs.mujoco.hopper_v4.HopperEnv object at 0x7fbb5efd0490>
There are three common things you might want a wrapper to do:
- Transform actions before applying them to the base environment
- Transform observations that are returned by the base environment
- Transform rewards that are returned by the base environment
Such wrappers can be easily implemented by inheriting from :class:`gymnasium.ActionWrapper`,
:class:`gymnasium.ObservationWrapper`, or :class:`gymnasium.RewardWrapper` and implementing the respective transformation.
If you need a wrapper to do more complicated tasks, you can inherit from the :class:`gymnasium.Wrapper` class directly.
If you'd like to implement your own custom wrapper, check out `the corresponding tutorial <../../tutorials/implementing_custom_wrappers>`_.
"""
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing
from gymnasium.wrappers.autoreset import AutoResetWrapper
from gymnasium.wrappers.clip_action import ClipAction
from gymnasium.wrappers.compatibility import EnvCompatibility
from gymnasium.wrappers.env_checker import PassiveEnvChecker
from gymnasium.wrappers.filter_observation import FilterObservation
from gymnasium.wrappers.flatten_observation import FlattenObservation
from gymnasium.wrappers.frame_stack import FrameStack, LazyFrames
from gymnasium.wrappers.gray_scale_observation import GrayScaleObservation
from gymnasium.wrappers.human_rendering import HumanRendering
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward
from gymnasium.wrappers.order_enforcing import OrderEnforcing
from gymnasium.wrappers.pixel_observation import PixelObservationWrapper
from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gymnasium.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gymnasium.wrappers.render_collection import RenderCollection
from gymnasium.wrappers.rescale_action import RescaleAction
from gymnasium.wrappers.resize_observation import ResizeObservation
from gymnasium.wrappers.step_api_compatibility import StepAPICompatibility
from gymnasium.wrappers.time_aware_observation import TimeAwareObservation
from gymnasium.wrappers.time_limit import TimeLimit
from gymnasium.wrappers.transform_observation import TransformObservation
from gymnasium.wrappers.transform_reward import TransformReward
from gymnasium.wrappers.vector_list_info import VectorListInfo

View File

@ -0,0 +1,204 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
try:
import cv2
except ImportError:
cv2 = None
class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
Specifically, the following preprocess stages applies to the atari environment:
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default
- Max-pooling: Pools over the most recent two observations from the frame skips
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
):
"""Wrapper for Atari 2600 preprocessing.
Args:
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
Raises:
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
if cv2 is None:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one "
"frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs:
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
]
else:
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, terminated, truncated, info = 0.0, False, False, {}
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives
if terminated or truncated:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[1])
else:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
return self._get_obs(), reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
assert cv2 is not None
obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
if self.grayscale_obs and self.grayscale_newaxis:
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
return obs

View File

@ -0,0 +1,86 @@
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
Warning:
When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
final reward, terminated and truncated state from the previous episode.
If you need the final state from the previous episode, you need to retrieve it via the
"final_observation" key in the info dict.
Make sure you know what you're doing if you use this wrapper!
"""
def __init__(self, env: gym.Env):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(self, action):
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
obs, reward, terminated, truncated, info = self.env.step(action)
if terminated or truncated:
new_obs, new_info = self.env.reset()
assert (
"final_observation" not in new_info
), 'info dict cannot contain key "final_observation" '
assert (
"final_info" not in new_info
), 'info dict cannot contain key "final_info" '
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
return obs, reward, terminated, truncated, info
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to specify the `autoreset=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.autoreset = True
self._cached_spec = env_spec
return env_spec

View File

@ -0,0 +1,43 @@
"""Wrapper for clipping actions within a valid bound."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipAction
>>> env = gym.make("Hopper-v4")
>>> env = ClipAction(env)
>>> env.action_space
Box(-1.0, 1.0, (3,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.step(np.array([5.0, -2.0, 0.0]))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, Box)
gym.utils.RecordConstructorArgs.__init__(self)
gym.ActionWrapper.__init__(self, env)
def action(self, action):
"""Clips the action within the valid bounds.
Args:
action: The action to clip
Returns:
The clipped action
"""
return np.clip(action, self.action_space.low, self.action_space.high)

View File

@ -0,0 +1,129 @@
"""A compatibility wrapper converting an old-style environment into a valid environment."""
from typing import Any, Dict, Optional, Protocol, Tuple, runtime_checkable
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ObsType
from gymnasium.utils.step_api_compatibility import (
convert_to_terminated_truncated_step_api,
)
@runtime_checkable
class LegacyEnv(Protocol):
"""A protocol for environments using the old step API."""
observation_space: gym.Space
action_space: gym.Space
def reset(self) -> Any:
"""Reset the environment and return the initial observation."""
...
def step(self, action: Any) -> Tuple[Any, float, bool, Dict]:
"""Run one timestep of the environment's dynamics."""
...
def render(self, mode: Optional[str] = "human") -> Any:
"""Render the environment."""
...
def close(self):
"""Close the environment."""
...
def seed(self, seed: Optional[int] = None):
"""Set the seed for this env's random number generator(s)."""
...
class EnvCompatibility(gym.Env):
r"""A wrapper which can transform an environment from the old API to the new API.
Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation.
New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info).
(Refer to docs for details on the API change)
Known limitations:
- Environments that use `self.np_random` might not work as expected.
"""
def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None):
"""A wrapper which converts old-style envs to valid modern envs.
Some information may be lost in the conversion, so we recommend updating your environment.
Args:
old_env (LegacyEnv): the env to wrap, implemented with the old API
render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render
"""
logger.deprecation(
"The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v1.0. "
"Instead use `gymnasium.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`"
)
self.env = old_env
self.metadata = getattr(old_env, "metadata", {"render_modes": []})
self.render_mode = render_mode
self.reward_range = getattr(old_env, "reward_range", None)
self.spec = getattr(old_env, "spec", None)
self.observation_space = old_env.observation_space
self.action_space = old_env.action_space
def reset(
self, seed: Optional[int] = None, options: Optional[dict] = None
) -> Tuple[ObsType, dict]:
"""Resets the environment.
Args:
seed: the seed to reset the environment with
options: the options to reset the environment with
Returns:
(observation, info)
"""
if seed is not None:
self.env.seed(seed)
# Options are ignored
if self.render_mode == "human":
self.render()
return self.env.reset(), {}
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
"""Steps through the environment.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info)
"""
obs, reward, done, info = self.env.step(action)
if self.render_mode == "human":
self.render()
return convert_to_terminated_truncated_step_api((obs, reward, done, info))
def render(self) -> Any:
"""Renders the environment.
Returns:
The rendering of the environment, depending on the render mode
"""
return self.env.render(mode=self.render_mode)
def close(self):
"""Closes the environment."""
self.env.close()
def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)

View File

@ -0,0 +1,95 @@
"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ActType
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_observation_space(env.observation_space)
self.checked_reset = False
self.checked_step = False
self.checked_render = False
self.close_called = False
def step(self, action: ActType):
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if not self.checked_step:
self.checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(self, **kwargs):
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if not self.checked_reset:
self.checked_reset = True
return env_reset_passive_checker(self.env, **kwargs)
else:
return self.env.reset(**kwargs)
def render(self, *args, **kwargs):
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if not self.checked_render:
self.checked_render = True
return env_render_passive_checker(self.env, *args, **kwargs)
else:
return self.env.render(*args, **kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to such that `disable_env_checker=False`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.disable_env_checker = False
self._cached_spec = env_spec
return env_spec
def close(self):
"""Warns if calling close on a closed environment fails."""
if not self.close_called:
self.close_called = True
return self.env.close()
else:
try:
return self.env.close()
except Exception as e:
logger.warn(
"Calling `env.close()` on the closed environment should be allowed, but it raised the following exception."
)
raise e

View File

@ -0,0 +1,92 @@
"""A wrapper for filtering dictionary observations by their keys."""
import copy
from typing import Sequence
import gymnasium as gym
from gymnasium import spaces
class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Filter Dict observation space by the keys.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformObservation
>>> env = gym.make("CartPole-v1")
>>> env = TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservation(env, filter_keys=['obs'])
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32)}, {})
>>> env.step(0)
({'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32)}, 1.0, False, False, {})
"""
def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None):
"""A wrapper that filters dictionary observations by their keys.
Args:
env: The environment to apply the wrapper
filter_keys: List of keys to be included in the observations. If ``None``, observations will not be filtered and this wrapper has no effect
Raises:
ValueError: If the environment's observation space is not :class:`spaces.Dict`
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
"""
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
gym.ObservationWrapper.__init__(self, env)
wrapped_observation_space = env.observation_space
if not isinstance(wrapped_observation_space, spaces.Dict):
raise ValueError(
f"FilterObservationWrapper is only usable with dict observations, "
f"environment observation space is {type(wrapped_observation_space)}"
)
observation_keys = wrapped_observation_space.spaces.keys()
if filter_keys is None:
filter_keys = tuple(observation_keys)
missing_keys = {key for key in filter_keys if key not in observation_keys}
if missing_keys:
raise ValueError(
"All the filter_keys must be included in the original observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {observation_keys}\n"
f"Missing keys: {missing_keys}"
)
self.observation_space = type(wrapped_observation_space)(
[
(name, copy.deepcopy(space))
for name, space in wrapped_observation_space.spaces.items()
if name in filter_keys
]
)
self._env = env
self._filter_keys = tuple(filter_keys)
def observation(self, observation):
"""Filters the observations.
Args:
observation: The observation to filter
Returns:
The filtered observations
"""
filter_observation = self._filter_observation(observation)
return filter_observation
def _filter_observation(self, observation):
observation = type(observation)(
[
(name, value)
for name, value in observation.items()
if name in self._filter_keys
]
)
return observation

View File

@ -0,0 +1,43 @@
"""Wrapper for flattening observations of an environment."""
import gymnasium as gym
from gymnasium import spaces
class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that flattens the observation.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FlattenObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservation(env)
>>> env.observation_space.shape
(27648,)
>>> obs, _ = env.reset()
>>> obs.shape
(27648,)
"""
def __init__(self, env: gym.Env):
"""Flattens the observations of an environment.
Args:
env: The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
self.observation_space = spaces.flatten_space(env.observation_space)
def observation(self, observation):
"""Flattens an observation.
Args:
observation: The observation to flatten
Returns:
The flattened observation
"""
return spaces.flatten(self.env.observation_space, observation)

View File

@ -0,0 +1,196 @@
"""Wrapper that stacks frames."""
from collections import deque
from typing import Union
import numpy as np
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box
class LazyFrames:
"""Ensures common frames are only stored once to optimize memory use.
To further reduce the memory use, it is optionally to turn on lz4 to compress the observations.
Note:
This object should only be converted to numpy array just before forward pass.
"""
__slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames")
def __init__(self, frames: list, lz4_compress: bool = False):
"""Lazyframe for a set of frames and if to apply lz4.
Args:
frames (list): The frames to convert to lazy frames
lz4_compress (bool): Use lz4 to compress the frames internally
Raises:
DependencyNotInstalled: lz4 is not installed
"""
self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape
self.dtype = frames[0].dtype
if lz4_compress:
try:
from lz4.block import compress
except ImportError as e:
raise DependencyNotInstalled(
"lz4 is not installed, run `pip install gymnasium[other]`"
) from e
frames = [compress(frame) for frame in frames]
self._frames = frames
self.lz4_compress = lz4_compress
def __array__(self, dtype=None):
"""Gets a numpy array of stacked frames with specific dtype.
Args:
dtype: The dtype of the stacked frames
Returns:
The array of stacked frames with dtype
"""
arr = self[:]
if dtype is not None:
return arr.astype(dtype)
return arr
def __len__(self):
"""Returns the number of frame stacks.
Returns:
The number of frame stacks
"""
return self.shape[0]
def __getitem__(self, int_or_slice: Union[int, slice]):
"""Gets the stacked frames for a particular index or slice.
Args:
int_or_slice: Index or slice to get items for
Returns:
np.stacked frames for the int or slice
"""
if isinstance(int_or_slice, int):
return self._check_decompress(self._frames[int_or_slice]) # single frame
return np.stack(
[self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0
)
def __eq__(self, other):
"""Checks that the current frames are equal to the other object."""
return self.__array__() == other
def _check_decompress(self, frame):
if self.lz4_compress:
from lz4.block import decompress
return np.frombuffer(decompress(frame), dtype=self.dtype).reshape(
self.frame_shape
)
return frame
class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
Note:
- To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.
- The observation space must be :class:`Box` type. If one uses :class:`Dict`
as observation space, it should apply :class:`FlattenObservation` wrapper first.
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import FrameStack
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStack(env, 4)
>>> env.observation_space
Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs, _ = env.reset()
>>> obs.shape
(4, 96, 96, 3)
"""
def __init__(
self,
env: gym.Env,
num_stack: int,
lz4_compress: bool = False,
):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env (Env): The environment to apply the wrapper
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
"""
gym.utils.RecordConstructorArgs.__init__(
self, num_stack=num_stack, lz4_compress=lz4_compress
)
gym.ObservationWrapper.__init__(self, env)
self.num_stack = num_stack
self.lz4_compress = lz4_compress
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(
self.observation_space.high[np.newaxis, ...], num_stack, axis=0
)
self.observation_space = Box(
low=low, high=high, dtype=self.observation_space.dtype
)
def observation(self, observation):
"""Converts the wrappers current frames to lazy frames.
Args:
observation: Ignored
Returns:
:class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames`
"""
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
return LazyFrames(list(self.frames), self.lz4_compress)
def step(self, action):
"""Steps through the environment, appending the observation to the frame buffer.
Args:
action: The action to step through the environment with
Returns:
Stacked observations, reward, terminated, truncated, and information from the environment
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self.frames.append(observation)
return self.observation(None), reward, terminated, truncated, info
def reset(self, **kwargs):
"""Reset the environment with kwargs.
Args:
**kwargs: The kwargs for the environment reset
Returns:
The stacked observations
"""
obs, info = self.env.reset(**kwargs)
[self.frames.append(obs) for _ in range(self.num_stack)]
return self.observation(None), info

View File

@ -0,0 +1,68 @@
"""Wrapper that converts a color observation to grayscale."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Convert the image observation from RGB to gray scale.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import GrayScaleObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space
Box(0, 255, (96, 96, 3), uint8)
>>> env = GrayScaleObservation(gym.make("CarRacing-v2"))
>>> env.observation_space
Box(0, 255, (96, 96), uint8)
>>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True)
>>> env.observation_space
Box(0, 255, (96, 96, 1), uint8)
"""
def __init__(self, env: gym.Env, keep_dim: bool = False):
"""Convert the image observation from RGB to gray scale.
Args:
env (Env): The environment to apply the wrapper
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
Otherwise, they are of shape AxB.
"""
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
gym.ObservationWrapper.__init__(self, env)
self.keep_dim = keep_dim
assert (
isinstance(self.observation_space, Box)
and len(self.observation_space.shape) == 3
and self.observation_space.shape[-1] == 3
)
obs_shape = self.observation_space.shape[:2]
if self.keep_dim:
self.observation_space = Box(
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
)
else:
self.observation_space = Box(
low=0, high=255, shape=obs_shape, dtype=np.uint8
)
def observation(self, observation):
"""Converts the colour observation to greyscale.
Args:
observation: Color observations
Returns:
Grayscale observations
"""
import cv2
observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
if self.keep_dim:
observation = np.expand_dims(observation, -1)
return observation

View File

@ -0,0 +1,142 @@
"""A wrapper that adds human-renering functionality to an environment."""
import copy
import numpy as np
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
RGB images but haven't implemented any code to render the images to the screen.
If you want to use this wrapper with your environments, remember to specify ``"render_fps"``
in the metadata of your environment.
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import HumanRendering
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRendering(env)
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
>>> env = gym.make("phys2d/CartPole-v1", render_mode="human") # phys2d/CartPole-v1 doesn't implement human-rendering natively
>>> obs, _ = env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRendering(env)
>>> obs, _ = wrapped.reset()
>>> env.render() # env.render() will always return an empty list!
[]
"""
def __init__(self, env):
"""Initialize a :class:`HumanRendering` instance.
Args:
env: The environment that is being wrapped
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [
"rgb_array",
"rgb_array_list",
], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'"
assert (
"render_fps" in env.metadata
), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"
self.screen_size = None
self.window = None
self.clock = None
self.metadata = copy.deepcopy(self.env.metadata)
if "human" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append("human")
gym.utils.RecordConstructorArgs.__init__(self)
@property
def render_mode(self):
"""Always returns ``'human'``."""
return "human"
def step(self, *args, **kwargs):
"""Perform a step in the base environment and render a frame to the screen."""
result = self.env.step(*args, **kwargs)
self._render_frame()
return result
def reset(self, *args, **kwargs):
"""Reset the base environment and render a frame to the screen."""
result = self.env.reset(*args, **kwargs)
self._render_frame()
return result
def render(self):
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
return None
def _render_frame(self):
"""Fetch the last frame from the base environment and render it to the screen."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[box2d]`"
) from e
if self.env.render_mode == "rgb_array_list":
last_rgb_array = self.env.render()
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render()
else:
raise Exception(
f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}"
)
assert isinstance(last_rgb_array, np.ndarray)
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
if self.screen_size is None:
self.screen_size = rgb_array.shape[:2]
assert (
self.screen_size == rgb_array.shape[:2]
), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}"
if self.window is None:
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(self.screen_size)
if self.clock is None:
self.clock = pygame.time.Clock()
surf = pygame.surfarray.make_surface(rgb_array)
self.window.blit(surf, (0, 0))
pygame.event.pump()
self.clock.tick(self.metadata["render_fps"])
pygame.display.flip()
def close(self):
"""Close the rendering window."""
super().close()
if self.window is not None:
import pygame
pygame.display.quit()
pygame.quit()

View File

@ -0,0 +1 @@
"""Module for monitoring.video_recorder."""

View File

@ -0,0 +1,178 @@
"""A wrapper for video recording environments by rolling it out, frame by frame."""
import json
import os
import os.path
import tempfile
from typing import List, Optional
from gymnasium import error, logger
class VideoRecorder:
"""VideoRecorder renders a nice movie of a rollout, frame by frame.
It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.
Note:
You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process.
"""
def __init__(
self,
env,
path: Optional[str] = None,
metadata: Optional[dict] = None,
enabled: bool = True,
base_path: Optional[str] = None,
disable_logger: bool = False,
):
"""Video recorder renders a nice movie of a rollout, frame by frame.
Args:
env (Env): Environment to take video of.
path (Optional[str]): Path to the video file; will be randomly chosen if omitted.
metadata (Optional[dict]): Contents to save to the metadata file.
enabled (bool): Whether to actually record video, or just no-op (for convenience)
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
disable_logger (bool): Whether to disable moviepy logger or not.
Raises:
Error: You can pass at most one of `path` or `base_path`
Error: Invalid path given that must have a particular file extension
"""
self._async = env.metadata.get("semantics.async")
self.enabled = enabled
self.disable_logger = disable_logger
self._closed = False
self.render_history = []
self.env = env
self.render_mode = env.render_mode
try:
# check that moviepy is now installed
import moviepy # noqa: F401
except ImportError as e:
raise error.DependencyNotInstalled(
"moviepy is not installed, run `pip install moviepy`"
) from e
if self.render_mode in {None, "human", "ansi", "ansi_list"}:
raise ValueError(
f"Render mode is {self.render_mode}, which is incompatible with"
f" RecordVideo. Initialize your environment with a render_mode"
f" that returns an image, such as rgb_array."
)
# Don't bother setting anything else if not enabled
if not self.enabled:
return
if path is not None and base_path is not None:
raise error.Error("You can pass at most one of `path` or `base_path`.")
required_ext = ".mp4"
if path is None:
if base_path is not None:
# Base path given, append ext
path = base_path + required_ext
else:
# Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile(suffix=required_ext) as f:
path = f.name
self.path = path
path_base, actual_ext = os.path.splitext(self.path)
if actual_ext != required_ext:
raise error.Error(
f"Invalid path given: {self.path} -- must have file extension {required_ext}."
)
self.frames_per_sec = env.metadata.get("render_fps", 30)
self.broken = False
# Dump metadata
self.metadata = metadata or {}
self.metadata["content_type"] = "video/mp4"
self.metadata_path = f"{path_base}.meta.json"
self.write_metadata()
logger.info(f"Starting new video recorder writing to {self.path}")
self.recorded_frames = []
@property
def functional(self):
"""Returns if the video recorder is functional, is enabled and not broken."""
return self.enabled and not self.broken
def capture_frame(self):
"""Render the given `env` and add the resulting frame to the video."""
frame = self.env.render()
if isinstance(frame, List):
self.render_history += frame
frame = frame[-1]
if not self.functional:
return
if self._closed:
logger.warn(
"The video recorder has been closed and no frames will be captured anymore."
)
return
logger.debug("Capturing video frame: path=%s", self.path)
if frame is None:
if self._async:
return
else:
# Indicates a bug in the environment: don't want to raise
# an error here.
logger.warn(
"Env returned None on `render()`. Disabling further rendering for video recorder by marking as "
f"disabled: path={self.path} metadata_path={self.metadata_path}"
)
self.broken = True
else:
self.recorded_frames.append(frame)
def close(self):
"""Flush all data to disk and close any open frame encoders."""
if not self.enabled or self._closed:
return
# Close the encoder
if len(self.recorded_frames) > 0:
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
raise error.DependencyNotInstalled(
"moviepy is not installed, run `pip install moviepy`"
) from e
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
moviepy_logger = None if self.disable_logger else "bar"
clip.write_videofile(self.path, logger=moviepy_logger)
else:
# No frames captured. Set metadata.
if self.metadata is None:
self.metadata = {}
self.metadata["empty"] = True
self.write_metadata()
# Stop tracking this for autoclose
self._closed = True
def write_metadata(self):
"""Writes metadata to metadata path."""
with open(self.metadata_path, "w") as f:
json.dump(self.metadata, f)
def __del__(self):
"""Closes the environment correctly when the recorder is deleted."""
# Make sure we've closed up shop when garbage collecting
if not self._closed:
logger.warn("Unable to save last video! Did you call close()?")

View File

@ -0,0 +1,155 @@
"""Set of wrappers for normalizing actions and observations."""
import numpy as np
import gymnasium as gym
class RunningMeanStd:
"""Tracks the mean, variance and count of values."""
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
try:
self.num_envs = self.get_wrapper_attr("num_envs")
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.num_envs = 1
self.is_vector_env = False
if self.is_vector_env:
self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
else:
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
def step(self, action):
"""Steps through the environment and normalizes the observation."""
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs, rews, terminateds, truncateds, infos
def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations."""
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon)
class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
"""
def __init__(
self,
env: gym.Env,
gamma: float = 0.99,
epsilon: float = 1e-8,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Args:
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
try:
self.num_envs = self.get_wrapper_attr("num_envs")
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.num_envs = 1
self.is_vector_env = False
self.return_rms = RunningMeanStd(shape=())
self.returns = np.zeros(self.num_envs)
self.gamma = gamma
self.epsilon = epsilon
def step(self, action):
"""Steps through the environment, normalizing the rewards returned."""
obs, rews, terminateds, truncateds, infos = self.env.step(action)
if not self.is_vector_env:
rews = np.array([rews])
self.returns = self.returns * self.gamma * (1 - terminateds) + rews
rews = self.normalize(rews)
if not self.is_vector_env:
rews = rews[0]
return obs, rews, terminateds, truncateds, infos
def normalize(self, rews):
"""Normalizes the rewards with the running mean rewards and their variance."""
self.return_rms.update(self.returns)
return rews / np.sqrt(self.return_rms.var + self.epsilon)

View File

@ -0,0 +1,89 @@
"""Wrapper to enforce the proper ordering of environment operations."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
from gymnasium.error import ResetNeeded
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import OrderEnforcing
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcing(env)
>>> env.step(0)
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
>>> _ = env.reset()
>>> env.render()
>>> _ = env.step(0)
>>> env.close()
"""
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(self, action):
"""Steps through the environment with `kwargs`."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return self.env.step(action)
def reset(self, **kwargs):
"""Resets the environment with `kwargs`."""
self._has_reset = True
return self.env.reset(**kwargs)
def render(self, *args, **kwargs):
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return self.env.render(*args, **kwargs)
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to add the `order_enforce=True`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.order_enforce = True
self._cached_spec = env_spec
return env_spec

View File

@ -0,0 +1,215 @@
"""Wrapper for augmenting observations by pixel values."""
import collections
import copy
from collections.abc import MutableMapping
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import gymnasium as gym
from gymnasium import spaces
STATE_KEY = "state"
class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment observations by pixel values.
Observations of this wrapper will be dictionaries of images.
You can also choose to add the observation of the base environment to this dictionary.
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
of rendered images will be updated with the base environment's observation. If, however, the observation
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
space) will be added to the dictionary under the key "state".
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import PixelObservationWrapper
>>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"))
>>> obs, _ = env.reset()
>>> obs.keys()
odict_keys(['pixels'])
>>> obs['pixels'].shape
(400, 600, 3)
>>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixels_only=False)
>>> obs, _ = env.reset()
>>> obs.keys()
odict_keys(['state', 'pixels'])
>>> obs['state'].shape
(96, 96, 3)
>>> obs['pixels'].shape
(400, 600, 3)
>>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixel_keys=('obs',))
>>> obs, _ = env.reset()
>>> obs.keys()
odict_keys(['obs'])
>>> obs['obs'].shape
(400, 600, 3)
"""
def __init__(
self,
env: gym.Env,
pixels_only: bool = True,
render_kwargs: Optional[Dict[str, Dict[str, Any]]] = None,
pixel_keys: Tuple[str, ...] = ("pixels",),
):
"""Initializes a new pixel Wrapper.
Args:
env: The environment to wrap.
pixels_only (bool): If `True` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If `False`, the
observation dictionary will contain both the original
observations and the pixel observations.
render_kwargs (dict): Optional dictionary containing that maps elements of `pixel_keys` to
keyword arguments passed to the :meth:`self.render` method.
pixel_keys: Optional custom string specifying the pixel
observation's key in the `OrderedDict` of observations.
Defaults to `(pixels,)`.
Raises:
AssertionError: If any of the keys in ``render_kwargs``do not show up in ``pixel_keys``.
ValueError: If ``env``'s observation space is not compatible with the
wrapper. Supported formats are a single array, or a dict of
arrays.
ValueError: If ``env``'s observation already contains any of the
specified ``pixel_keys``.
TypeError: When an unexpected pixel type is used
"""
gym.utils.RecordConstructorArgs.__init__(
self,
pixels_only=pixels_only,
render_kwargs=render_kwargs,
pixel_keys=pixel_keys,
)
gym.ObservationWrapper.__init__(self, env)
# Avoid side-effects that occur when render_kwargs is manipulated
render_kwargs = copy.deepcopy(render_kwargs)
self.render_history = []
if render_kwargs is None:
render_kwargs = {}
for key in render_kwargs:
assert key in pixel_keys, (
"The argument render_kwargs should map elements of "
"pixel_keys to dictionaries of keyword arguments. "
f"Found key '{key}' in render_kwargs but not in pixel_keys."
)
default_render_kwargs = {}
if not env.render_mode:
raise AttributeError(
"env.render_mode must be specified to use PixelObservationWrapper:"
"`gymnasium.make(env_name, render_mode='rgb_array')`."
)
for key in pixel_keys:
render_kwargs.setdefault(key, default_render_kwargs)
wrapped_observation_space = env.observation_space
if isinstance(wrapped_observation_space, spaces.Box):
self._observation_is_dict = False
invalid_keys = {STATE_KEY}
elif isinstance(wrapped_observation_space, (spaces.Dict, MutableMapping)):
self._observation_is_dict = True
invalid_keys = set(wrapped_observation_space.spaces.keys())
else:
raise ValueError("Unsupported observation space structure.")
if not pixels_only:
# Make sure that now keys in the `pixel_keys` overlap with
# `observation_keys`
overlapping_keys = set(pixel_keys) & set(invalid_keys)
if overlapping_keys:
raise ValueError(
f"Duplicate or reserved pixel keys {overlapping_keys!r}."
)
if pixels_only:
self.observation_space = spaces.Dict()
elif self._observation_is_dict:
self.observation_space = copy.deepcopy(wrapped_observation_space)
else:
self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space})
# Extend observation space with pixels.
self.env.reset()
pixels_spaces = {}
for pixel_key in pixel_keys:
pixels = self._render(**render_kwargs[pixel_key])
pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
if not hasattr(pixels, "dtype") or not hasattr(pixels, "shape"):
raise TypeError(
f"Render method returns a {pixels.__class__.__name__}, but an array with dtype and shape is expected."
"Be sure to specify the correct render_mode."
)
if np.issubdtype(pixels.dtype, np.integer):
low, high = (0, 255)
elif np.issubdtype(pixels.dtype, np.float):
low, high = (-float("inf"), float("inf"))
else:
raise TypeError(pixels.dtype)
pixels_space = spaces.Box(
shape=pixels.shape, low=low, high=high, dtype=pixels.dtype
)
pixels_spaces[pixel_key] = pixels_space
self.observation_space.spaces.update(pixels_spaces)
self._pixels_only = pixels_only
self._render_kwargs = render_kwargs
self._pixel_keys = pixel_keys
def observation(self, observation):
"""Updates the observations with the pixel observations.
Args:
observation: The observation to add pixel observations for
Returns:
The updated pixel observations
"""
pixel_observation = self._add_pixel_observation(observation)
return pixel_observation
def _add_pixel_observation(self, wrapped_observation):
if self._pixels_only:
observation = collections.OrderedDict()
elif self._observation_is_dict:
observation = type(wrapped_observation)(wrapped_observation)
else:
observation = collections.OrderedDict()
observation[STATE_KEY] = wrapped_observation
pixel_observations = {
pixel_key: self._render(**self._render_kwargs[pixel_key])
for pixel_key in self._pixel_keys
}
observation.update(pixel_observations)
return observation
def render(self, *args, **kwargs):
"""Renders the environment."""
render = self.env.render(*args, **kwargs)
if isinstance(render, list):
render = self.render_history + render
self.render_history = []
return render
def _render(self, *args, **kwargs):
render = self.env.render(*args, **kwargs)
if isinstance(render, list):
self.render_history += render
return render

View File

@ -0,0 +1,131 @@
"""Wrapper that tracks the cumulative rewards and episode lengths."""
import time
from collections import deque
from typing import Optional
import numpy as np
import gymnasium as gym
class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
length_queue: The lengths of the last ``deque_size``-many episodes
"""
def __init__(self, env: gym.Env, deque_size: int = 100):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size)
gym.Wrapper.__init__(self, env)
try:
self.num_envs = self.get_wrapper_attr("num_envs")
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.num_envs = 1
self.is_vector_env = False
self.episode_count = 0
self.episode_start_times: np.ndarray = None
self.episode_returns: Optional[np.ndarray] = None
self.episode_lengths: Optional[np.ndarray] = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
def reset(self, **kwargs):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(**kwargs)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info
def step(self, action):
"""Steps through the environment, recording the episode statistics."""
(
observations,
rewards,
terminations,
truncations,
infos,
) = self.env.step(action)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
)
else:
infos["episode"] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
}
if self.is_vector_env:
infos["_episode"] = np.where(dones, True, False)
self.return_queue.extend(self.episode_returns[dones])
self.length_queue.extend(self.episode_lengths[dones])
self.episode_count += num_dones
self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()
return (
observations,
rewards,
terminations,
truncations,
infos,
)

View File

@ -0,0 +1,228 @@
"""Wrapper for recording videos."""
import os
from typing import Callable, Optional
import gymnasium as gym
from gymnasium import logger
from gymnasium.wrappers.monitoring import video_recorder
def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
Args:
episode_id: The episode number
Returns:
If to apply a video schedule number
"""
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0
class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
To do this, you can specify **either** ``episode_trigger`` **or** ``step_trigger`` (not both).
They should be functions returning a boolean that indicates whether a recording should be started at the
current episode or step, respectively.
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed.
By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
``video_length``.
"""
def __init__(
self,
env: gym.Env,
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: int = 0,
name_prefix: str = "rl-video",
disable_logger: bool = False,
):
"""Wrapper records videos of rollouts.
Args:
env: The environment that will be wrapped
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
disable_logger (bool): Whether to disable moviepy logger or not.
"""
gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
if env.render_mode in {None, "human", "ansi", "ansi_list"}:
raise ValueError(
f"Render mode is {env.render_mode}, which is incompatible with"
f" RecordVideo. Initialize your environment with a render_mode"
f" that returns an image, such as rgb_array."
)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule
trigger_count = sum(x is not None for x in [episode_trigger, step_trigger])
assert trigger_count == 1, "Must specify exactly one trigger"
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
self.disable_logger = disable_logger
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
if os.path.isdir(self.video_folder):
logger.warn(
f"Overwriting existing videos at {self.video_folder} folder "
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.terminated = False
self.truncated = False
self.recorded_frames = 0
self.episode_id = 0
try:
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.is_vector_env = False
def reset(self, **kwargs):
"""Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs)
self.terminated = False
self.truncated = False
if self.recording:
assert self.video_recorder is not None
self.video_recorder.recorded_frames = []
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations
def start_video_recorder(self):
"""Starts video recorder using :class:`video_recorder.VideoRecorder`."""
self.close_video_recorder()
video_name = f"{self.name_prefix}-step-{self.step_id}"
if self.episode_trigger:
video_name = f"{self.name_prefix}-episode-{self.episode_id}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env,
base_path=base_path,
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
disable_logger=self.disable_logger,
)
self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True
def _video_enabled(self):
if self.step_trigger:
return self.step_trigger(self.step_id)
else:
return self.episode_trigger(self.episode_id)
def step(self, action):
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
(
observations,
rewards,
terminateds,
truncateds,
infos,
) = self.env.step(action)
if not (self.terminated or self.truncated):
# increment steps and episodes
self.step_id += 1
if not self.is_vector_env:
if terminateds or truncateds:
self.episode_id += 1
self.terminated = terminateds
self.truncated = truncateds
elif terminateds[0] or truncateds[0]:
self.episode_id += 1
self.terminated = terminateds[0]
self.truncated = truncateds[0]
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
else:
if not self.is_vector_env:
if terminateds or truncateds:
self.close_video_recorder()
elif terminateds[0] or truncateds[0]:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations, rewards, terminateds, truncateds, infos
def close_video_recorder(self):
"""Closes the video recorder if currently recording."""
if self.recording:
assert self.video_recorder is not None
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1
def render(self, *args, **kwargs):
"""Compute the render frames as specified by render_mode attribute during initialization of the environment or as specified in kwargs."""
if self.video_recorder is None or not self.video_recorder.enabled:
return super().render(*args, **kwargs)
if len(self.video_recorder.render_history) > 0:
recorded_frames = [
self.video_recorder.render_history.pop()
for _ in range(len(self.video_recorder.render_history))
]
if self.recording:
return recorded_frames
else:
return recorded_frames + super().render(*args, **kwargs)
else:
return super().render(*args, **kwargs)
def close(self):
"""Closes the wrapper then the video recorder."""
super().close()
self.close_video_recorder()

View File

@ -0,0 +1,62 @@
"""A wrapper that adds render collection mode to an environment."""
import copy
import gymnasium as gym
class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Save collection of render frames."""
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
"""Initialize a :class:`RenderCollection` instance.
Args:
env: The environment that is being wrapped
pop_frames (bool): If true, clear the collection frames after .render() is called.
Default value is True.
reset_clean (bool): If true, clear the collection frames when .reset() is called.
Default value is True.
"""
gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
self.frame_list = []
self.reset_clean = reset_clean
self.pop_frames = pop_frames
self.metadata = copy.deepcopy(self.env.metadata)
if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append(f"{self.env.render_mode}_list")
@property
def render_mode(self):
"""Returns the collection render_mode name."""
return f"{self.env.render_mode}_list"
def step(self, *args, **kwargs):
"""Perform a step in the base environment and collect a frame."""
output = self.env.step(*args, **kwargs)
self.frame_list.append(self.env.render())
return output
def reset(self, *args, **kwargs):
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
result = self.env.reset(*args, **kwargs)
if self.reset_clean:
self.frame_list = []
self.frame_list.append(self.env.render())
return result
def render(self):
"""Returns the collection of frames and, if pop_frames = True, clears it."""
frames = self.frame_list
if self.pop_frames:
self.frame_list = []
return frames

View File

@ -0,0 +1,89 @@
"""Wrapper for rescaling actions to within a max and min action."""
from typing import Union
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import RescaleAction
>>> import numpy as np
>>> env = gym.make("Hopper-v4")
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1]))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 0.75])
>>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs)
True
"""
def __init__(
self,
env: gym.Env,
min_action: Union[float, int, np.ndarray],
max_action: Union[float, int, np.ndarray],
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The environment to apply the wrapper
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(
env.action_space, Box
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
gym.ActionWrapper.__init__(self, env)
self.min_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
)
self.max_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
)
self.action_space = Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
)
def action(self, action):
"""Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.
Args:
action: The action to rescale
Returns:
The rescaled action
"""
assert np.all(np.greater_equal(action, self.min_action)), (
action,
self.min_action,
)
assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action)
low = self.env.action_space.low
high = self.env.action_space.high
action = low + (high - low) * (
(action - self.min_action) / (self.max_action - self.min_action)
)
action = np.clip(action, low, high)
return action

View File

@ -0,0 +1,83 @@
"""Wrapper for resizing observations."""
from __future__ import annotations
import numpy as np
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box
class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Resize the image observation.
This wrapper works on environments with image observations. More generally,
the input can either be two-dimensional (AxB, e.g. grayscale images) or
three-dimensional (AxBxC, e.g. color images). This resizes the observation
to the shape given by the 2-tuple :attr:`shape`.
The argument :attr:`shape` may also be an integer, in which case, the
observation is scaled to a square of side-length :attr:`shape`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ResizeObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = ResizeObservation(env, 64)
>>> env.observation_space.shape
(64, 64, 3)
"""
def __init__(self, env: gym.Env, shape: tuple[int, int] | int) -> None:
"""Resizes image observations to shape given by :attr:`shape`.
Args:
env: The environment to apply the wrapper
shape: The shape of the resized observations
"""
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
gym.ObservationWrapper.__init__(self, env)
if isinstance(shape, int):
shape = (shape, shape)
assert len(shape) == 2 and all(
x > 0 for x in shape
), f"Expected shape to be a 2-tuple of positive integers, got: {shape}"
self.shape = tuple(shape)
assert isinstance(
env.observation_space, Box
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
dims = len(env.observation_space.shape)
assert (
dims == 2 or dims == 3
), f"Expected the observation space to have 2 or 3 dimensions, got: {dims}"
obs_shape = self.shape + env.observation_space.shape[2:]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def observation(self, observation):
"""Updates the observations by resizing the observation to shape given by :attr:`shape`.
Args:
observation: The observation to reshape
Returns:
The reshaped observations
Raises:
DependencyNotInstalled: opencv-python is not installed
"""
try:
import cv2
except ImportError as e:
raise DependencyNotInstalled(
"opencv (cv2) is not installed, run `pip install gymnasium[other]`"
) from e
observation = cv2.resize(
observation, self.shape[::-1], interpolation=cv2.INTER_AREA
)
return observation.reshape(self.observation_space.shape)

View File

@ -0,0 +1,56 @@
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
import gymnasium as gym
from gymnasium.logger import deprecation
from gymnasium.utils.step_api_compatibility import step_api_compatibility
class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs):
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
Old step API refers to step() method returning (observation, reward, done, info)
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
(Refer to docs for details on the API change)
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import StepAPICompatibility
>>> env = gym.make("CartPole-v1")
>>> env # wrapper not applied by default, set to new API
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env = StepAPICompatibility(gym.make("CartPole-v1"))
>>> env
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
"""
def __init__(self, env: gym.Env, output_truncation_bool: bool = True):
"""A wrapper which can transform an environment from new step API to old and vice-versa.
Args:
env (gym.Env): the env to wrap. Can be in old or new API
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
"""
gym.utils.RecordConstructorArgs.__init__(
self, output_truncation_bool=output_truncation_bool
)
gym.Wrapper.__init__(self, env)
self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv)
self.output_truncation_bool = output_truncation_bool
if not self.output_truncation_bool:
deprecation(
"Initializing environment in (old) done step API which returns one bool instead of two."
)
def step(self, action):
"""Steps through the environment, returning 5 or 4 items depending on `output_truncation_bool`.
Args:
action: action to step through the environment with
Returns:
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
"""
step_returns = self.env.step(action)
return step_api_compatibility(
step_returns, self.output_truncation_bool, self.is_vector_env
)

View File

@ -0,0 +1,79 @@
"""Wrapper for adding time aware observations to environment observation."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Augment the observation with the current time step in the episode.
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TimeAwareObservation
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env)
>>> env.reset(seed=42)
(array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ]), {})
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0]
array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 1. ])
"""
def __init__(self, env: gym.Env):
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
Args:
env: The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.ObservationWrapper.__init__(self, env)
assert isinstance(env.observation_space, Box)
assert env.observation_space.dtype == np.float32
low = np.append(self.observation_space.low, 0.0)
high = np.append(self.observation_space.high, np.inf)
self.observation_space = Box(low, high, dtype=np.float32)
try:
self.is_vector_env = self.get_wrapper_attr("is_vector_env")
except AttributeError:
self.is_vector_env = False
def observation(self, observation):
"""Adds to the observation with the current time step.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time step appended to
"""
return np.append(observation, self.t)
def step(self, action):
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action.
"""
self.t += 1
return super().step(action)
def reset(self, **kwargs):
"""Reset the environment setting the time to zero.
Args:
**kwargs: Kwargs to apply to env.reset()
Returns:
The reset environment
"""
self.t = 0
return super().reset(**kwargs)

View File

@ -0,0 +1,89 @@
"""Wrapper for limiting the time steps of an environment."""
from __future__ import annotations
from copy import deepcopy
from typing import TYPE_CHECKING
import gymnasium as gym
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TimeLimit
>>> env = gym.make("CartPole-v1")
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(
self,
env: gym.Env,
max_episode_steps: int,
):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
"""
gym.utils.RecordConstructorArgs.__init__(
self, max_episode_steps=max_episode_steps
)
gym.Wrapper.__init__(self, env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
if the number of steps elapsed >= max episode steps
"""
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
truncated = True
return observation, reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)
@property
def spec(self) -> EnvSpec | None:
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
if self._cached_spec is not None:
return self._cached_spec
env_spec = self.env.spec
if env_spec is not None:
env_spec = deepcopy(env_spec)
env_spec.max_episode_steps = self._max_episode_steps
self._cached_spec = env_spec
return env_spec

View File

@ -0,0 +1,47 @@
"""Wrapper for transforming observations."""
from typing import Any, Callable
import gymnasium as gym
class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Transform the observation via an arbitrary function :attr:`f`.
The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
If the transformation you wish to apply to observations returns values in a *different* space, you should subclass :class:`ObservationWrapper`, implement the transformation, and set the new observation space accordingly. If you were to use this wrapper instead, the observation space would be set incorrectly.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformObservation
>>> import numpy as np
>>> np.random.seed(0)
>>> env = gym.make("CartPole-v1")
>>> env = TransformObservation(env, lambda obs: obs + 0.1 * np.random.randn(*obs.shape))
>>> env.reset(seed=42)
(array([0.20380084, 0.03390356, 0.13373359, 0.24382612]), {})
"""
def __init__(self, env: gym.Env, f: Callable[[Any], Any]):
"""Initialize the :class:`TransformObservation` wrapper with an environment and a transform function :attr:`f`.
Args:
env: The environment to apply the wrapper
f: A function that transforms the observation
"""
gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.ObservationWrapper.__init__(self, env)
assert callable(f)
self.f = f
def observation(self, observation):
"""Transforms the observations with callable :attr:`f`.
Args:
observation: The observation to transform
Returns:
The transformed observation
"""
return self.f(observation)

View File

@ -0,0 +1,46 @@
"""Wrapper for transforming the reward."""
from typing import Callable
import gymnasium as gym
class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs):
"""Transform the reward via an arbitrary function.
Warning:
If the base environment specifies a reward range which is not invariant under :attr:`f`, the :attr:`reward_range` of the wrapped environment will be incorrect.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformReward
>>> env = gym.make("CartPole-v1")
>>> env = TransformReward(env, lambda r: 0.01*r)
>>> _ = env.reset()
>>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
>>> reward
0.01
"""
def __init__(self, env: gym.Env, f: Callable[[float], float]):
"""Initialize the :class:`TransformReward` wrapper with an environment and reward transform function :attr:`f`.
Args:
env: The environment to apply the wrapper
f: A function that transforms the reward
"""
gym.utils.RecordConstructorArgs.__init__(self, f=f)
gym.RewardWrapper.__init__(self, env)
assert callable(f)
self.f = f
def reward(self, reward):
"""Transforms the reward using callable :attr:`f`.
Args:
reward: The reward to transform
Returns:
The transformed reward
"""
return self.f(reward)

View File

@ -0,0 +1,126 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from typing import List
import gymnasium as gym
class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
Example:
>>> # As dict:
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
>>> # As list:
>>> infos = [
... {
... "episode": {"r": "<cumulative reward>", "l": "<episode length>", "t": "<elapsed time since beginning of episode>"},
... "final_observation": "<observation>",
... "final_info": {},
... },
... ...,
... ]
"""
def __init__(self, env):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
try:
self.get_wrapper_attr("is_vector_env")
except AttributeError:
assert False, "This wrapper can only be used in vectorized environments."
def step(self, action):
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(action)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
def reset(self, **kwargs):
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(**kwargs)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> List[dict]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
"""Process episode statistics.
`RecordEpisodeStatistics` wrapper add extra
information to the info. This information are in
the form of a dict of dict. This method process these
information and add them to the info.
`RecordEpisodeStatistics` info contains the keys
"r", "l", "t" which represents "cumulative reward",
"episode length", "elapsed time since instantiation of wrapper".
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
return list_info