I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,23 @@
"""Experimental vector env API."""
from gymnasium.experimental.vector import utils
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
from gymnasium.experimental.vector.vector_env import (
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
VectorRewardWrapper,
VectorWrapper,
)
__all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
]

View File

@ -0,0 +1,685 @@
"""An async vector environment."""
from __future__ import annotations
import multiprocessing
import sys
import time
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
from multiprocessing.connection import Connection
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import logger
from gymnasium.core import Env, ObsType
from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
CustomSpaceError,
NoAsyncCallError,
)
from gymnasium.experimental.vector.utils import (
CloudpickleWrapper,
batch_space,
clear_mpi_env_vars,
concatenate,
create_empty_array,
create_shared_memory,
iterate,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
__all__ = ["AsyncVectorEnv"]
class AsyncState(Enum):
DEFAULT = "default"
WAITING_RESET = "reset"
WAITING_STEP = "step"
WAITING_CALL = "call"
class AsyncVectorEnv(VectorEnv):
"""Vectorized environment that runs multiple environments in parallel.
It uses ``multiprocessing`` processes, and pipes for communication.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Sequence[Callable[[], Env]],
shared_memory: bool = True,
copy: bool = True,
context: str | None = None,
daemon: bool = True,
worker: callable | None = None,
):
"""Vectorized environment that runs multiple environments in parallel.
Args:
env_fns: Functions that create the environments.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
"""
super().__init__()
ctx = multiprocessing.get_context(context)
self.env_fns = env_fns
self.num_envs = len(env_fns)
self.shared_memory = shared_memory
self.copy = copy
# This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes
dummy_env = env_fns[0]()
self.metadata = dummy_env.metadata
self.single_observation_space = dummy_env.observation_space
self.single_action_space = dummy_env.action_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
dummy_env.close()
del dummy_env
if self.shared_memory:
try:
_obs_buffer = create_shared_memory(
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
self.single_observation_space, _obs_buffer, n=self.num_envs
)
except CustomSpaceError as e:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
"is incompatible with non-standard Gymnasium observation spaces "
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is "
"only compatible with default Gymnasium spaces (e.g. `Box`, "
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
"if you use custom observation spaces."
) from e
else:
_obs_buffer = None
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
target = worker or _worker
with clear_mpi_env_vars():
for idx, env_fn in enumerate(self.env_fns):
parent_pipe, child_pipe = ctx.Pipe()
process = ctx.Process(
target=target,
name=f"Worker<{type(self).__name__}>-{idx}",
args=(
idx,
CloudpickleWrapper(env_fn),
child_pipe,
parent_pipe,
_obs_buffer,
self.error_queue,
),
)
self.parent_pipes.append(parent_pipe)
self.processes.append(process)
process.daemon = daemon
process.start()
child_pipe.close()
self._state = AsyncState.DEFAULT
self._check_spaces()
def reset_async(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
To get the results of these calls, you may invoke :meth:`reset_wait`.
Args:
seed: List of seeds for each environment
options: The reset option
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
"""
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
str(self._state.value),
)
for pipe, single_seed in zip(self.parent_pipes, seed):
single_kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
if options is not None:
single_kwargs["options"] = options
pipe.send(("reset", single_kwargs))
self._state = AsyncState.WAITING_RESET
def reset_wait(
self,
timeout: int | float | None = None,
) -> tuple[ObsType, list[dict]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
Returns:
A tuple of batched observations and list of dictionaries
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
TimeoutError: If :meth:`reset_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_RESET:
raise NoAsyncCallError(
"Calling `reset_wait` without any prior " "call to `reset_async`.",
AsyncState.WAITING_RESET.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `reset_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait()
def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment.
Args:
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
def step_wait(
self, timeout: int | float | None = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns:
The batched environment step information, (obs, reward, terminated, truncated, info)
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
TimeoutError: If :meth:`step_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_STEP:
raise NoAsyncCallError(
"Calling `step_wait` without any prior call " "to `step_async`.",
AsyncState.WAITING_STEP.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, terminated, truncated, info = result
successes.append(success)
if success:
observations_list.append(obs)
rewards.append(rew)
terminateds.append(terminated)
truncateds.append(truncated)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space,
observations_list,
self.observations,
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
)
def step(self, actions):
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
"""
self.step_async(actions)
return self.step_wait()
def call_async(self, name: str, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method.
Args:
name: Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `call_async` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
for pipe in self.parent_pipes:
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout: int | float | None = None) -> list:
"""Calls all parent pipes and waits for the results.
Args:
timeout: Number of seconds before the call to `step_wait` times out.
If `None` (default), the call to `step_wait` never times out.
Returns:
List of the results of the individual calls to the method or property for each environment.
Raises:
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
raise NoAsyncCallError(
"Calling `call_wait` without any prior call to `call_async`.",
AsyncState.WAITING_CALL.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise multiprocessing.TimeoutError(
f"The call to `call_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
return results
def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
"""Sets an attribute of the sub-environments.
Args:
name: Name of the property to be set in each individual environment.
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
"""
self._assert_is_running()
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)
for pipe, value in zip(self.parent_pipes, values):
pipe.send(("_setattr", (name, value)))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
"""Close the environments & clean up the extra resources (processes and pipes).
Args:
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
the call to :meth:`close` never times out. If the call to :meth:`close`
times out, then all processes are terminated.
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
Raises:
TimeoutError: If :meth:`close` timed out.
"""
timeout = 0 if terminate else timeout
try:
if self._state != AsyncState.DEFAULT:
logger.warn(
f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
)
function = getattr(self, f"{self._state.value}_wait")
function(timeout)
except multiprocessing.TimeoutError:
terminate = True
if terminate:
for process in self.processes:
if process.is_alive():
process.terminate()
else:
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.send(("close", None))
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.recv()
for pipe in self.parent_pipes:
if pipe is not None:
pipe.close()
for process in self.processes:
process.join()
def _poll(self, timeout=None):
self._assert_is_running()
if timeout is None:
return True
end_time = time.perf_counter() + timeout
delta = None
for pipe in self.parent_pipes:
delta = max(end_time - time.perf_counter(), 0)
if pipe is None:
return False
if pipe.closed or (not pipe.poll(delta)):
return False
return True
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
f"Some environments have an action space different from `{self.single_action_space}`. "
"In order to batch actions, the action spaces from all environments must be equal."
)
def _assert_is_running(self):
if self.closed:
raise ClosedEnvironmentError(
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
)
def _raise_if_errors(self, successes: list[bool]):
if all(successes):
return
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
)
logger.error(f"Shutting down Worker-{index}.")
self.parent_pipes[index].close()
self.parent_pipes[index] = None
if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
raise exctype(value)
def __del__(self):
"""On deleting the object, checks that the vector environment is closed."""
if not getattr(self, "closed", True) and hasattr(self, "_state"):
self.close(terminate=True)
def _worker(
index: int,
env_fn: callable,
pipe: Connection,
parent_pipe: Connection,
shared_memory: bool,
error_queue: Queue,
):
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
if command == "reset":
observation, info = env.reset(**data)
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, info), True))
elif command == "step":
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
if shared_memory:
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
observation = None
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close":
pipe.send((None, True))
break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == observation_space, data[1] == action_space),
True,
)
)
else:
raise RuntimeError(
f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
pipe.send((None, False))
finally:
env.close()

View File

@ -0,0 +1,229 @@
"""A synchronous vector environment."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Iterator
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
__all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Iterator[Callable[[], Env]],
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
"""
super().__init__()
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.num_envs = len(self.envs)
self.copy = copy
self.metadata = self.envs[0].metadata
self.spec = self.envs[0].spec
self.single_observation_space = self.envs[0].observation_space
self.single_action_space = self.envs[0].action_space
self.observation_space = batch_space(
self.single_observation_space, self.num_envs
)
self.action_space = batch_space(self.single_action_space, self.num_envs)
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
def reset(
self,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
self._terminateds[:] = False
self._truncateds[:] = False
observations = []
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step(self, actions):
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
actions = iterate(self.action_space, actions)
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, actions)):
(
observation,
self._rewards[i],
self._terminateds[i],
self._truncateds[i],
info,
) = env.step(action)
if self._terminateds[i] or self._truncateds[i]:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
)
def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: list | tuple | Any):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
"""
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_spaces(self) -> bool:
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
return True

View File

@ -0,0 +1,30 @@
"""Module for gymnasium experimental vector utility functions."""
from gymnasium.experimental.vector.utils.misc import (
CloudpickleWrapper,
clear_mpi_env_vars,
)
from gymnasium.experimental.vector.utils.shared_memory import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.utils.space_utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
__all__ = [
"batch_space",
"iterate",
"concatenate",
"create_empty_array",
"create_shared_memory",
"read_from_shared_memory",
"write_to_shared_memory",
"CloudpickleWrapper",
"clear_mpi_env_vars",
]

View File

@ -0,0 +1,61 @@
"""Miscellaneous utilities."""
from __future__ import annotations
import contextlib
import os
from collections.abc import Callable
from gymnasium.core import Env
__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper:
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: Callable[[], Env]):
"""Cloudpickle wrapper for a function."""
self.fn = fn
def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle
return cloudpickle.dumps(self.fn)
def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle
self.fn = pickle.loads(ob)
def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn()
@contextlib.contextmanager
def clear_mpi_env_vars():
"""Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables
temporarily such as when we are starting multiprocessing Processes.
Yields:
Yields for the context manager
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ["OMPI_", "PMI_"]:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)

