I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,164 @@
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.experimental.wrappers import vector
from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0
from gymnasium.experimental.wrappers.common import (
AutoresetV0,
OrderEnforcingV0,
PassiveEnvCheckerV0,
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.lambda_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers.lambda_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
PixelObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
)
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
from gymnasium.experimental.wrappers.rendering import (
HumanRenderingV0,
RecordVideoV0,
RenderCollectionV0,
)
from gymnasium.experimental.wrappers.stateful_action import StickyActionV0
from gymnasium.experimental.wrappers.stateful_observation import (
DelayObservationV0,
FrameStackObservationV0,
MaxAndSkipObservationV0,
NormalizeObservationV0,
TimeAwareObservationV0,
)
from gymnasium.experimental.wrappers.stateful_reward import NormalizeRewardV1
# Todo - Add legacy wrapper to new wrapper error for users when merged into gymnasium.wrappers
__all__ = [
"vector",
# --- Observation wrappers ---
"AtariPreprocessingV0",
"DelayObservationV0",
"DtypeObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"FrameStackObservationV0",
"GrayscaleObservationV0",
"LambdaObservationV0",
"MaxAndSkipObservationV0",
"NormalizeObservationV0",
"PixelObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"TimeAwareObservationV0",
# --- Action Wrappers ---
"ClipActionV0",
"LambdaActionV0",
"RescaleActionV0",
# "NanAction",
"StickyActionV0",
# --- Reward wrappers ---
"ClipRewardV0",
"LambdaRewardV0",
"NormalizeRewardV1",
# --- Common ---
"AutoresetV0",
"PassiveEnvCheckerV0",
"OrderEnforcingV0",
"RecordEpisodeStatisticsV0",
# --- Rendering ---
"RenderCollectionV0",
"RecordVideoV0",
"HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpyV0": "jax_to_numpy",
"JaxToTorchV0": "jax_to_torch",
"NumpyToTorchV0": "numpy_to_torch",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@ -0,0 +1,206 @@
"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018."""
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box
__all__ = ["AtariPreprocessingV0"]
class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Atari 2600 preprocessing wrapper.
This class follows the guidelines in Machado et al. (2018),
"Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents".
Specifically, the following preprocess stages applies to the atari environment:
- Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops.
- Frame skipping: The number of frames skipped between steps, 4 by default
- Max-pooling: Pools over the most recent two observations from the frame skips
- Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated.
Turned off by default. Not recommended by Machado et al. (2018).
- Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default
- Grayscale observation: If the observation is colour or greyscale, by default, greyscale.
- Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled.
"""
def __init__(
self,
env: gym.Env,
noop_max: int = 30,
frame_skip: int = 4,
screen_size: int = 84,
terminal_on_life_loss: bool = False,
grayscale_obs: bool = True,
grayscale_newaxis: bool = False,
scale_obs: bool = False,
):
"""Wrapper for Atari 2600 preprocessing.
Args:
env (Env): The environment to apply the preprocessing
noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0.
frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game.
screen_size (int): resize Atari frame
terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a
life is lost.
grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation
is returned.
grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to
grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
optimization benefits of FrameStack Wrapper.
Raises:
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
"""
gym.utils.RecordConstructorArgs.__init__(
self,
noop_max=noop_max,
frame_skip=frame_skip,
screen_size=screen_size,
terminal_on_life_loss=terminal_on_life_loss,
grayscale_obs=grayscale_obs,
grayscale_newaxis=grayscale_newaxis,
scale_obs=scale_obs,
)
gym.Wrapper.__init__(self, env)
try:
import cv2 # noqa: F401
except ImportError:
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
if frame_skip > 1:
if (
env.spec is not None
and "NoFrameskip" not in env.spec.id
and getattr(env.unwrapped, "_frameskip", None) != 1
):
raise ValueError(
"Disable frame-skipping in the original env. Otherwise, more than one "
"frame-skip will happen as through this wrapper"
)
self.noop_max = noop_max
assert env.unwrapped.get_action_meanings()[0] == "NOOP"
self.frame_skip = frame_skip
self.screen_size = screen_size
self.terminal_on_life_loss = terminal_on_life_loss
self.grayscale_obs = grayscale_obs
self.grayscale_newaxis = grayscale_newaxis
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs:
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
]
else:
self.obs_buffer = [
np.empty(env.observation_space.shape, dtype=np.uint8),
np.empty(env.observation_space.shape, dtype=np.uint8),
]
self.lives = 0
self.game_over = False
_low, _high, _obs_dtype = (
(0, 255, np.uint8) if not scale_obs else (0, 1, np.float32)
)
_shape = (screen_size, screen_size, 1 if grayscale_obs else 3)
if grayscale_obs and not grayscale_newaxis:
_shape = _shape[:-1] # Remove channel axis
self.observation_space = Box(
low=_low, high=_high, shape=_shape, dtype=_obs_dtype
)
@property
def ale(self):
"""Make ale as a class property to avoid serialization error."""
return self.env.unwrapped.ale
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward, terminated, truncated, info = 0.0, False, False, {}
for t in range(self.frame_skip):
_, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
self.game_over = terminated
if self.terminal_on_life_loss:
new_lives = self.ale.lives()
terminated = terminated or new_lives < self.lives
self.game_over = terminated
self.lives = new_lives
if terminated or truncated:
break
if t == self.frame_skip - 2:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[1])
else:
self.ale.getScreenRGB(self.obs_buffer[1])
elif t == self.frame_skip - 1:
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
return self._get_obs(), total_reward, terminated, truncated, info
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
# NoopReset
_, reset_info = self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
if self.noop_max > 0
else 0
)
for _ in range(noops):
_, _, terminated, truncated, step_info = self.env.step(0)
reset_info.update(step_info)
if terminated or truncated:
_, reset_info = self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
self.ale.getScreenGrayscale(self.obs_buffer[0])
else:
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
return self._get_obs(), reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling
np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0])
import cv2
obs = cv2.resize(
self.obs_buffer[0],
(self.screen_size, self.screen_size),
interpolation=cv2.INTER_AREA,
)
if self.scale_obs:
obs = np.asarray(obs, dtype=np.float32) / 255.0
else:
obs = np.asarray(obs, dtype=np.uint8)
if self.grayscale_obs and self.grayscale_newaxis:
obs = np.expand_dims(obs, axis=-1) # Add a channel axis
return obs

View File

@ -0,0 +1,315 @@
"""A collection of common wrappers.
* ``AutoresetV0`` - Auto-resets the environment
* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data
* ``OrderEnforcingV0`` - Enforces the order of function calls to environments
* ``RecordEpisodeStatisticsV0`` - Records the episode statistics
"""
from __future__ import annotations
import time
from collections import deque
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
check_observation_space,
env_render_passive_checker,
env_reset_passive_checker,
env_step_passive_checker,
)
__all__ = [
"AutoresetV0",
"PassiveEnvCheckerV0",
"OrderEnforcingV0",
"RecordEpisodeStatisticsV0",
]
class AutoresetV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`."""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.
Args:
env (gym.Env): The environment to apply the wrapper
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._episode_ended: bool = False
self._reset_options: dict[str, Any] | None = None
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step.
Args:
action: The action to take
Returns:
The autoreset environment :meth:`step`
"""
if self._episode_ended:
obs, info = self.env.reset(options=self._reset_options)
self._episode_ended = True
return obs, 0, False, False, info
else:
obs, reward, terminated, truncated, info = super().step(action)
self._episode_ended = terminated or truncated
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, saving the options used."""
self._episode_ended = False
self._reset_options = options
return super().reset(seed=seed, options=self._reset_options)
class PassiveEnvCheckerV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API."""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialises the wrapper with the environments, run the observation and action space tests."""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert hasattr(
env, "action_space"
), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_action_space(env.action_space)
assert hasattr(
env, "observation_space"
), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/"
check_observation_space(env.observation_space)
self._checked_reset: bool = False
self._checked_step: bool = False
self._checked_render: bool = False
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment that on the first call will run the `passive_env_step_check`."""
if self._checked_step is False:
self._checked_step = True
return env_step_passive_checker(self.env, action)
else:
return self.env.step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment that on the first call will run the `passive_env_reset_check`."""
if self._checked_reset is False:
self._checked_reset = True
return env_reset_passive_checker(self.env, seed=seed, options=options)
else:
return self.env.reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment that on the first call will run the `passive_env_render_check`."""
if self._checked_render is False:
self._checked_render = True
return env_render_passive_checker(self.env)
else:
return self.env.render()
class OrderEnforcingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcingV0(env)
>>> env.step(0)
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render()
Traceback (most recent call last):
...
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
>>> _ = env.reset()
>>> env.render()
>>> _ = env.step(0)
>>> env.close()
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
disable_render_order_enforcing: bool = False,
):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Args:
env: The environment to wrap
disable_render_order_enforcing: If to disable render order enforcing
"""
gym.utils.RecordConstructorArgs.__init__(
self, disable_render_order_enforcing=disable_render_order_enforcing
)
gym.Wrapper.__init__(self, env)
self._has_reset: bool = False
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Steps through the environment."""
if not self._has_reset:
raise ResetNeeded("Cannot call env.step() before calling env.reset()")
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment with `kwargs`."""
self._has_reset = True
return super().reset(seed=seed, options=options)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Renders the environment with `kwargs`."""
if not self._disable_render_order_enforcing and not self._has_reset:
raise ResetNeeded(
"Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return super().render()
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset
class RecordEpisodeStatisticsV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes
episode_length_buffer: The lengths of the last ``deque_size``-many episodes
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self._stats_key = stats_key
self.episode_count = 0
self.episode_start_time: float = -1
self.episode_reward: float = -1
self.episode_length: int = -1
self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length)
self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length)
self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, recording the episode statistics."""
obs, reward, terminated, truncated, info = super().step(action)
self.episode_reward += reward
self.episode_length += 1
if terminated or truncated:
assert self._stats_key not in info
episode_time_length = np.round(
time.perf_counter() - self.episode_start_time, 6
)
info[self._stats_key] = {
"r": self.episode_reward,
"l": self.episode_length,
"t": episode_time_length,
}
self.episode_time_length_buffer.append(episode_time_length)
self.episode_reward_buffer.append(self.episode_reward)
self.episode_length_buffer.append(self.episode_length)
self.episode_count += 1
return obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment using seed and options and resets the episode rewards and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_time = time.perf_counter()
self.episode_reward = 0
self.episode_length = 0
return obs, info

View File

@ -0,0 +1,162 @@
"""Helper functions and wrapper class for converting between numpy and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
import jax
import jax.numpy as jnp
except ImportError:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`"
)
__all__ = ["JaxToNumpyV0", "jax_to_numpy", "numpy_to_jax"]
@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax Array."""
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)
@numpy_to_jax.register(numbers.Number)
def _number_to_jax(
value: numbers.Number,
) -> jax.Array:
"""Converts a number (int, float, etc.) to a Jax Array."""
assert jnp is not None
return jnp.array(value)
@numpy_to_jax.register(np.ndarray)
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
assert jnp is not None
return jnp.array(value, dtype=value.dtype)
@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})
@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
return type(value)(numpy_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)
@jax_to_numpy.register(jax.Array)
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
"""Converts a Jax Array to a numpy array."""
return np.array(value)
@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jax.Array | Any]
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})
@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
return type(value)(jax_to_numpy(v) for v in value)
class JaxToNumpyV0(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Wraps a jax environment so that it can be interacted with through numpy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
Notes:
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Wraps a jax environment such that the input and outputs are numpy arrays.
Args:
env: the jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Transforms the action to a jax array .
Args:
action: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_numpy(obs),
float(reward),
bool(terminated),
bool(truncated),
jax_to_numpy(info),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())

