I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,146 @@
"""Wrappers for vector environments."""
# pyright: reportUnsupportedDunderAll=false
import importlib
import re
from gymnasium.error import DeprecatedWrapper
from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToListV0
from gymnasium.experimental.wrappers.vector.record_episode_statistics import (
RecordEpisodeStatisticsV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_action import (
ClipActionV0,
LambdaActionV0,
RescaleActionV0,
VectorizeLambdaActionV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_observation import (
DtypeObservationV0,
FilterObservationV0,
FlattenObservationV0,
GrayscaleObservationV0,
LambdaObservationV0,
RescaleObservationV0,
ReshapeObservationV0,
ResizeObservationV0,
VectorizeLambdaObservationV0,
)
from gymnasium.experimental.wrappers.vector.vectorize_reward import (
ClipRewardV0,
LambdaRewardV0,
VectorizeLambdaRewardV0,
)
__all__ = [
# --- Vector only wrappers
"VectorizeLambdaObservationV0",
"VectorizeLambdaActionV0",
"VectorizeLambdaRewardV0",
"DictInfoToListV0",
# --- Observation wrappers ---
"LambdaObservationV0",
"FilterObservationV0",
"FlattenObservationV0",
"GrayscaleObservationV0",
"ResizeObservationV0",
"ReshapeObservationV0",
"RescaleObservationV0",
"DtypeObservationV0",
# "PixelObservationV0",
# "NormalizeObservationV0",
# "TimeAwareObservationV0",
# "FrameStackObservationV0",
# "DelayObservationV0",
# --- Action Wrappers ---
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# --- Reward wrappers ---
"LambdaRewardV0",
"ClipRewardV0",
# "NormalizeRewardV1",
# --- Common ---
"RecordEpisodeStatisticsV0",
# --- Rendering ---
# "RenderCollectionV0",
# "RecordVideoV0",
# "HumanRenderingV0",
# --- Conversion ---
"JaxToNumpyV0",
"JaxToTorchV0",
"NumpyToTorchV0",
]
# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them
# to avoid `import jax` or `import torch` on `import gymnasium`.
_wrapper_to_class = {
# data converters
"JaxToNumpyV0": "jax_to_numpy",
"JaxToTorchV0": "jax_to_torch",
"NumpyToTorchV0": "numpy_to_torch",
}
def __getattr__(wrapper_name: str):
"""Load a wrapper by name.
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args:
wrapper_name: The name of a wrapper to load.
Returns:
The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.vector.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt)
return getattr(module, wrapper_name)
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else:
version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/vector-wrappers/#gymnasium.experimental.wrappers.vector.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@ -0,0 +1,86 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
__all__ = ["DictInfoToListV0"]
class DictInfoToListV0(VectorWrapper):
"""Converts infos of vectorized environments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))``
Example::
>>> import numpy as np
>>> dict_info = {
... "k": np.array([0., 0., 0.5, 0.3]),
... "_k": np.array([False, False, True, True])
... }
>>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}]
"""
def __init__(self, env: VectorEnv):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info

View File

@ -0,0 +1,79 @@
"""Vector wrapper for converting between NumPy and Jax."""
from __future__ import annotations
from typing import Any
import jax.numpy as jnp
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
__all__ = ["JaxToNumpyV0"]
class JaxToNumpyV0(VectorWrapper):
"""Wraps a jax vector environment so that it can be interacted with through numpy arrays.
Notes:
A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0``
Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays.
"""
def __init__(self, env: VectorEnv):
"""Wraps an environment such that the input and outputs are numpy arrays.
Args:
env: the vector jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
"jax is not installed, run `pip install gymnasium[jax]`"
)
super().__init__(env)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Transforms the action to a jax array .
Args:
actions: the action to perform as a numpy array
Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_actions = numpy_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_actions)
return (
jax_to_numpy(obs),
jax_to_numpy(reward),
jax_to_numpy(terminated),
jax_to_numpy(truncated),
jax_to_numpy(info),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)
return jax_to_numpy(self.env.reset(seed=seed, options=options))