View File

@ -0,0 +1,255 @@
"""Utility functions for vector environments to share memory between processes."""
from __future__ import annotations
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Any
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
flatten,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(
space: Space[Any], n: int = 1, ctx=mp
) -> dict[str, Any] | tuple[Any, ...] | mp.Array:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
Returns:
shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp
):
assert space.dtype is not None
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
@create_shared_memory.register(Text)
def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp):
return ctx.Array(np.dtype(np.int32).char, n * space.max_length)
@create_shared_memory.register(Graph)
@create_shared_memory.register(Sequence)
def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp):
raise TypeError(
f"As {space} has a dynamic shape then it is not possible to make a static shared memory."
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> dict[str, Any] | tuple[Any, ...] | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
Returns:
Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1
):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
@read_from_shared_memory.register(Text)
def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]:
data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape(
(n, space.max_length)
)
return tuple(
"".join(
[
space.character_list[val]
for val in values
if val < len(space.character_set)
]
)
for values in data
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array,
):
"""Write the observation of a single environment into shared memory.
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it."
)
else:
raise TypeError(
f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(
space: Box | Discrete | MultiDiscrete | MultiBinary,
index: int,
value,
shared_memory,
):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(
space: Tuple, index: int, values: tuple[Any, ...], shared_memory
):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(
space: Dict, index: int, values: dict[str, Any], shared_memory
):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
@write_to_shared_memory.register(Text)
def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory):
size = space.max_length
destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32)
np.copyto(
destination[index * size : (index + 1) * size],
flatten(space, values),
)

View File

@ -0,0 +1,407 @@
"""Space-based utility functions for vector environments.
- ``batch_space``: Create a (batched) space, containing multiple copies of a single space.
- ``concatenate``: Concatenate multiple samples from (unbatched) space into a single object.
- ``Iterate``: Iterate over the elements of a (batched) space and items.
- ``create_empty_array``: Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``
"""
from __future__ import annotations
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Any, Iterable, Iterator
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Space,
Text,
Tuple,
)
from gymnasium.spaces.space import T_cov
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
@singledispatch
def batch_space(space: Space[Any], n: int = 1) -> Space[Any]:
"""Create a (batched) space, containing multiple copies of a single space.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space does not have a registered function.
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
"""
raise TypeError(
f"The space provided to `batch_space` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@batch_space.register(Box)
def _batch_space_box(space: Box, n: int = 1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
@batch_space.register(Discrete)
def _batch_space_discrete(space: Discrete, n: int = 1):
return MultiDiscrete(
np.full((n,), space.n, dtype=space.dtype),
dtype=space.dtype,
seed=deepcopy(space.np_random),
start=np.full((n,), space.start, dtype=space.dtype),
)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space: MultiDiscrete, n: int = 1):
repeats = tuple([n] + [1] * space.nvec.ndim)
low = np.tile(space.start, repeats)
high = low + np.tile(space.nvec, repeats) - 1
return Box(
low=low,
high=high,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(MultiBinary)
def _batch_space_multibinary(space: MultiBinary, n: int = 1):
return Box(
low=0,
high=1,
shape=(n,) + space.shape,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(Tuple)
def _batch_space_tuple(space: Tuple, n: int = 1):
return Tuple(
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
seed=deepcopy(space.np_random),
)
@batch_space.register(Dict)
def _batch_space_dict(space: Dict, n: int = 1):
return Dict(
{key: batch_space(subspace, n=n) for key, subspace in space.items()},
seed=deepcopy(space.np_random),
)
@batch_space.register(Graph)
@batch_space.register(Text)
@batch_space.register(Sequence)
@batch_space.register(Space)
def _batch_space_custom(space: Graph | Text | Sequence, n: int = 1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
)
space_rng = deepcopy(space.np_random)
new_seeds = list(map(int, space_rng.integers(0, 1e8, n)))
batched_space.seed(new_seeds)
return batched_space
@singledispatch
def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not an instance of :class:`gym.Space`
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
>>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
>>> next(it)
Traceback (most recent call last):
...
StopIteration
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `iterate` function. Register `{type(space)}` for `iterate` to support it."
)
else:
raise TypeError(
f"The space provided to `iterate` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@iterate.register(Discrete)
def _iterate_discrete(space: Discrete, items: Iterable):
raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def _iterate_base(space: Box | MultiDiscrete | MultiBinary, items: np.ndarray):
try:
return iter(items)
except TypeError as e:
raise TypeError(
f"Unable to iterate over the following elements: {items}"
) from e
@iterate.register(Tuple)
def _iterate_tuple(space: Tuple, items: tuple[Any, ...]):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(type(subspace) in iterate.registry for subspace in space):
return zip(*[iterate(subspace, items[i]) for i, subspace in enumerate(space)])
try:
return iter(items)
except Exception as e:
unregistered_spaces = [
type(subspace)
for subspace in space
if type(subspace) not in iterate.registry
]
raise CustomSpaceError(
f"Could not iterate through {space} as no custom iterate function is registered for {unregistered_spaces} and `iter(items)` raised the following error: {e}."
) from e
@iterate.register(Dict)
def _iterate_dict(space: Dict, items: dict[str, Any]):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict({key: value for key, value in zip(keys, item)})
@singledispatch
def concatenate(
space: Space, items: Iterable, out: tuple[Any, ...] | dict[str, Any] | np.ndarray
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Concatenate multiple samples from space into a single object.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space
Example:
>>> from gymnasium.spaces import Box
>>> import numpy as np
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.77395606, 0.43887845, 0.85859793],
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
"""
raise TypeError(
f"The space provided to `concatenate` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@concatenate.register(Box)
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(
space: Box | Discrete | MultiDiscrete | MultiBinary,
items: Iterable,
out: np.ndarray,
) -> np.ndarray:
return np.stack(items, axis=0, out=out)
@concatenate.register(Tuple)
def _concatenate_tuple(
space: Tuple, items: Iterable, out: tuple[Any, ...]
) -> tuple[Any, ...]:
return tuple(
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
)
@concatenate.register(Dict)
def _concatenate_dict(
space: Dict, items: Iterable, out: dict[str, Any]
) -> dict[str, Any]:
return OrderedDict(
{
key: concatenate(subspace, [item[key] for item in items], out[key])
for key, subspace in space.items()
}
)
@concatenate.register(Graph)
@concatenate.register(Text)
@concatenate.register(Sequence)
@concatenate.register(Space)
def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, ...]:
return tuple(items)
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
) -> tuple[Any, ...] | dict[str, Any] | np.ndarray:
"""Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``.
In most cases, the array will be contained within the batched space, however, this is not guaranteed.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise TypeError(
f"The space provided to `create_empty_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
# It is possible for the some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_multi(space: Box, n: int = 1, fn=np.zeros) -> np.ndarray:
return fn((n,) + space.shape, dtype=space.dtype)
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space: Tuple, n: int = 1, fn=np.zeros) -> tuple[Any, ...]:
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
@create_empty_array.register(Dict)
def _create_empty_array_dict(space: Dict, n: int = 1, fn=np.zeros) -> dict[str, Any]:
return OrderedDict(
{
key: create_empty_array(subspace, n=n, fn=fn)
for key, subspace in space.items()
}
)
@create_empty_array.register(Graph)
def _create_empty_array_graph(
space: Graph, n: int = 1, fn=np.zeros
) -> tuple[GraphInstance, ...]:
if space.edge_space is not None:
return tuple(
GraphInstance(
nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
edges=fn((1,) + space.edge_space.shape, dtype=space.edge_space.dtype),
edge_links=fn((1, 2), dtype=np.int64),
)
for _ in range(n)
)
else:
return tuple(
GraphInstance(
nodes=fn((1,) + space.node_space.shape, dtype=space.node_space.dtype),
edges=None,
edge_links=None,
)
for _ in range(n)
)
@create_empty_array.register(Text)
def _create_empty_array_text(space: Text, n: int = 1, fn=np.zeros) -> tuple[str, ...]:
return tuple(space.characters[0] * space.min_length for _ in range(n))
@create_empty_array.register(Sequence)
def _create_empty_array_sequence(
space: Sequence, n: int = 1, fn=np.zeros
) -> tuple[Any, ...]:
if space.stack:
return tuple(
create_empty_array(space.feature_space, n=1, fn=fn) for _ in range(n)
)
else:
return tuple(tuple() for _ in range(n))
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@ -0,0 +1,486 @@
"""Base class for vectorized environments."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.utils import seeding
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
ArrayType = TypeVar("ArrayType")
__all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"ArrayType",
]
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated.
As a result, the final step's observation and info are overwritten by the reset's observation and info.
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter,
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`.
The Vector Environments have the additional attributes for users to understand the implementation
- :attr:`num_envs` - The number of sub-environment in the vector environment
- :attr:`observation_space` - The batched observation space of the vector environment
- :attr:`single_observation_space` - The observation space of a single sub-environment
- :attr:`action_space` - The batched action space of the vector environment
- :attr:`single_action_space` - The action space of a single sub-environment
Note:
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`.
Note:
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes`
for all the sub-environments during initialization.
Note:
All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
"""
spec: EnvSpec
observation_space: gym.Space
action_space: gym.Space
single_observation_space: gym.Space
single_action_space: gym.Space
num_envs: int
closed = False
_np_random: np.random.Generator | None = None
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
Example:
>>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
Note:
As the vector environments autoreset for a terminating and truncating sub-environments,
the returned observation and info is not the final step's observation or info which is instead stored in
info as `"final_observation"` and `"final_info"`.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
[-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> termination
array([False, False, False])
>>> termination
array([False, False, False])
>>> infos
{}
"""
pass
def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
"""Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
Warnings:
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
Note:
This will be automatically called when garbage collected or program exited.
Args:
**kwargs: Keyword arguments passed to :meth:`close_extras`
"""
if self.closed:
return
self.close_extras(**kwargs)
self.closed = True
@property
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.
Returns:
Instances of `np.random.Generator`
"""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
return self._np_random
@np_random.setter
def np_random(self, value: np.random.Generator):
self._np_random = value
@property
def unwrapped(self):
"""Return the base environment."""
return self
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):
self.close()
def __repr__(self) -> str:
"""Returns a string representation of the vector environment.
Returns:
A string containing the class name, number of environments and environment spec id
"""
if getattr(self, "spec", None) is None:
return f"{self.__class__.__name__}({self.num_envs})"
else:
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
class VectorWrapper(VectorEnv):
"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
Note:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
_observation_space: gym.Space | None = None
_action_space: gym.Space | None = None
_single_observation_space: gym.Space | None = None
_single_action_space: gym.Space | None = None
def __init__(self, env: VectorEnv):
"""Initialize the vectorized environment wrapper."""
super().__init__()
assert isinstance(env, VectorEnv)
self.env = env
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset all environment using seed and options."""
return self.env.reset(seed=seed, options=options)
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Step all environments."""
return self.env.step(actions)
def close(self, **kwargs: Any):
"""Close all environments."""
return self.env.close(**kwargs)
def close_extras(self, **kwargs: Any):
"""Close all extra resources."""
return self.env.close_extras(**kwargs)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name: str) -> Any:
"""Forward all other attributes to the base environment."""
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
return getattr(self.env, name)
@property
def unwrapped(self):
"""Return the base non-wrapped environment."""
return self.env.unwrapped
def __repr__(self):
"""Return the string representation of the vectorized environment."""
return f"<{self.__class__.__name__}, {self.env}>"
def __del__(self):
"""Close the vectorized environment."""
self.env.__del__()
@property
def spec(self) -> EnvSpec | None:
"""Gets the specification of the wrapped environment."""
return self.env.spec
@property
def observation_space(self) -> gym.Space:
"""Gets the observation space of the vector environment."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space
@observation_space.setter
def observation_space(self, space: gym.Space):
"""Sets the observation space of the vector environment."""
self._observation_space = space
@property
def action_space(self) -> gym.Space:
"""Gets the action space of the vector environment."""
if self._action_space is None:
return self.env.action_space
return self._action_space
@action_space.setter
def action_space(self, space: gym.Space):
"""Sets the action space of the vector environment."""
self._action_space = space
@property
def single_observation_space(self) -> gym.Space:
"""Gets the single observation space of the vector environment."""
if self._single_observation_space is None:
return self.env.single_observation_space
return self._single_observation_space
@single_observation_space.setter
def single_observation_space(self, space: gym.Space):
"""Sets the single observation space of the vector environment."""
self._single_observation_space = space
@property
def single_action_space(self) -> gym.Space:
"""Gets the single action space of the vector environment."""
if self._single_action_space is None:
return self.env.single_action_space
return self._single_action_space
@single_action_space.setter
def single_action_space(self, space):
"""Sets the single action space of the vector environment."""
self._single_action_space = space
@property
def num_envs(self) -> int:
"""Gets the wrapped vector environment's num of the sub-environments."""
return self.env.num_envs
class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
)
def vector_observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
observation: A vector observation from the environment
Returns:
the transformed observation
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.actions(actions))
def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment.
Args:
actions (ActType): the actions to transform
Returns:
ActType: the transformed actions
"""
raise NotImplementedError
class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.reward(reward), termination, truncation, info
def reward(self, reward: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
Returns:
array: the transformed reward
"""
raise NotImplementedError