View File

@ -0,0 +1,179 @@
# This wrapper will convert torch inputs for the actions and observations to Jax arrays
# for an underlying Jax environment then convert the return observations from Jax arrays
# back to torch tensors.
#
# Functionality for converting between torch and jax types originally copied from
# https://github.com/google/brax/blob/9d6b7ced2a13da0d074b5e9fbd3aad8311e26997/brax/io/torch.py
# Under the Apache 2.0 license. Copyright is held by the authors
"""Helper functions and wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import gymnasium as gym
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy
try:
import jax
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
except ImportError:
raise DependencyNotInstalled(
"Jax is not installed therefore cannot call `torch_to_jax`, run `pip install gymnasium[jax]`"
)
try:
import torch
from torch.utils import dlpack as torch_dlpack
Device = Union[str, torch.device]
except ImportError:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_jax`, run `pip install torch`"
)
__all__ = ["JaxToTorchV0", "jax_to_torch", "torch_to_jax", "Device"]
@functools.singledispatch
def torch_to_jax(value: Any) -> Any:
"""Converts a PyTorch Tensor into a Jax Array."""
raise Exception(
f"No known conversion for Torch type ({type(value)}) to Jax registered. Report as issue on github."
)
@torch_to_jax.register(numbers.Number)
def _number_torch_to_jax(value: numbers.Number) -> Any:
"""Convert a python number (int, float, complex) to a jax array."""
return jnp.array(value)
@torch_to_jax.register(torch.Tensor)
def _tensor_torch_to_jax(value: torch.Tensor) -> jax.Array:
"""Converts a PyTorch Tensor into a Jax Array."""
tensor = torch_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = jax_dlpack.from_dlpack(tensor) # pyright: ignore[reportPrivateImportUsage]
return tensor
@torch_to_jax.register(abc.Mapping)
def _mapping_torch_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
return type(value)(**{k: torch_to_jax(v) for k, v in value.items()})
@torch_to_jax.register(abc.Iterable)
def _iterable_torch_to_jax(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
return type(value)(torch_to_jax(v) for v in value)
@functools.singledispatch
def jax_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@jax_to_torch.register(jax.Array)
def _devicearray_jax_to_torch(
value: jax.Array, device: Device | None = None
) -> torch.Tensor:
"""Converts a Jax Array into a PyTorch Tensor."""
assert jax_dlpack is not None and torch_dlpack is not None
dlpack = jax_dlpack.to_dlpack(value) # pyright: ignore[reportPrivateImportUsage]
tensor = torch_dlpack.from_dlpack(dlpack)
if device:
return tensor.to(device=device)
return tensor
@jax_to_torch.register(abc.Mapping)
def _jax_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: jax_to_torch(v, device) for k, v in value.items()})
@jax_to_torch.register(abc.Iterable)
def _jax_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
return type(value)(jax_to_torch(v, device) for v in value)
class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a Jax-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Performs the given action within the environment.
Args:
action: The action to perform as a PyTorch Tensor
Returns:
The next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
jax_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)
def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a NumPy array."""
return jax_to_numpy(self.env.render())

View File