View File

@ -0,0 +1,76 @@
"""Vector wrapper class for converting between PyTorch and Jax."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import (
Device,
jax_to_torch,
torch_to_jax,
)
__all__ = ["JaxToTorchV0"]
class JaxToTorchV0(VectorWrapper):
"""Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors.
Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors.
"""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Vector wrapper to change inputs and outputs to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Performs the given action within the environment.
Args:
actions: The action to perform as a PyTorch Tensor
Returns:
Torch-based Tensors of the next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_jax(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
jax_to_torch(obs, self.device),
jax_to_torch(reward, self.device),
jax_to_torch(terminated, self.device),
jax_to_torch(truncated, self.device),
jax_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_jax(options)
return jax_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@ -0,0 +1,73 @@
"""Wrapper for converting NumPy environments to PyTorch."""
from __future__ import annotations
from typing import Any
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers.jax_to_torch import Device
from gymnasium.experimental.wrappers.numpy_to_torch import (
numpy_to_torch,
torch_to_numpy,
)
__all__ = ["NumpyToTorchV0"]
class NumpyToTorchV0(VectorWrapper):
"""Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors."""
def __init__(self, env: VectorEnv, device: Device | None = None):
"""Wrapper class to change inputs and outputs of environment to PyTorch tensors.
Args:
env: The Jax-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
self.device: Device | None = device
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Using a PyTorch based action that is converted to NumPy to be used by the environment.
Args:
action: A PyTorch-based action
Returns:
The PyTorch-based Tensor next observation, reward, termination, truncation, and extra info
"""
jax_action = torch_to_numpy(actions)
obs, reward, terminated, truncated, info = self.env.step(jax_action)
return (
numpy_to_torch(obs, self.device),
numpy_to_torch(reward, self.device),
numpy_to_torch(terminated, self.device),
numpy_to_torch(truncated, self.device),
numpy_to_torch(info, self.device),
)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the environment returning PyTorch-based observation and info.
Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.
Returns:
PyTorch-based observations and info
"""
if options:
options = torch_to_numpy(options)
return numpy_to_torch(self.env.reset(seed=seed, options=options), self.device)

View File

