90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
"""Wrapper for limiting the time steps of an environment."""
|
|
from __future__ import annotations
|
|
|
|
from copy import deepcopy
|
|
from typing import TYPE_CHECKING
|
|
|
|
import gymnasium as gym
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from gymnasium.envs.registration import EnvSpec
|
|
|
|
|
|
class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|
"""This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
|
|
|
|
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
|
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from gymnasium.wrappers import TimeLimit
|
|
>>> env = gym.make("CartPole-v1")
|
|
>>> env = TimeLimit(env, max_episode_steps=1000)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: gym.Env,
|
|
max_episode_steps: int,
|
|
):
|
|
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
|
"""
|
|
gym.utils.RecordConstructorArgs.__init__(
|
|
self, max_episode_steps=max_episode_steps
|
|
)
|
|
gym.Wrapper.__init__(self, env)
|
|
|
|
self._max_episode_steps = max_episode_steps
|
|
self._elapsed_steps = None
|
|
|
|
def step(self, action):
|
|
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
|
|
|
Args:
|
|
action: The environment step action
|
|
|
|
Returns:
|
|
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
|
|
if the number of steps elapsed >= max episode steps
|
|
|
|
"""
|
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
|
self._elapsed_steps += 1
|
|
|
|
if self._elapsed_steps >= self._max_episode_steps:
|
|
truncated = True
|
|
|
|
return observation, reward, terminated, truncated, info
|
|
|
|
def reset(self, **kwargs):
|
|
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
|
|
|
Args:
|
|
**kwargs: The kwargs to reset the environment with
|
|
|
|
Returns:
|
|
The reset environment
|
|
"""
|
|
self._elapsed_steps = 0
|
|
return self.env.reset(**kwargs)
|
|
|
|
@property
|
|
def spec(self) -> EnvSpec | None:
|
|
"""Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`."""
|
|
if self._cached_spec is not None:
|
|
return self._cached_spec
|
|
|
|
env_spec = self.env.spec
|
|
if env_spec is not None:
|
|
env_spec = deepcopy(env_spec)
|
|
env_spec.max_episode_steps = self._max_episode_steps
|
|
|
|
self._cached_spec = env_spec
|
|
return env_spec
|