@ -0,0 +1,178 @@
"""A collection of wrappers that all use the LambdaAction class.
* ``LambdaActionV0`` - Transforms the actions based on a function
* ``ClipActionV0`` - Clips the action within a bounds
* ``RescaleActionV0`` - Rescales the action within a minimum and maximum actions
"""
from __future__ import annotations
from typing import Callable
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType
from gymnasium.spaces import Box, Space
__all__ = ["LambdaActionV0", "ClipActionV0", "RescaleActionV0"]
class LambdaActionV0(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
def __init__(
self,
env: gym.Env[ObsType, ActType],
func: Callable[[WrapperActType], ActType],
action_space: Space[WrapperActType] | None,
):
"""Initialize LambdaAction.
Args:
env: The environment to wrap
func: Function to apply to the :meth:`step`'s ``action``
action_space: The updated action space of the wrapper given the function.
"""
gym.utils.RecordConstructorArgs.__init__(
self, func=func, action_space=action_space
)
gym.Wrapper.__init__(self, env)
if action_space is not None:
self.action_space = action_space
self.func = func
def action(self, action: WrapperActType) -> ActType:
"""Apply function to action."""
return self.func(action)
class ClipActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ClipActionV0
>>> import numpy as np
>>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> env = ClipActionV0(env)
>>> env.action_space
Box(-inf, inf, (3,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.step(np.array([5.0, -2.0, 0.0], dtype=np.float32))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to wrap
"""
assert isinstance(env.action_space, Box)
gym.utils.RecordConstructorArgs.__init__(self)
LambdaActionV0.__init__(
self,
env=env,
func=lambda action: np.clip(
action, env.action_space.low, env.action_space.high
),
action_space=Box(
-np.inf,
np.inf,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)
class RescaleActionV0(
LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleActionV0
>>> import numpy as np
>>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1, 1, 1], dtype=np.float32))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 0.75], dtype=np.float32)
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs)
True
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Constructor for the Rescale Action wrapper.
Args:
env (Env): The environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)
if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(min_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)
assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)
if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)
assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
)
intercept = gradient * -min_action + env.action_space.low
LambdaActionV0.__init__(
self,
env=env,
func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)

View File

@ -0,0 +1,620 @@
"""A collection of observation wrappers using a lambda function.
* ``LambdaObservationV0`` - Transforms the observation with a function
* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys
* ``FlattenObservationV0`` - Flattens the observations
* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation
* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation)
* ``ReshapeObservationV0`` - Reshapes an array-based observation
* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value
* ``DtypeObservationV0`` - Convert an observation to a dtype
* ``PixelObservationV0`` - Allows the observation to the rendered frame
"""
from __future__ import annotations
from typing import Any, Callable, Final, Sequence
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
__all__ = [
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
"PixelObservationV0",
]
class LambdaObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaObservationV0
>>> import numpy as np
>>> np.random.seed(0)
>>> env = gym.make("CartPole-v1")
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
>>> env.reset(seed=42)
(array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {})
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
func: Callable[[ObsType], Any],
observation_space: gym.Space[WrapperObsType] | None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The environment to wrap
func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
gym.utils.RecordConstructorArgs.__init__(
self, func=func, observation_space=observation_space
)
gym.ObservationWrapper.__init__(self, env)
if observation_space is not None:
self.observation_space = observation_space
self.func = func
def observation(self, observation: ObsType) -> Any:
"""Apply function to the observation."""
return self.func(observation)
class FilterObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Filters Dict or Tuple observation space by the keys or indexes.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformObservation
>>> from gymnasium.experimental.wrappers import FilterObservationV0
>>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservationV0(env, filter_keys=['time'])
>>> env.reset(seed=42)
({'time': 0}, {})
>>> env.step(0)
({'time': 0}, 1.0, False, False, {})
"""
def __init__(
self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int]
):
"""Constructor for the filter observation wrapper.
Args:
env: The environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
assert isinstance(filter_keys, Sequence)
gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys)
# Filters for dictionary space
if isinstance(env.observation_space, spaces.Dict):
assert all(isinstance(key, str) for key in filter_keys)
if any(
key not in env.observation_space.spaces.keys() for key in filter_keys
):
missing_keys = [
key
for key in filter_keys
if key not in env.observation_space.spaces.keys()
]
raise ValueError(
"All the `filter_keys` must be included in the observation space.\n"
f"Filter keys: {filter_keys}\n"
f"Observation keys: {list(env.observation_space.spaces.keys())}\n"
f"Missing keys: {missing_keys}"
)
new_observation_space = spaces.Dict(
{key: env.observation_space[key] for key in filter_keys}
)
if len(new_observation_space) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {key: obs[key] for key in filter_keys},
observation_space=new_observation_space,
)
# Filter for tuple observation
elif isinstance(env.observation_space, spaces.Tuple):
assert all(isinstance(key, int) for key in filter_keys)
assert len(set(filter_keys)) == len(
filter_keys
), f"Duplicate keys exist, filter_keys: {filter_keys}"
if any(
0 < key and key >= len(env.observation_space) for key in filter_keys
):
missing_index = [
key
for key in filter_keys
if 0 < key and key >= len(env.observation_space)
]
raise ValueError(
"All the `filter_keys` must be included in the length of the observation space.\n"
f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, "
f"missing indexes: {missing_index}"
)
new_observation_spaces = spaces.Tuple(
env.observation_space[key] for key in filter_keys
)
if len(new_observation_spaces) == 0:
raise ValueError(
"The observation space is empty due to filtering all keys."
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: tuple(obs[key] for key in filter_keys),
observation_space=new_observation_spaces,
)
else:
raise ValueError(
f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}"
)
self.filter_keys: Final[Sequence[str | int]] = filter_keys
class FlattenObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that flattens the observation.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import FlattenObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> env = FlattenObservationV0(env)
>>> env.observation_space.shape
(27648,)
>>> obs, _ = env.reset()
>>> obs.shape
(27648,)
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The environment to wrap
"""
gym.utils.RecordConstructorArgs.__init__(self)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: spaces.utils.flatten(env.observation_space, obs),
observation_space=spaces.utils.flatten_space(env.observation_space),
)
class GrayscaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that converts an RGB image to grayscale.
The :attr:`keep_dim` will keep the channel dimension
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import GrayscaleObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> grayscale_env = GrayscaleObservationV0(env)
>>> grayscale_env.observation_space.shape
(96, 96)
>>> grayscale_env = GrayscaleObservationV0(env, keep_dim=True)
>>> grayscale_env.observation_space.shape
(96, 96, 1)
"""
def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
assert isinstance(env.observation_space, spaces.Box)
assert (
len(env.observation_space.shape) == 3
and env.observation_space.shape[-1] == 3
)
assert (
np.all(env.observation_space.low == 0)
and np.all(env.observation_space.high == 255)
and env.observation_space.dtype == np.uint8
)
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
self.keep_dim: Final[bool] = keep_dim
if keep_dim:
new_observation_space = spaces.Box(
low=0,
high=255,
shape=env.observation_space.shape[:2] + (1,),
dtype=np.uint8,
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.expand_dims(
np.sum(
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
axis=-1,
),
observation_space=new_observation_space,
)
else:
new_observation_space = spaces.Box(
low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.sum(
np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1
).astype(np.uint8),
observation_space=new_observation_space,
)
class ResizeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Resizes image observations using OpenCV to shape.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ResizeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> resized_env = ResizeObservationV0(env, (32, 32))
>>> resized_env.observation_space.shape
(32, 32, 3)
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The environment to wrap
shape: The resized observation shape
"""
assert isinstance(env.observation_space, spaces.Box)
assert len(env.observation_space.shape) in [2, 3]
assert np.all(env.observation_space.low == 0) and np.all(
env.observation_space.high == 255
)
assert env.observation_space.dtype == np.uint8
assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 for x in shape)
try:
import cv2
except ImportError as e:
raise DependencyNotInstalled(
"opencv (cv2) is not installed, run `pip install gymnasium[other]`"
) from e
self.shape: Final[tuple[int, ...]] = tuple(shape)
new_observation_space = spaces.Box(
low=0,
high=255,
shape=self.shape + env.observation_space.shape[2:],
dtype=np.uint8,
)
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA),
observation_space=new_observation_space,
)
class ReshapeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Reshapes array based observations to shapes.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ReshapeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape
(96, 96, 3)
>>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3))
>>> reshape_env.observation_space.shape
(24, 4, 96, 1, 3)
"""
def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]):
"""Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product.
Args:
env: The environment to wrap
shape: The reshaped observation space
"""
assert isinstance(env.observation_space, spaces.Box)
assert np.product(shape) == np.product(env.observation_space.shape)
assert isinstance(shape, tuple)
assert all(np.issubdtype(type(elem), np.integer) for elem in shape)
assert all(x > 0 or x == -1 for x in shape)
new_observation_space = spaces.Box(
low=np.reshape(np.ravel(env.observation_space.low), shape),
high=np.reshape(np.ravel(env.observation_space.high), shape),
shape=shape,
dtype=env.observation_space.dtype,
)
self.shape = shape
gym.utils.RecordConstructorArgs.__init__(self, shape=shape)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: np.reshape(obs, shape),
observation_space=new_observation_space,
)
class RescaleObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Linearly rescales observation to between a minimum and maximum value.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleObservationV0
>>> env = gym.make("Pendulum-v1")
>>> env.observation_space
Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32))
>>> env.observation_space
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)
if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)
if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)
self.min_obs = min_obs
self.max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (max_obs - min_obs) / (
env.observation_space.high - env.observation_space.low
)
intercept = gradient * -env.observation_space.low + min_obs
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
)
class DtypeObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper for transforming the dtype of an observation.
Note:
This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces
"""
def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The environment to wrap
dtype: The new dtype of the observation
"""
assert isinstance(
env.observation_space,
(spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
)
self.dtype = dtype
if isinstance(env.observation_space, spaces.Box):
new_observation_space = spaces.Box(
low=env.observation_space.low,
high=env.observation_space.high,
shape=env.observation_space.shape,
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.Discrete):
new_observation_space = spaces.Box(
low=env.observation_space.start,
high=env.observation_space.start + env.observation_space.n,
shape=(),
dtype=self.dtype,
)
elif isinstance(env.observation_space, spaces.MultiDiscrete):
new_observation_space = spaces.MultiDiscrete(
env.observation_space.nvec, dtype=dtype
)
elif isinstance(env.observation_space, spaces.MultiBinary):
new_observation_space = spaces.Box(
low=0,
high=1,
shape=env.observation_space.shape,
dtype=self.dtype,
)
else:
raise TypeError(
"DtypeObservation is only compatible with value / array-based observations."
)
gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: dtype(obs),
observation_space=new_observation_space,
)
class PixelObservationV0(
LambdaObservationV0[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Includes the rendered observations to the environment's observations.
Observations of this wrapper will be dictionaries of images.
You can also choose to add the observation of the base environment to this dictionary.
In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary
of rendered images will be updated with the base environment's observation. If, however, the observation
space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box`
space) will be added to the dictionary under the key "state".
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
pixels_only: bool = True,
pixels_key: str = "pixels",
obs_key: str = "state",
):
"""Constructor of the pixel observation wrapper.
Args:
env: The environment to wrap.
pixels_only (bool): If ``True`` (default), the original observation returned
by the wrapped environment will be discarded, and a dictionary
observation will only include pixels. If ``False``, the
observation dictionary will contain both the original
observations and the pixel observations.
pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels"
obs_key: Optional custom string specifying the obs key. Defaults to "state"
"""
gym.utils.RecordConstructorArgs.__init__(
self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key
)
assert env.render_mode is not None and env.render_mode != "human"
env.reset()
pixels = env.render()
assert pixels is not None and isinstance(pixels, np.ndarray)
pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8)
if pixels_only:
obs_space = pixel_space
LambdaObservationV0.__init__(
self, env=env, func=lambda _: self.render(), observation_space=obs_space
)
elif isinstance(env.observation_space, spaces.Dict):
assert pixels_key not in env.observation_space.spaces.keys()
obs_space = spaces.Dict(
{pixels_key: pixel_space, **env.observation_space.spaces}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {pixels_key: self.render(), **obs_space},
observation_space=obs_space,
)
else:
obs_space = spaces.Dict(
{obs_key: env.observation_space, pixels_key: pixel_space}
)
LambdaObservationV0.__init__(
self,
env=env,
func=lambda obs: {obs_key: obs, pixels_key: self.render()},
observation_space=obs_space,
)

View File

@ -0,0 +1,102 @@
"""A collection of wrappers for modifying the reward.
* ``LambdaRewardV0`` - Transforms the reward by a function
* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value
"""
from __future__ import annotations
from typing import Callable, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidBound
__all__ = ["LambdaRewardV0", "ClipRewardV0"]
class LambdaRewardV0(
gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""A reward wrapper that allows a custom function to modify the step reward.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(0)
>>> rew
3.0
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
func: Callable[[SupportsFloat], SupportsFloat],
):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The environment to wrap
func: (Callable): The function to apply to reward
"""
gym.utils.RecordConstructorArgs.__init__(self, func=func)
gym.RewardWrapper.__init__(self, env)
self.func = func
def reward(self, reward: SupportsFloat) -> SupportsFloat:
"""Apply function to reward.
Args:
reward (Union[float, int, np.ndarray]): environment's reward
"""
return self.func(reward)
class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = ClipRewardV0(env, 0, 0.5)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew
0.5
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Initialize ClipRewardsV0 wrapper.
Args:
env (Env): The environment to wrap
min_reward (Union[float, np.ndarray]): lower bound to apply
max_reward (Union[float, np.ndarray]): higher bound to apply
"""
if min_reward is None and max_reward is None:
raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
elif max_reward is not None and min_reward is not None:
if np.any(max_reward - min_reward < 0):
raise InvalidBound(
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
)
gym.utils.RecordConstructorArgs.__init__(
self, min_reward=min_reward, max_reward=max_reward
)
LambdaRewardV0.__init__(
self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward)
)

View File

@ -0,0 +1,148 @@
"""Helper functions and wrapper class for converting between PyTorch and NumPy."""
from __future__ import annotations
import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat, Union
import numpy as np
import gymnasium as gym
from gymnasium.core import WrapperActType, WrapperObsType
from gymnasium.error import DependencyNotInstalled
try:
import torch
Device = Union[str, torch.device]
except ImportError:
raise DependencyNotInstalled(
"Torch is not installed therefore cannot call `torch_to_numpy`, run `pip install torch`"
)
__all__ = ["NumpyToTorchV0", "torch_to_numpy", "numpy_to_torch"]
@functools.singledispatch
def torch_to_numpy(value: Any) -> Any:
"""Converts a PyTorch Tensor into a NumPy Array."""
raise Exception(
f"No known conversion for Torch type ({type(value)}) to NumPy registered. Report as issue on github."
)
@torch_to_numpy.register(numbers.Number)
@torch_to_numpy.register(torch.Tensor)
def _number_torch_to_numpy(value: numbers.Number | torch.Tensor) -> Any:
"""Convert a python number (int, float, complex) and torch.Tensor to a numpy array."""
return np.array(value)
@torch_to_numpy.register(abc.Mapping)
def _mapping_torch_to_numpy(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
return type(value)(**{k: torch_to_numpy(v) for k, v in value.items()})
@torch_to_numpy.register(abc.Iterable)
def _iterable_torch_to_numpy(value: Iterable[Any]) -> Iterable[Any]:
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
return type(value)(torch_to_numpy(v) for v in value)
@functools.singledispatch
def numpy_to_torch(value: Any, device: Device | None = None) -> Any:
"""Converts a Jax Array into a PyTorch Tensor."""
raise Exception(
f"No known conversion for NumPy type ({type(value)}) to PyTorch registered. Report as issue on github."
)
@numpy_to_torch.register(np.ndarray)
def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor:
"""Converts a Jax Array into a PyTorch Tensor."""
assert torch is not None
tensor = torch.tensor(value)
if device:
return tensor.to(device=device)
return tensor
@numpy_to_torch.register(abc.Mapping)
def _numpy_mapping_to_torch(
value: Mapping[str, Any], device: Device | None = None
) -> Mapping[str, Any]:
"""Converts a mapping of Jax Array into a Dictionary of PyTorch Tensors."""
return type(value)(**{k: numpy_to_torch(v, device) for k, v in value.items()})
@numpy_to_torch.register(abc.Iterable)
def _numpy_iterable_to_torch(
value: Iterable[Any], device: Device | None = None
) -> Iterable[Any]:
"""Converts an Iterable from Jax Array to an iterable of PyTorch Tensors."""
return type(value)(numpy_to_torch(v, device) for v in value)
class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors.
Note:
For ``rendered`` this is returned as a NumPy array not a pytorch Tensor.
"""
def __init__(self, env: gym.Env, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based environment to wrap
device: The device the torch Tensors should be moved to
"""
gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
self.device: Device | None = device
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
float(reward),
bool(terminated),
bool(truncated),
numpy_to_torch(info, self.device),
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@ -0,0 +1,437 @@
"""A collections of rendering-based wrappers.
* ``RenderCollectionV0`` - Collects rendered frames into a list
* ``RecordVideoV0`` - Records a video of the environments
* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"``
"""
from __future__ import annotations
import os
from copy import deepcopy
from typing import Any, Callable, List, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium import error, logger
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.error import DependencyNotInstalled
__all__ = ["RenderCollectionV0", "RecordVideoV0", "HumanRenderingV0"]
class RenderCollectionV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``."""
def __init__(
self,
env: gym.Env[ObsType, ActType],
pop_frames: bool = True,
reset_clean: bool = True,
):
"""Initialize a :class:`RenderCollection` instance.
Args:
env: The environment that is being wrapped
pop_frames (bool): If true, clear the collection frames after ``meth:render`` is called. Default value is ``True``.
reset_clean (bool): If true, clear the collection frames when ``meth:reset`` is called. Default value is ``True``.
"""
gym.utils.RecordConstructorArgs.__init__(
self, pop_frames=pop_frames, reset_clean=reset_clean
)
gym.Wrapper.__init__(self, env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
self.frame_list: list[RenderFrame] = []
self.pop_frames = pop_frames
self.reset_clean = reset_clean
self.metadata = deepcopy(self.env.metadata)
if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]:
self.metadata["render_modes"].append(f"{self.env.render_mode}_list")
@property
def render_mode(self):
"""Returns the collection render_mode name."""
return f"{self.env.render_mode}_list"
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Perform a step in the base environment and collect a frame."""
output = super().step(action)
self.frame_list.append(super().render())
return output
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
output = super().reset(seed=seed, options=options)
if self.reset_clean:
self.frame_list = []
self.frame_list.append(super().render())
return output
def render(self) -> list[RenderFrame]:
"""Returns the collection of frames and, if pop_frames = True, clears it."""
frames = self.frame_list
if self.pop_frames:
self.frame_list = []
return frames
class RecordVideoV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode.
To do this, you can specify ``episode_trigger`` or ``step_trigger``.
They should be functions returning a boolean that indicates whether a recording should be started at the
current episode or step, respectively.
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed,
i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and
then every 1000 episodes.
By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed
length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``.
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
video_folder: str,
episode_trigger: Callable[[int], bool] | None = None,
step_trigger: Callable[[int], bool] | None = None,
video_length: int = 0,
name_prefix: str = "rl-video",
fps: int | None = None,
disable_logger: bool = False,
):
"""Wrapper records videos of rollouts.
Args:
env: The environment that will be wrapped
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
Otherwise, snippets of the specified length are captured
name_prefix (str): Will be prepended to the filename of the recordings
fps (int): The frame per second in the video. The default value is the one specified in the environment metadata.
If the environment metadata doesn't specify `render_fps`, the value 30 is used.
disable_logger (bool): Whether to disable moviepy logger or not
"""
gym.utils.RecordConstructorArgs.__init__(
self,
video_folder=video_folder,
episode_trigger=episode_trigger,
step_trigger=step_trigger,
video_length=video_length,
name_prefix=name_prefix,
disable_logger=disable_logger,
)
gym.Wrapper.__init__(self, env)
if env.render_mode in {None, "human", "ansi"}:
raise ValueError(
f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.",
"Initialize your environment with a render_mode that returns an image, such as rgb_array.",
)
if episode_trigger is None and step_trigger is None:
def capped_cubic_video_schedule(episode_id: int) -> bool:
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0
episode_trigger = capped_cubic_video_schedule
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.disable_logger = disable_logger
self.video_folder = os.path.abspath(video_folder)
if os.path.isdir(self.video_folder):
logger.warn(
f"Overwriting existing videos at {self.video_folder} folder "
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)
if fps is None:
fps = self.metadata.get("render_fps", 30)
self.frames_per_sec: int = fps
self.name_prefix: str = name_prefix
self._video_name: str | None = None
self.video_length: int = video_length if video_length != 0 else float("inf")
self.recording: bool = False
self.recorded_frames: list[RenderFrame] = []
self.render_history: list[RenderFrame] = []
self.step_id = -1
self.episode_id = -1
try:
import moviepy # noqa: F401
except ImportError as e:
raise error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
) from e
def _capture_frame(self):
assert self.recording, "Cannot capture a frame, recording wasn't started."
frame = self.env.render()
if isinstance(frame, List):
if len(frame) == 0: # render was called
return
self.render_history += frame
frame = frame[-1]
if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
else:
self.stop_recording()
logger.warn(
"Recording stopped: expected type of frame returned by render ",
f"to be a numpy array, got instead {type(frame)}.",
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment and eventually starts a new recording."""
obs, info = super().reset(seed=seed, options=options)
self.episode_id += 1
if self.recording and self.video_length == float("inf"):
self.stop_recording()
if self.episode_trigger and self.episode_trigger(self.episode_id):
self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}")
if self.recording:
self._capture_frame()
if len(self.recorded_frames) > self.video_length:
self.stop_recording()
return obs, info
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
obs, rew, terminated, truncated, info = self.env.step(action)
self.step_id += 1
if self.step_trigger and self.step_trigger(self.step_id):
self.start_recording(f"{self.name_prefix}-step-{self.step_id}")
if self.recording:
self._capture_frame()
if len(self.recorded_frames) > self.video_length:
self.stop_recording()
return obs, rew, terminated, truncated, info
def start_recording(self, video_name: str):
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
self.stop_recording()
self.recording = True
self._video_name = video_name
def stop_recording(self):
"""Stop current recording and saves the video."""
assert self.recording, "stop_recording was called, but no recording was started"
if len(self.recorded_frames) == 0:
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
raise error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
) from e
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
moviepy_logger = None if self.disable_logger else "bar"
path = os.path.join(self.video_folder, f"{self._video_name}.mp4")
clip.write_videofile(path, logger=moviepy_logger)
self.recorded_frames = []
self.recording = False
self._video_name = None
def render(self) -> RenderFrame | list[RenderFrame]:
"""Compute the render frames as specified by render_mode attribute during initialization of the environment."""
render_out = super().render()
if self.recording and isinstance(render_out, List):
self.recorded_frames += render_out
if len(self.render_history) > 0:
tmp_history = self.render_history
self.render_history = []
return tmp_history + render_out
else:
return render_out
def close(self):
"""Closes the wrapper then the video recorder."""
super().close()
if self.recording:
self.stop_recording()
def __del__(self):
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0:
logger.warn("Unable to save last video! Did you call close()?")
class HumanRenderingV0(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
"""Performs human rendering for an environment that only supports "rgb_array"rendering.
This wrapper is particularly useful when you have implemented an environment that can produce
RGB images but haven't implemented any code to render the images to the screen.
If you want to use this wrapper with your environments, remember to specify ``"render_fps"``
in the metadata of your environment.
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import HumanRenderingV0
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRenderingV0(env)
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
>>> env = gym.make("phys2d/CartPole-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
>>> obs, _ = env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRenderingV0(env)
>>> obs, _ = wrapped.reset()
>>> env.render() # env.render() will always return an empty list!
[]
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Initialize a :class:`HumanRendering` instance.
Args:
env: The environment that is being wrapped
"""
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
assert env.render_mode in [
"rgb_array",
"rgb_array_list",
], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'"
assert (
"render_fps" in env.metadata
), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper"
self.screen_size = None
self.window = None
self.clock = None
if "human" not in self.metadata["render_modes"]:
self.metadata = deepcopy(self.env.metadata)
self.metadata["render_modes"].append("human")
@property
def render_mode(self):
"""Always returns ``'human'``."""
return "human"
def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]:
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(action)
self._render_frame()
return result
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
self._render_frame()
return result
def render(self) -> None:
"""This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`."""
return None
def _render_frame(self):
"""Fetch the last frame from the base environment and render it to the screen."""
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[box2d]`"
)
if self.env.render_mode == "rgb_array_list":
last_rgb_array = self.env.render()
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render()
else:
raise Exception(
f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}"
)
assert isinstance(last_rgb_array, np.ndarray)
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
if self.screen_size is None:
self.screen_size = rgb_array.shape[:2]
assert (
self.screen_size == rgb_array.shape[:2]
), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}"
if self.window is None:
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(self.screen_size)
if self.clock is None:
self.clock = pygame.time.Clock()
surf = pygame.surfarray.make_surface(rgb_array)
self.window.blit(surf, (0, 0))
pygame.event.pump()
self.clock.tick(self.metadata["render_fps"])
pygame.display.flip()
def close(self):
"""Close the rendering window."""
if self.window is not None:
import pygame
pygame.display.quit()
pygame.quit()
super().close()

View File

@ -0,0 +1,62 @@
"""``StickyAction`` wrapper - There is a probability that the action is taken again."""
from __future__ import annotations
from typing import Any
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import InvalidProbability
__all__ = ["StickyActionV0"]
class StickyActionV0(
gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a probability of repeating the previous action.
This wrapper follows the implementation proposed by `Machado et al., 2018 <https://arxiv.org/pdf/1709.06009.pdf>`_
in Section 5.2 on page 12.
"""
def __init__(
self, env: gym.Env[ObsType, ActType], repeat_action_probability: float
):
"""Initialize StickyAction wrapper.
Args:
env (Env): the wrapped environment
repeat_action_probability (int | float): a probability of repeating the old action.
"""
if not 0 <= repeat_action_probability < 1:
raise InvalidProbability(
f"repeat_action_probability should be in the interval [0,1). Received {repeat_action_probability}"
)
gym.utils.RecordConstructorArgs.__init__(
self, repeat_action_probability=repeat_action_probability
)
gym.ActionWrapper.__init__(self, env)
self.repeat_action_probability = repeat_action_probability
self.last_action: ActType | None = None
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Reset the environment."""
self.last_action = None
return super().reset(seed=seed, options=options)
def action(self, action: ActType) -> ActType:
"""Execute the action."""
if (
self.last_action is not None
and self.np_random.uniform() < self.repeat_action_probability
):
action = self.last_action
self.last_action = action
return action