@ -0,0 +1,144 @@
"""Wrapper that tracks the cumulative rewards and episode lengths."""
from __future__ import annotations
import time
from collections import deque
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
__all__ = ["RecordEpisodeStatisticsV0"]
class RecordEpisodeStatisticsV0(VectorWrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = { # doctest: +SKIP
... ...
... "episode": {
... "r": "<cumulative reward>",
... "l": "<episode length>",
... "t": "<elapsed time since beginning of episode>"
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = { # doctest: +SKIP
... ...
... "episode": {
... "r": "<array of cumulative reward for each done sub-environment>",
... "l": "<array of episode length for each done sub-environment>",
... "t": "<array of elapsed time since beginning of episode for each done sub-environment>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
Attributes:
return_queue: The cumulative rewards of the last ``deque_size``-many episodes
length_queue: The lengths of the last ``deque_size``-many episodes
"""
def __init__(self, env: VectorEnv, deque_size: int = 100):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
"""
super().__init__(env)
self.episode_count = 0
self.episode_start_times: np.ndarray = np.zeros(())
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
def reset(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment, recording the episode statistics."""
(
observations,
rewards,
terminations,
truncations,
infos,
) = self.env.step(actions)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
)
else:
infos["episode"] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
}
infos["_episode"] = dones
self.episode_count += num_dones
for i in np.where(dones):
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])
self.episode_lengths[dones] = 0
self.episode_returns[dones] = 0
self.episode_start_times[dones] = time.perf_counter()
return (
observations,
rewards,
terminations,
truncations,
infos,
)

View File

@ -0,0 +1,143 @@
"""Vectorizes action wrappers to work for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable
import numpy as np
from gymnasium import Space
from gymnasium.core import ActType, Env
from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv
from gymnasium.experimental.wrappers import lambda_action
from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate
class LambdaActionV0(VectorActionWrapper):
"""Transforms an action via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector actions.
If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`.
"""
def __init__(
self,
env: VectorEnv,
func: Callable[[ActType], Any],
action_space: Space | None = None,
):
"""Constructor for the lambda action wrapper.
Args:
env: The vector environment to wrap
func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``.
action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def actions(self, actions: ActType) -> ActType:
"""Applies the :attr:`func` to the actions."""
return self.func(actions)
class VectorizeLambdaActionV0(VectorActionWrapper):
"""Vectorizes a single-agent lambda action wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, action_space: Space):
"""Constructor for the fake environment."""
self.action_space = action_space
def __init__(
self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any
):
"""Constructor for the vectorized lambda action wrapper.
Args:
env: The vector environment to wrap
wrapper: The wrapper to vectorize
**kwargs: Arguments for the LambdaActionV0 wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_action_space), **kwargs
)
self.single_action_space = self.wrapper.action_space
self.action_space = batch_space(self.single_action_space, self.num_envs)
self.same_out = self.action_space == self.env.action_space
self.out = create_empty_array(self.single_action_space, self.num_envs)
def actions(self, actions: ActType) -> ActType:
"""Applies the wrapper to each of the action.
Args:
actions: The actions to apply the function to
Returns:
The updated actions using the wrapper func
"""
if self.same_out:
return concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
actions,
)
else:
return deepcopy(
concatenate(
self.single_action_space,
tuple(
self.wrapper.func(action)
for action in iterate(self.action_space, actions)
),
self.out,
)
)
class ClipActionV0(VectorizeLambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound."""
def __init__(self, env: VectorEnv):
"""Constructor for the Clip Action wrapper.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_action.ClipActionV0)
class RescaleActionV0(VectorizeLambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action]."""
def __init__(
self,
env: VectorEnv,
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The vector environment to wrap
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
super().__init__(
env,
lambda_action.RescaleActionV0,
min_action=min_action,
max_action=max_action,
)

View File

@ -0,0 +1,222 @@
"""Vectorizes observation wrappers to works for `VectorEnv`."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import Space
from gymnasium.core import Env, ObsType
from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper
from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate
from gymnasium.experimental.wrappers import lambda_observation
from gymnasium.vector.utils import create_empty_array
class LambdaObservationV0(VectorObservationWrapper):
"""Transforms an observation via a function provided to the wrapper.
The function :attr:`func` will be applied to all vector observations.
If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`.
"""
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the lambda observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)
if observation_space is not None:
self.observation_space = observation_space
self.vector_func = vector_func
self.single_func = single_func
def vector_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
class VectorizeLambdaObservationV0(VectorObservationWrapper):
"""Vectori`es a single-agent lambda observation wrapper for vector environments."""
class VectorizedEnv(Env):
"""Fake single-agent environment uses for the single-agent wrapper."""
def __init__(self, observation_space: Space):
"""Constructor for the fake environment."""
self.observation_space = observation_space
def __init__(
self,
env: VectorEnv,
wrapper: type[lambda_observation.LambdaObservationV0],
**kwargs: Any,
):
"""Constructor for the vectorized lambda observation wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(
self.VectorizedEnv(self.env.single_observation_space), **kwargs
)
self.single_observation_space = self.wrapper.observation_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
observation,
)
else:
return deepcopy(
concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
),
self.out,
)
)
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservationV0(VectorizeLambdaObservationV0):
"""Vector wrapper for filtering dict or tuple observation spaces."""
def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]):
"""Constructor for the filter observation wrapper.
Args:
env: The vector environment to wrap
filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly
"""
super().__init__(
env, lambda_observation.FilterObservationV0, filter_keys=filter_keys
)
class FlattenObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that flattens the observation."""
def __init__(self, env: VectorEnv):
"""Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``.
Args:
env: The vector environment to wrap
"""
super().__init__(env, lambda_observation.FlattenObservationV0)
class GrayscaleObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper that converts an RGB image to grayscale."""
def __init__(self, env: VectorEnv, keep_dim: bool = False):
"""Constructor for an RGB image based environments to make the image grayscale.
Args:
env: The vector environment to wrap
keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2``
"""
super().__init__(
env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim
)
class ResizeObservationV0(VectorizeLambdaObservationV0):
"""Resizes image observations using OpenCV to shape."""
def __init__(self, env: VectorEnv, shape: tuple[int, ...]):
"""Constructor that requires an image environment observation space with a shape.
Args:
env: The vector environment to wrap
shape: The resized observation shape
"""
super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape)
class ReshapeObservationV0(VectorizeLambdaObservationV0):
"""Reshapes array based observations to shapes."""
def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]):
"""Constructor for env with Box observation space that has a shape product equal to the new shape product.
Args:
env: The vector environment to wrap
shape: The reshaped observation space
"""
super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape)
class RescaleObservationV0(VectorizeLambdaObservationV0):
"""Linearly rescales observation to between a minimum and maximum value."""
def __init__(
self,
env: VectorEnv,
min_obs: np.floating | np.integer | np.ndarray,
max_obs: np.floating | np.integer | np.ndarray,
):
"""Constructor that requires the env observation spaces to be a :class:`Box`.
Args:
env: The vector environment to wrap
min_obs: The new minimum observation bound
max_obs: The new maximum observation bound
"""
super().__init__(
env,
lambda_observation.RescaleObservationV0,
min_obs=min_obs,
max_obs=max_obs,
)
class DtypeObservationV0(VectorizeLambdaObservationV0):
"""Observation wrapper for transforming the dtype of an observation."""
def __init__(self, env: VectorEnv, dtype: Any):
"""Constructor for Dtype observation wrapper.
Args:
env: The vector environment to wrap
dtype: The new dtype of the observation
"""
super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype)

