I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,85 @@
"""Module for vector environments."""
from typing import Callable, Iterable, List, Optional, Union
import gymnasium as gym
from gymnasium.core import Env
from gymnasium.vector import utils
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
__all__ = [
"AsyncVectorEnv",
"SyncVectorEnv",
"VectorEnv",
"VectorEnvWrapper",
"make",
"utils",
]
def make(
id: str,
num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
Args:
id: The environment ID. This must be a valid ID from the registry.
num_envs: Number of copies of the environment.
asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing` to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
**kwargs: Keywords arguments applied during `gym.make`
Returns:
The vectorized environment.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
gym.logger.warn(
"`gymnasium.vector.make(...)` is deprecated and will be replaced by `gymnasium.make_vec(...)` in v1.0"
)
def create_env(env_num: int) -> Callable[[], Env]:
"""Creates an environment that can enable or disable the environment checker."""
# If the env_num > 0 then disable the environment checker otherwise use the parameter
_disable_env_checker = True if env_num > 0 else disable_env_checker
def _make_env() -> Env:
env = gym.envs.registration.make(
id,
disable_env_checker=_disable_env_checker,
**kwargs,
)
if wrappers is not None:
if callable(wrappers):
env = wrappers(env)
elif isinstance(wrappers, Iterable) and all(
[callable(w) for w in wrappers]
):
for wrapper in wrappers:
env = wrapper(env)
else:
raise NotImplementedError
return env
return _make_env
env_fns = [
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
]
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)

View File

@ -0,0 +1,687 @@
"""An async vector environment."""
import multiprocessing as mp
import sys
import time
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import Env, ObsType
from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
CustomSpaceError,
NoAsyncCallError,
)
from gymnasium.vector.utils import (
CloudpickleWrapper,
clear_mpi_env_vars,
concatenate,
create_empty_array,
create_shared_memory,
iterate,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.vector_env import VectorEnv
__all__ = ["AsyncVectorEnv"]
class AsyncState(Enum):
DEFAULT = "default"
WAITING_RESET = "reset"
WAITING_STEP = "step"
WAITING_CALL = "call"
class AsyncVectorEnv(VectorEnv):
"""Vectorized environment that runs multiple environments in parallel.
It uses ``multiprocessing`` processes, and pipes for communication.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Sequence[Callable[[], Env]],
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
shared_memory: bool = True,
copy: bool = True,
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[Callable] = None,
):
"""Vectorized environment that runs multiple environments in parallel.
Args:
env_fns: Functions that create the environments.
observation_space: Observation space of a single environment. If ``None``,
then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True.
"""
ctx = mp.get_context(context)
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
dummy_env = env_fns[0]()
self.metadata = dummy_env.metadata
if (observation_space is None) or (action_space is None):
observation_space = observation_space or dummy_env.observation_space
action_space = action_space or dummy_env.action_space
dummy_env.close()
del dummy_env
super().__init__(
num_envs=len(env_fns),
observation_space=observation_space,
action_space=action_space,
)
if self.shared_memory:
try:
_obs_buffer = create_shared_memory(
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
self.single_observation_space, _obs_buffer, n=self.num_envs
)
except CustomSpaceError as e:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
"is incompatible with non-standard Gymnasium observation spaces "
"(i.e. custom spaces inheriting from `gymnasium.Space`), and is "
"only compatible with default Gymnasium spaces (e.g. `Box`, "
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
"if you use custom observation spaces."
) from e
else:
_obs_buffer = None
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
target = _worker_shared_memory if self.shared_memory else _worker
target = worker or target
with clear_mpi_env_vars():
for idx, env_fn in enumerate(self.env_fns):
parent_pipe, child_pipe = ctx.Pipe()
process = ctx.Process(
target=target,
name=f"Worker<{type(self).__name__}>-{idx}",
args=(
idx,
CloudpickleWrapper(env_fn),
child_pipe,
parent_pipe,
_obs_buffer,
self.error_queue,
),
)
self.parent_pipes.append(parent_pipe)
self.processes.append(process)
process.daemon = daemon
process.start()
child_pipe.close()
self._state = AsyncState.DEFAULT
self._check_spaces()
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
To get the results of these calls, you may invoke :meth:`reset_wait`.
Args:
seed: List of seeds for each environment
options: The reset option
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`step_async`). This can be caused by two consecutive
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between.
"""
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
self._state.value,
)
for pipe, single_seed in zip(self.parent_pipes, seed):
single_kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
if options is not None:
single_kwargs["options"] = options
pipe.send(("reset", single_kwargs))
self._state = AsyncState.WAITING_RESET
def reset_wait(
self,
timeout: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
seed: ignored
options: ignored
Returns:
A tuple of batched observations and list of dictionaries
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
TimeoutError: If :meth:`reset_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_RESET:
raise NoAsyncCallError(
"Calling `reset_wait` without any prior " "call to `reset_async`.",
AsyncState.WAITING_RESET.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
f"The call to `reset_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment.
Args:
actions: Batch of actions. element of :attr:`~VectorEnv.action_space`
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: If the environment is already waiting for a pending call to another
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
between.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns:
The batched environment step information, (obs, reward, terminated, truncated, info)
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
TimeoutError: If :meth:`step_wait` timed out.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_STEP:
raise NoAsyncCallError(
"Calling `step_wait` without any prior call " "to `step_async`.",
AsyncState.WAITING_STEP.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
successes.append(success)
if success:
obs, rew, terminated, truncated, info = result
observations_list.append(obs)
rewards.append(rew)
terminateds.append(terminated)
truncateds.append(truncated)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space,
observations_list,
self.observations,
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
)
def call_async(self, name: str, *args, **kwargs):
"""Calls the method with name asynchronously and apply args and kwargs to the method.
Args:
name: Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `call_async` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe in self.parent_pipes:
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
"""Calls all parent pipes and waits for the results.
Args:
timeout: Number of seconds before the call to `step_wait` times out.
If `None` (default), the call to `step_wait` never times out.
Returns:
List of the results of the individual calls to the method or property for each environment.
Raises:
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
raise NoAsyncCallError(
"Calling `call_wait` without any prior call to `call_async`.",
AsyncState.WAITING_CALL.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
f"The call to `call_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
return results
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Sets an attribute of the sub-environments.
Args:
name: Name of the property to be set in each individual environment.
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
"""
self._assert_is_running()
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe, value in zip(self.parent_pipes, values):
pipe.send(("_setattr", (name, value)))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
):
"""Close the environments & clean up the extra resources (processes and pipes).
Args:
timeout: Number of seconds before the call to :meth:`close` times out. If ``None``,
the call to :meth:`close` never times out. If the call to :meth:`close`
times out, then all processes are terminated.
terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated.
Raises:
TimeoutError: If :meth:`close` timed out.
"""
timeout = 0 if terminate else timeout
try:
if self._state != AsyncState.DEFAULT:
logger.warn(
f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
)
function = getattr(self, f"{self._state.value}_wait")
function(timeout)
except mp.TimeoutError:
terminate = True
if terminate:
for process in self.processes:
if process.is_alive():
process.terminate()
else:
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.send(("close", None))
for pipe in self.parent_pipes:
if (pipe is not None) and (not pipe.closed):
pipe.recv()
for pipe in self.parent_pipes:
if pipe is not None:
pipe.close()
for process in self.processes:
process.join()
def _poll(self, timeout=None):
self._assert_is_running()
if timeout is None:
return True
end_time = time.perf_counter() + timeout
delta = None
for pipe in self.parent_pipes:
delta = max(end_time - time.perf_counter(), 0)
if pipe is None:
return False
if pipe.closed or (not pipe.poll(delta)):
return False
return True
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
def _assert_is_running(self):
if self.closed:
raise ClosedEnvironmentError(
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
)
def _raise_if_errors(self, successes):
if all(successes):
return
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
)
logger.error(f"Shutting down Worker-{index}.")
self.parent_pipes[index].close()
self.parent_pipes[index] = None
if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
raise exctype(value)
def __del__(self):
"""On deleting the object, checks that the vector environment is closed."""
if not getattr(self, "closed", True) and hasattr(self, "_state"):
self.close(terminate=True)
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
assert shared_memory is None
env = env_fn()
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
if command == "reset":
observation, info = env.reset(**data)
pipe.send(((observation, info), True))
elif command == "step":
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close":
pipe.send((None, True))
break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == env.observation_space, data[1] == env.action_space),
True,
)
)
else:
raise RuntimeError(
f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
pipe.send((None, False))
finally:
env.close()
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
assert shared_memory is not None
env = env_fn()
observation_space = env.observation_space
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
if command == "reset":
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
elif command == "step":
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
pipe.send((None, True))
elif command == "close":
pipe.send((None, True))
break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces":
pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True)
)
else:
raise RuntimeError(
f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_setattr`, `_check_spaces`}."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
pipe.send((None, False))
finally:
env.close()

View File

@ -0,0 +1,235 @@
"""A synchronous vector environment."""
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np
from numpy.typing import NDArray
from gymnasium import Env
from gymnasium.spaces import Space
from gymnasium.vector.utils import concatenate, create_empty_array, iterate
from gymnasium.vector.vector_env import VectorEnv
__all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Example:
>>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Iterable[Callable[[], Env]],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
):
"""Vectorized environment that serially runs multiple environments.
Args:
env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``,
then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
"""
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy
self.metadata = self.envs[0].metadata
if (observation_space is None) or (action_space is None):
observation_space = observation_space or self.envs[0].observation_space
action_space = action_space or self.envs[0].action_space
super().__init__(
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
)
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None
def seed(self, seed: Optional[Union[int, Sequence[int]]] = None):
"""Sets the seed in all sub-environments.
Args:
seed: The seed
"""
super().seed(seed=seed)
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
for env, single_seed in zip(self.envs, seed):
env.seed(single_seed)
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
options: Option information for the environment reset
Returns:
The reset observation of the environment and reset information
"""
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
self._terminateds[:] = False
self._truncateds[:] = False
observations = []
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
self._actions = iterate(self.action_space, actions)
def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Steps through each of the environments returning the batched results.
Returns:
The batched environment step results
"""
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
(
observation,
self._rewards[i],
self._terminateds[i],
self._truncateds[i],
info,
) = env.step(action)
if self._terminateds[i] or self._truncateds[i]:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
)
def call(self, name, *args, **kwargs) -> tuple:
"""Calls the method with name and applies args and kwargs.
Args:
name: The method name
*args: The method args
**kwargs: The method kwargs
Returns:
Tuple of results
"""
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def set_attr(self, name: str, values: Union[list, tuple, Any]):
"""Sets an attribute of the sub-environments.
Args:
name: The property name to change
values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
"""
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]
def _check_spaces(self) -> bool:
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)
return True

View File

@ -0,0 +1,26 @@
"""Module for gymnasium vector utils."""
from gymnasium.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
from gymnasium.vector.utils.numpy_utils import concatenate, create_empty_array
from gymnasium.vector.utils.shared_memory import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.utils.spaces import (
_BaseGymSpaces, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
__all__ = [
"CloudpickleWrapper",
"clear_mpi_env_vars",
"concatenate",
"create_empty_array",
"create_shared_memory",
"read_from_shared_memory",
"write_to_shared_memory",
"BaseGymSpaces",
"batch_space",
"iterate",
]

View File

@ -0,0 +1,61 @@
"""Miscellaneous utilities."""
from __future__ import annotations
import contextlib
import os
from collections.abc import Callable
from gymnasium.core import Env
__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper:
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: Callable[[], Env]):
"""Cloudpickle wrapper for a function."""
self.fn = fn
def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle
return cloudpickle.dumps(self.fn)
def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle
self.fn = pickle.loads(ob)
def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn()
@contextlib.contextmanager
def clear_mpi_env_vars():
"""Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables
temporarily such as when we are starting multiprocessing Processes.
Yields:
Yields for the context manager
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ["OMPI_", "PMI_"]:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)

View File

@ -0,0 +1,146 @@
"""Numpy utility functions: concatenate space samples and create empty array."""
from collections import OrderedDict
from functools import singledispatch
from typing import Callable, Iterable, Union
import numpy as np
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["concatenate", "create_empty_array"]
@singledispatch
def concatenate(
space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray]
) -> Union[tuple, dict, np.ndarray]:
"""Concatenate multiple samples from space into a single object.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box
>>> import numpy as np
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.77395606, 0.43887845, 0.85859793],
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@concatenate.register(Box)
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(space, items, out):
return np.stack(items, axis=0, out=out)
@concatenate.register(Tuple)
def _concatenate_tuple(space, items, out):
return tuple(
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
)
@concatenate.register(Dict)
def _concatenate_dict(space, items, out):
return OrderedDict(
[
(key, concatenate(subspace, [item[key] for item in items], out[key]))
for (key, subspace) in space.spaces.items()
]
)
@concatenate.register(Space)
def _concatenate_custom(space, items, out):
return tuple(items)
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
) -> Union[tuple, dict, np.ndarray]:
"""Create an empty (possibly nested) numpy array.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
# It is possible for the some of the Box low to be greater than 0, then array is not in space
@create_empty_array.register(Box)
# If the Discrete start > 0 or start + length < 0 then array is not in space
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
return fn(shape, dtype=space.dtype)
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
@create_empty_array.register(Dict)
def _create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@ -0,0 +1,191 @@
"""Utility functions for vector environments to share memory between processes."""
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
from typing import Union
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(
space: Space, n: int = 1, ctx=mp
) -> Union[dict, tuple, mp.Array]:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
Returns:
shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
"Cannot create a shared memory for space with "
f"type `{type(space)}`. Shared memory only supports "
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gymnasium spaces."
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1
) -> Union[dict, tuple, np.ndarray]:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
Returns:
Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
"Cannot read from a shared memory for space with "
f"type `{type(space)}`. Shared memory only supports "
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gymnasium spaces."
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: Union[dict, tuple, mp.Array],
):
"""Write the observation of a single environment into shared memory.
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
"Cannot write to a shared memory for space with "
f"type `{type(space)}`. Shared memory only supports "
"default Gymnasium spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gymnasium spaces."
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])

View File

@ -0,0 +1,215 @@
"""Utility functions for gymnasium spaces: batch space and iterator."""
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Iterator
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
_BaseGymSpaces = BaseGymSpaces
__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch
def batch_space(space: Space, n: int = 1) -> Space:
"""Create a (batched) space, containing multiple copies of a single space.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
"""
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
)
@batch_space.register(Box)
def _batch_space_box(space, n=1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
@batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
return MultiDiscrete(
np.full((n,), space.n, dtype=space.dtype),
dtype=space.dtype,
seed=deepcopy(space.np_random),
start=np.full((n,), space.start, dtype=space.dtype),
)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim)
low = np.tile(space.start, repeats)
high = low + np.tile(space.nvec, repeats) - 1
return Box(
low=low,
high=high,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(MultiBinary)
def _batch_space_multibinary(space, n=1):
return Box(
low=0,
high=1,
shape=(n,) + space.shape,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(Tuple)
def _batch_space_tuple(space, n=1):
return Tuple(
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
seed=deepcopy(space.np_random),
)
@batch_space.register(Dict)
def _batch_space_dict(space, n=1):
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
),
seed=deepcopy(space.np_random),
)
@batch_space.register(Space)
def _batch_space_custom(space, n=1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
)
new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n)))
batched_space.seed(new_seeds)
return batched_space
@singledispatch
def iterate(space: Space, items) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
space: Space to which `items` belong to.
items: Items to be iterated over.
Returns:
Iterator over the elements in `items`.
Raises:
ValueError: Space is not an instance of :class:`gym.Space`
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
>>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
>>> next(it)
Traceback (most recent call last):
...
StopIteration
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@iterate.register(Discrete)
def _iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def _iterate_base(space, items):
try:
return iter(items)
except TypeError as e:
raise TypeError(
f"Unable to iterate over the following elements: {items}"
) from e
@iterate.register(Tuple)
def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)
return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)
@iterate.register(Dict)
def _iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
@iterate.register(Space)
def _iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gymnasium.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)

View File

@ -0,0 +1,403 @@
"""Base class for vectorized environments."""
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from numpy.typing import NDArray
import gymnasium as gym
from gymnasium import logger
from gymnasium.vector.utils.spaces import batch_space
__all__ = ["VectorEnv"]
class VectorEnv(gym.Env):
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple
sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have
terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated.
As a result, the final step's observation and info are overwritten by the reset's observation and info.
Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter,
using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information.
The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each
parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment.
Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`.
The Vector Environments have the additional attributes for users to understand the implementation
- :attr:`num_envs` - The number of sub-environment in the vector environment
- :attr:`observation_space` - The batched observation space of the vector environment
- :attr:`single_observation_space` - The observation space of a single sub-environment
- :attr:`action_space` - The batched action space of the vector environment
- :attr:`single_action_space` - The action space of a single sub-environment
Note:
The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list
of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a
dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`.
Note:
To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes`
for all the sub-environments during initialization.
Note:
All parallel environments should share the identical observation and action spaces.
In other words, a vector of multiple different environments is not supported.
"""
def __init__(
self,
num_envs: int,
observation_space: gym.Space,
action_space: gym.Space,
):
"""Base class for vectorized environments.
Args:
num_envs: Number of environments in the vectorized environment.
observation_space: Observation space of a single environment.
action_space: Action space of a single environment.
"""
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = batch_space(action_space, n=num_envs)
self.closed = False
self.viewer = None
# The observation and action spaces of a single environment are
# kept in separate properties
self.single_observation_space = observation_space
self.single_action_space = action_space
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed
by a call to :meth:`reset_wait` to retrieve the results.
Args:
seed: The reset seed
options: Reset options
"""
pass
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`.
Args:
seed: The reset seed
options: Reset options
Returns:
The results from :meth:`reset_async`
Raises:
NotImplementedError: VectorEnv does not implement function
"""
raise NotImplementedError("VectorEnv does not implement function")
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
):
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
seed: The environment reset seeds
options: If to return the options
Returns:
A batch of observations and info from the vectorized environment.
Example:
>>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, options=options)
def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`.
Args:
actions: The actions to take asynchronously
"""
def step_wait(
self, **kwargs
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`.
Args:
**kwargs: Additional keywords for vector implementation
Returns:
The results from the :meth:`step_async` call
"""
raise NotImplementedError()
def step(
self, actions
) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]:
"""Take an action for each parallel environment.
Args:
actions: element of :attr:`action_space` Batch of actions.
Returns:
Batch of (observations, rewards, terminations, truncations, infos)
Note:
As the vector environments autoreset for a terminating and truncating sub-environments,
the returned observation and info is not the final step's observation or info which is instead stored in
info as `"final_observation"` and `"final_info"`.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
[-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> termination
array([False, False, False])
>>> truncation
array([False, False, False])
>>> infos
{}
"""
self.step_async(actions)
return self.step_wait()
def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously."""
def call_wait(self, **kwargs) -> List[Any]: # type: ignore
"""After calling a method in :meth:`call_async`, this function collects the results."""
def call(self, name: str, *args, **kwargs) -> List[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
name (str): Name of the method or property to call.
*args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call.
Returns:
List of the results of the individual calls to the method or property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name: str):
"""Get a property from each parallel environment.
Args:
name (str): Name of the property to be get from each individual environment.
Returns:
The property with name
"""
return self.call(name)
def set_attr(self, name: str, values: Union[list, tuple, object]):
"""Set a property in each sub-environment.
Args:
name (str): Name of the property to be set in each individual environment.
values (list, tuple, or object): Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments.
"""
def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class."""
pass
def close(self, **kwargs):
"""Close all parallel environments and release resources.
It also closes all the existing image viewers, then calls :meth:`close_extras` and set
:attr:`closed` as ``True``.
Warnings:
This function itself does not close the environments, it should be handled
in :meth:`close_extras`. This is generic for both synchronous and asynchronous
vectorized environments.
Note:
This will be automatically called when garbage collected or program exited.
Args:
**kwargs: Keyword arguments passed to :meth:`close_extras`
"""
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras(**kwargs)
self.closed = True
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):
self.close()
def __repr__(self) -> str:
"""Returns a string representation of the vector environment.
Returns:
A string containing the class name, number of environments and environment spec id
"""
if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})"
else:
return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})"
class VectorEnvWrapper(VectorEnv):
"""Wraps the vectorized environment to allow a modular transformation.
This class is the base class for all wrappers for vectorized environments. The subclass
could override some methods to change the behavior of the original vectorized environment
without touching the original code.
Note:
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""
def __init__(self, env: VectorEnv):
assert isinstance(env, VectorEnv)
self.env = env
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset_async(self, **kwargs):
return self.env.reset_async(**kwargs)
def reset_wait(self, **kwargs):
return self.env.reset_wait(**kwargs)
def step_async(self, actions):
return self.env.step_async(actions)
def step_wait(self):
return self.env.step_wait()
def close(self, **kwargs):
return self.env.close(**kwargs)
def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs)
def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs)
def set_attr(self, name, values):
return self.env.set_attr(name, values)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
logger.warn(
f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, "
f"to get this variable you can do `env.unwrapped.{name}` for environment variables."
)
return getattr(self.env, name)
@property
def unwrapped(self):
return self.env.unwrapped
def __repr__(self):
return f"<{self.__class__.__name__}, {self.env}>"
def __del__(self):
self.env.__del__()