View File

@ -0,0 +1,505 @@
"""A collection of stateful observation wrappers.
* ``DelayObservationV0`` - A wrapper for delaying the returned observation
* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation
* ``FrameStackObservationV0`` - Frame stack the observations
* ``NormalizeObservationV0`` - Normalized the observations to a mean and
* ``MaxAndSkipObservationV0`` - Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames.
"""
from __future__ import annotations
from collections import deque
from copy import deepcopy
from typing import Any, Final, SupportsFloat
import numpy as np
import gymnasium as gym
import gymnasium.spaces as spaces
from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
)
from gymnasium.experimental.wrappers.utils import RunningMeanStd, create_zero_array
from gymnasium.spaces import Box, Dict, Tuple
__all__ = [
"DelayObservationV0",
"TimeAwareObservationV0",
"FrameStackObservationV0",
"NormalizeObservationV0",
]
class DelayObservationV0(
gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs
):
"""Wrapper which adds a delay to the returned observation.
Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with
the same shape as the observation space.
Example:
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> env.reset(seed=123)
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {})
>>> env = DelayObservationV0(env, delay=2)
>>> env.reset(seed=123)
(array([0., 0., 0., 0.], dtype=float32), {})
>>> env.step(env.action_space.sample())
(array([0., 0., 0., 0.], dtype=float32), 1.0, False, False, {})
>>> env.step(env.action_space.sample())
(array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), 1.0, False, False, {})
Note:
This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature.
"""
def __init__(self, env: gym.Env[ObsType, ActType], delay: int):
"""Initialises the DelayObservation wrapper with an integer.
Args:
env: The environment to wrap
delay: The number of timesteps to delay observations
"""
if not np.issubdtype(type(delay), np.integer):
raise TypeError(
f"The delay is expected to be an integer, actual type: {type(delay)}"
)
if not 0 <= delay:
raise ValueError(
f"The delay needs to be greater than zero, actual value: {delay}"
)
gym.utils.RecordConstructorArgs.__init__(self, delay=delay)
gym.ObservationWrapper.__init__(self, env)
self.delay: Final[int] = int(delay)
self.observation_queue: Final[deque] = deque()
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment, clearing the observation queue."""
self.observation_queue.clear()
return super().reset(seed=seed, options=options)
def observation(self, observation: ObsType) -> ObsType:
"""Return the delayed observation."""
self.observation_queue.append(observation)
if len(self.observation_queue) > self.delay:
return self.observation_queue.popleft()
else:
return create_zero_array(self.observation_space)
class TimeAwareObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""Augment the observation with time information of the episode.
The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1]
otherwise if ``False``, the number of timesteps remaining before truncation occurs is an integer.
For environments with ``Dict`` observation spaces, the time information is automatically
added in the key `"time"` (can be changed through :attr:`dict_time_key`) and for environments with ``Tuple``
observation space, the time information is added as the final element in the tuple.
Otherwise, the observation space is transformed into a ``Dict`` observation space with two keys,
`"obs"` for the base environment's observation and `"time"` for the time information.
To flatten the observation, use the :attr:`flatten` parameter which will use the
:func:`gymnasium.spaces.utils.flatten` function.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0.0, 1.0, (1,), float32))
>>> env.reset(seed=42)[0]
{'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0.], dtype=float32)}
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0]
{'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([0.002], dtype=float32)}
Unnormalize time observation space example:
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservationV0(env, normalize_time=False)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
>>> env.reset(seed=42)[0]
{'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([500], dtype=int32)}
>>> _ = env.action_space.seed(42)[0]
>>> env.step(env.action_space.sample())[0]
{'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([499], dtype=int32)}
Flatten observation space example:
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
>>> env.reset(seed=42)[0]
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ],
dtype=float32)
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0]
array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ],
dtype=float32)
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
flatten: bool = False,
normalize_time: bool = True,
*,
dict_time_key: str = "time",
):
"""Initialize :class:`TimeAwareObservationV0`.
Args:
env: The environment to apply the wrapper
flatten: Flatten the observation to a `Box` of a single dimension
normalize_time: if `True` return time in the range [0,1]
otherwise return time as remaining timesteps before truncation
dict_time_key: For environment with a ``Dict`` observation space, the key for the time space. By default, `"time"`.
"""
gym.utils.RecordConstructorArgs.__init__(
self,
flatten=flatten,
normalize_time=normalize_time,
dict_time_key=dict_time_key,
)
gym.ObservationWrapper.__init__(self, env)
self.flatten: Final[bool] = flatten
self.normalize_time: Final[bool] = normalize_time
# We don't need to keep if a TimeLimit wrapper exists as `spec` will do that work for us now
if env.spec is not None and env.spec.max_episode_steps is not None:
self.max_timesteps = env.spec.max_episode_steps
else:
raise ValueError(
"The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`."
)
self._timesteps: int = 0
# Find the normalized time space
if self.normalize_time:
self._time_preprocess_func = lambda time: np.array(
[time / self.max_timesteps], dtype=np.float32
)
time_space = Box(0.0, 1.0)
else:
self._time_preprocess_func = lambda time: np.array(
[self.max_timesteps - time], dtype=np.int32
)
time_space = Box(0, self.max_timesteps, dtype=np.int32)
# Find the observation space
if isinstance(env.observation_space, Dict):
assert dict_time_key not in env.observation_space.keys()
observation_space = Dict(
{dict_time_key: time_space, **env.observation_space.spaces}
)
self._append_data_func = lambda obs, time: {dict_time_key: time, **obs}
elif isinstance(env.observation_space, Tuple):
observation_space = Tuple(env.observation_space.spaces + (time_space,))
self._append_data_func = lambda obs, time: obs + (time,)
else:
observation_space = Dict(obs=env.observation_space, time=time_space)
self._append_data_func = lambda obs, time: {"obs": obs, "time": time}
# If to flatten the observation space
if self.flatten:
self.observation_space: gym.Space[WrapperObsType] = spaces.flatten_space(
observation_space
)
self._obs_postprocess_func = lambda obs: spaces.flatten(
observation_space, obs
)
else:
self.observation_space: gym.Space[WrapperObsType] = observation_space
self._obs_postprocess_func = lambda obs: obs
def observation(self, observation: ObsType) -> WrapperObsType:
"""Adds to the observation with the current time information.
Args:
observation: The observation to add the time step to
Returns:
The observation with the time information appended to it
"""
return self._obs_postprocess_func(
self._append_data_func(
observation, self._time_preprocess_func(self._timesteps)
)
)
def step(
self, action: ActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, incrementing the time step.
Args:
action: The action to take
Returns:
The environment's step using the action with the next observation containing the timestep info
"""
self._timesteps += 1
return super().step(action)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment setting the time to zero.
Args:
seed: The seed to reset the environment
options: The options used to reset the environment
Returns:
Resets the environment with the initial timestep info added the observation
"""
self._timesteps = 0
return super().reset(seed=seed, options=options)
class FrameStackObservationV0(
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number of stacks is 4, then the returned observation contains
the most recent 4 observations. For environment 'Pendulum-v1', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [4, 3].
Note:
- After :meth:`reset` is called, the frame buffer will be filled with the initial observation.
I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames.
Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import FrameStackObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStackObservationV0(env, 4)
>>> env.observation_space
Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs, _ = env.reset()
>>> obs.shape
(4, 96, 96, 3)
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
stack_size: int,
*,
zeros_obs: ObsType | None = None,
):
"""Observation wrapper that stacks the observations in a rolling manner.
Args:
env: The environment to apply the wrapper
stack_size: The number of frames to stack with zero_obs being used originally.
zeros_obs: Keyword only parameter that allows a custom padding observation at :meth:`reset`
"""
if not np.issubdtype(type(stack_size), np.integer):
raise TypeError(
f"The stack_size is expected to be an integer, actual type: {type(stack_size)}"
)
if not 1 < stack_size:
raise ValueError(
f"The stack_size needs to be greater than one, actual value: {stack_size}"
)
gym.utils.RecordConstructorArgs.__init__(self, stack_size=stack_size)
gym.Wrapper.__init__(self, env)
self.observation_space = batch_space(env.observation_space, n=stack_size)
self.stack_size: Final[int] = stack_size
self.zero_obs: Final[ObsType] = (
zeros_obs if zeros_obs else create_zero_array(env.observation_space)
)
self._stacked_obs = deque(
[self.zero_obs for _ in range(self.stack_size)], maxlen=self.stack_size
)
self._stacked_array = create_empty_array(
env.observation_space, n=self.stack_size
)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, appending the observation to the frame buffer.
Args:
action: The action to step through the environment with
Returns:
Stacked observations, reward, terminated, truncated, and info from the environment
"""
obs, reward, terminated, truncated, info = super().step(action)
self._stacked_obs.append(obs)
updated_obs = deepcopy(
concatenate(
self.env.observation_space, self._stacked_obs, self._stacked_array
)
)
return updated_obs, reward, terminated, truncated, info
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Reset the environment, returning the stacked observation and info.
Args:
seed: The environment seed
options: The reset options
Returns:
The stacked observations and info
"""
obs, info = super().reset(seed=seed, options=options)
for _ in range(self.stack_size - 1):
self._stacked_obs.append(self.zero_obs)
self._stacked_obs.append(obs)
updated_obs = deepcopy(
concatenate(
self.env.observation_space, self._stacked_obs, self._stacked_array
)
)
return updated_obs, info
class NormalizeObservationV0(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called.
If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.ObservationWrapper.__init__(self, env)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the observation statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting
def observation(self, observation: ObsType) -> WrapperObsType:
"""Normalises the observation using the running mean and variance of the observations."""
if self._update_running_mean:
self.obs_rms.update(observation)
return (observation - self.obs_rms.mean) / np.sqrt(
self.obs_rms.var + self.epsilon
)
class MaxAndSkipObservationV0(
gym.Wrapper[WrapperObsType, ActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
"""This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last observations.
Note: This wrapper is based on the wrapper from stable-baselines3: https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html#MaxAndSkipEnv
"""
def __init__(self, env: gym.Env[ObsType, ActType], skip: int = 4):
"""This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last frames.
Args:
env (Env): The environment to apply the wrapper
skip: The number of frames to skip
"""
gym.utils.RecordConstructorArgs.__init__(self, skip=skip)
gym.Wrapper.__init__(self, env)
if not np.issubdtype(type(skip), np.integer):
raise TypeError(
f"The skip is expected to be an integer, actual type: {type(skip)}"
)
if skip < 2:
raise ValueError(
f"The skip value needs to be equal or greater than two, actual value: {skip}"
)
if env.observation_space.shape is None:
raise ValueError("The observation space must have the shape attribute.")
self._skip = skip
self._obs_buffer = np.zeros(
(2, *env.observation_space.shape), dtype=env.observation_space.dtype
)
def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Step the environment with the given action for ``skip`` steps.
Repeat action, sum reward, and max over last observations.
Args:
action: The action to step through the environment with
Returns:
Max of the last two observations, reward, terminated, truncated, and info from the environment
"""
total_reward = 0.0
terminated = truncated = False
info = {}
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += float(reward)
if done:
break
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info