View File

@ -0,0 +1,78 @@
"""Vectorizes reward function to work with `VectorEnv`."""
from __future__ import annotations
from typing import Any, Callable
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper
from gymnasium.experimental.vector.vector_env import ArrayType
from gymnasium.experimental.wrappers import lambda_reward
class LambdaRewardV0(VectorRewardWrapper):
"""A reward wrapper that allows a custom function to modify the step reward."""
def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The vector environment to wrap
func: (Callable): The function to apply to reward
"""
super().__init__(env)
self.func = func
def reward(self, reward: ArrayType) -> ArrayType:
"""Apply function to reward."""
return self.func(reward)
class VectorizeLambdaRewardV0(VectorRewardWrapper):
"""Vectorizes a single-agent lambda reward wrapper for vector environments."""
def __init__(
self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any
):
"""Constructor for the vectorized lambda reward wrapper.
Args:
env: The vector environment to wrap.
wrapper: The wrapper to vectorize
**kwargs: Keyword argument for the wrapper
"""
super().__init__(env)
self.wrapper = wrapper(Env(), **kwargs)
def reward(self, reward: ArrayType) -> ArrayType:
"""Iterates over the reward updating each with the wrapper func."""
for i, r in enumerate(reward):
reward[i] = self.wrapper.func(r)
return reward
class ClipRewardV0(VectorizeLambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound."""
def __init__(
self,
env: VectorEnv,
min_reward: float | np.ndarray | None = None,
max_reward: float | np.ndarray | None = None,
):
"""Constructor for ClipReward wrapper.
Args:
env: The vector environment to wrap
min_reward: The min reward for each step
max_reward: the max reward for each step
"""
super().__init__(
env,
lambda_reward.ClipRewardV0,
min_reward=min_reward,
max_reward=max_reward,
)