View File

@ -0,0 +1,85 @@
"""A collection of wrappers for modifying the reward with an internal state.
* ``NormalizeRewardV1`` - Normalizes the rewards to a mean and standard deviation
"""
from __future__ import annotations
from typing import Any, SupportsFloat
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.wrappers.utils import RunningMeanStd
__all__ = ["NormalizeRewardV1"]
class NormalizeRewardV1(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
The exponential moving average will have variance :math:`(1 - \gamma)^2`.
The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently.
"""
def __init__(
self,
env: gym.Env[ObsType, ActType],
gamma: float = 0.99,
epsilon: float = 1e-8,
):
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
Args:
env (env): The environment to apply the wrapper
epsilon (float): A stability parameter
gamma (float): The discount factor that is used in the exponential moving average.
"""
gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
gym.Wrapper.__init__(self, env)
self.rewards_running_means = RunningMeanStd(shape=())
self.discounted_reward: np.array = np.array([0.0])
self.gamma = gamma
self.epsilon = epsilon
self._update_running_mean = True
@property
def update_running_mean(self) -> bool:
"""Property to freeze/continue the running mean calculation of the reward statistics."""
return self._update_running_mean
@update_running_mean.setter
def update_running_mean(self, setting: bool):
"""Sets the property to freeze/continue the running mean calculation of the reward statistics."""
self._update_running_mean = setting
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
"""Steps through the environment, normalizing the reward returned."""
obs, reward, terminated, truncated, info = super().step(action)
self.discounted_reward = self.discounted_reward * self.gamma * (
1 - terminated
) + float(reward)
return obs, self.normalize(float(reward)), terminated, truncated, info
def normalize(self, reward: SupportsFloat):
"""Normalizes the rewards with the running mean rewards and their variance."""
if self._update_running_mean:
self.rewards_running_means.update(self.discounted_reward)
return reward / np.sqrt(self.rewards_running_means.var + self.epsilon)

View File

@ -0,0 +1,147 @@
"""Utility functions for the wrappers."""
from collections import OrderedDict
from functools import singledispatch
import numpy as np
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Text,
Tuple,
)
from gymnasium.spaces.space import T_cov
__all__ = ["RunningMeanStd", "update_mean_var_count_from_moments", "create_zero_array"]
class RunningMeanStd:
"""Tracks the mean, variance and count of values."""
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
@singledispatch
def create_zero_array(space: Space[T_cov]) -> T_cov:
"""Creates a zero-based array of a space, this is similar to ``create_empty_array`` except all arrays are valid samples from the space.
As some ``Box`` cases have ``high`` or ``low`` that don't contain zero then the ``create_empty_array`` would in case
create arrays which is not contained in the space.
Args:
space: The space to create a zero array for
Returns:
Valid sample from the space that is as close to zero as possible
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_zero_array` function. Register `{type(space)}` for `create_zero_array` to support it."
)
else:
raise TypeError(
f"The space provided to `create_zero_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_zero_array.register(Box)
def _create_box_zero_array(space: Box):
zero_array = np.zeros(space.shape, dtype=space.dtype)
zero_array = np.where(space.low > 0, space.low, zero_array)
zero_array = np.where(space.high < 0, space.high, zero_array)
return zero_array
@create_zero_array.register(Discrete)
def _create_discrete_zero_array(space: Discrete):
return space.start
@create_zero_array.register(MultiDiscrete)
def _create_multidiscrete_zero_array(space: MultiDiscrete):
return np.array(space.start, copy=True, dtype=space.dtype)
@create_zero_array.register(MultiBinary)
def _create_array_zero_array(space: MultiBinary):
return np.zeros(space.shape, dtype=space.dtype)
@create_zero_array.register(Tuple)
def _create_tuple_zero_array(space: Tuple):
return tuple(create_zero_array(subspace) for subspace in space.spaces)
@create_zero_array.register(Dict)
def _create_dict_zero_array(space: Dict):
return OrderedDict(
{key: create_zero_array(subspace) for key, subspace in space.spaces.items()}
)
@create_zero_array.register(Sequence)
def _create_sequence_zero_array(space: Sequence):
if space.stack:
return create_zero_array(space.stacked_feature_space)
else:
return tuple()
@create_zero_array.register(Text)
def _create_text_zero_array(space: Text):
return "".join(space.characters[0] for _ in range(space.min_length))
@create_zero_array.register(Graph)
def _create_graph_zero_array(space: Graph):
nodes = np.expand_dims(create_zero_array(space.node_space), axis=0)
if space.edge_space is None:
return GraphInstance(nodes=nodes, edges=None, edge_links=None)
else:
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
edge_links = np.zeros((1, 2), dtype=np.int64)
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)

View File

@ -0,0 +1,146 @@
"""Wrappers for vector environments."""
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToListV0
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
VectorizeLambdaActionV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
VectorizeLambdaObservationV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
ClipRewardV0,
LambdaRewardV0,
VectorizeLambdaRewardV0,
)
__all__ = [
# --- Vector only wrappers
"VectorizeLambdaObservationV0",
"VectorizeLambdaActionV0",
"VectorizeLambdaRewardV0",
"DictInfoToListV0",
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
# "TimeAwareObservationV0",
# "FrameStackObservationV0",
# "DelayObservationV0",
# --- Action Wrappers ---
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
# "NormalizeRewardV1",
# --- Common ---
"RecordEpisodeStatisticsV0",
# --- Rendering ---
# "RenderCollectionV0",
# "RecordVideoV0",
# "HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpyV0": "jax_to_numpy",
"JaxToTorchV0": "jax_to_torch",
"NumpyToTorchV0": "numpy_to_torch",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.vector.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/vector-wrappers/#gymnasium.experimental.wrappers.vector.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@ -0,0 +1,86 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
__all__ = ["DictInfoToListV0"]
class DictInfoToListV0(VectorWrapper):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
Example::
>>> import numpy as np
>>> dict_info = {
... "k": np.array([0., 0., 0.5, 0.3]),
... "_k": np.array([False, False, True, True])
... }
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
"""
def __init__(self, env: VectorEnv):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info

View File

@ -0,0 +1,79 @@
"""Vector wrapper for converting between NumPy and Jax."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
__all__ = ["JaxToNumpyV0"]
class JaxToNumpyV0(VectorWrapper):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
"""
def __init__(self, env: VectorEnv):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
env: the vector jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to a jax array .
Args:
actions: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_actions = numpy_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
return (
jax_to_numpy(obs),
jax_to_numpy(reward),
jax_to_numpy(terminated),
jax_to_numpy(truncated),
jax_to_numpy(info),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))

View File

@ -0,0 +1,76 @@
"""Vector wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import (
Device,
jax_to_torch,
torch_to_jax,
)
__all__ = ["JaxToTorchV0"]
class JaxToTorchV0(VectorWrapper):
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
"""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Vector wrapper to change inputs and outputs to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Performs the given action within the environment.
Args:
actions: The action to perform as a PyTorch Tensor
Returns:
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
jax_to_torch(reward, self.device),
jax_to_torch(terminated, self.device),
jax_to_torch(truncated, self.device),
jax_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@ -0,0 +1,73 @@
"""Wrapper for converting NumPy environments to PyTorch."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import Device
from gymnasium.experimental.wrappers.numpy_to_torch import (
numpy_to_torch,
torch_to_numpy,
)
__all__ = ["NumpyToTorchV0"]
class NumpyToTorchV0(VectorWrapper):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors."""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
numpy_to_torch(reward, self.device),
numpy_to_torch(terminated, self.device),
numpy_to_torch(truncated, self.device),
numpy_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@ -0,0 +1,144 @@
"""Wrapper that tracks the cumulative rewards and episode lengths."""
from __future__ import annotations
import time
from collections import deque
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
__all__ = ["RecordEpisodeStatisticsV0"]
class RecordEpisodeStatisticsV0(VectorWrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = { # doctest: +SKIP
... ...
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = { # doctest: +SKIP
... ...
... "episode": {
... "r": "<array of cumulative reward for each done sub-environment>",
... "l": "<array of episode length for each done sub-environment>",
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
length_queue: The lengths of the last ``deque_size``-many episodes
"""
def __init__(self, env: VectorEnv, deque_size: int = 100):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
self.episode_count = 0
self.episode_start_times: np.ndarray = np.zeros(())
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
def reset(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment, recording the episode statistics."""
(
observations,
rewards,
terminations,
truncations,
infos,
) = self.env.step(actions)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
)
else:
infos["episode"] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
}
infos["_episode"] = dones
self.episode_count += num_dones
for i in np.where(dones):
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])
self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()
return (
observations,
rewards,
terminations,
truncations,
infos,
)

View File

@ -0,0 +1,143 @@
"""Vectorizes action wrappers to work for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable
import numpy as np
from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
from gymnasium.experimental.wrappers import lambda_action
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class LambdaActionV0(VectorActionWrapper):
"""Transforms an action via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector actions.
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
"""
def __init__(
self,
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)
class VectorizeLambdaActionV0(VectorActionWrapper):
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, action_space: Space):
"""Constructor for the fake environment."""
self.action_space = action_space
def __init__(
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
):
"""Constructor for the vectorized lambda action wrapper.
Args:
env: The vector environment to wrap
wrapper: The wrapper to vectorize
**kwargs: Arguments for the LambdaActionV0 wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_action_space), **kwargs
)
self.single_action_space = self.wrapper.action_space
self.action_space = batch_space(self.single_action_space, self.num_envs)
self.same_out = self.action_space == self.env.action_space
self.out = create_empty_array(self.single_action_space, self.num_envs)
def actions(self, actions: ActType) -> ActType:
"""Applies the wrapper to each of the action.
Args:
actions: The actions to apply the function to
Returns:
The updated actions using the wrapper func
"""
if self.same_out:
return concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
actions,
)
else:
return deepcopy(
concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
self.out,
)
)
class ClipActionV0(VectorizeLambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
def __init__(self, env: VectorEnv):
"""Constructor for the Clip Action wrapper.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_action.ClipActionV0)
class RescaleActionV0(VectorizeLambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
def __init__(
self,
env: VectorEnv,
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The vector environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
super().__init__(
env,
lambda_action.RescaleActionV0,
min_action=min_action,
max_action=max_action,
)

View File

@ -0,0 +1,222 @@
"""Vectorizes observation wrappers to works for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import Space
from gymnasium.core import Env, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
from gymnasium.experimental.wrappers import lambda_observation
from gymnasium.vector.utils import create_empty_array
class LambdaObservationV0(VectorObservationWrapper):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
"""
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)
if observation_space is not None:
self.observation_space = observation_space
self.vector_func = vector_func
self.single_func = single_func
def vector_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
class VectorizeLambdaObservationV0(VectorObservationWrapper):
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, observation_space: Space):
"""Constructor for the fake environment."""
self.observation_space = observation_space
def __init__(
self,
env: VectorEnv,
wrapper: type[lambda_observation.LambdaObservationV0],
**kwargs: Any,
):
"""Constructor for the vectorized lambda observation wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_observation_space), **kwargs
)
self.single_observation_space = self.wrapper.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
observation,
)
else:
return deepcopy(
concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
self.out,
)
)
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservationV0(VectorizeLambdaObservationV0):
"""Vector wrapper for filtering dict or tuple observation spaces."""
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
"""Constructor for the filter observation wrapper.
Args:
env: The vector environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
super().__init__(
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
)
class FlattenObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that flattens the observation."""
def __init__(self, env: VectorEnv):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_observation.FlattenObservationV0)
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that converts an RGB image to grayscale."""
def __init__(self, env: VectorEnv, keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The vector environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
super().__init__(
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
)
class ResizeObservationV0(VectorizeLambdaObservationV0):
"""Resizes image observations using OpenCV to shape."""
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The vector environment to wrap
shape: The resized observation shape
"""
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
class ReshapeObservationV0(VectorizeLambdaObservationV0):
"""Reshapes array based observations to shapes."""
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
Args:
env: The vector environment to wrap
shape: The reshaped observation space
"""
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
class RescaleObservationV0(VectorizeLambdaObservationV0):
"""Linearly rescales observation to between a minimum and maximum value."""
def __init__(
self,
env: VectorEnv,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The vector environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
super().__init__(
env,
lambda_observation.RescaleObservationV0,
min_obs=min_obs,
max_obs=max_obs,
)
class DtypeObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: VectorEnv, dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The vector environment to wrap
dtype: The new dtype of the observation
"""
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)

View File

@ -0,0 +1,78 @@
"""Vectorizes reward function to work with `VectorEnv`."""
from __future__ import annotations
from typing import Any, Callable
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers import lambda_reward
class LambdaRewardV0(VectorRewardWrapper):
"""A reward wrapper that allows a custom function to modify the step reward."""
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The vector environment to wrap
func: (Callable): The function to apply to reward
"""
super().__init__(env)
self.func = func
def reward(self, reward: ArrayType) -> ArrayType:
"""Apply function to reward."""
return self.func(reward)
class VectorizeLambdaRewardV0(VectorRewardWrapper):
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
def __init__(
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
):
"""Constructor for the vectorized lambda reward wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(Env(), **kwargs)
def reward(self, reward: ArrayType) -> ArrayType:
"""Iterates over the reward updating each with the wrapper func."""
for i, r in enumerate(reward):
reward[i] = self.wrapper.func(r)
return reward
class ClipRewardV0(VectorizeLambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
def __init__(
self,
env: VectorEnv,
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Constructor for ClipReward wrapper.
Args:
env: The vector environment to wrap
min_reward: The min reward for each step
max_reward: the max reward for each step
"""
super().__init__(
env,
lambda_reward.ClipRewardV0,
min_reward=min_reward,
max_reward=max_reward,
)