90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
"""Wrapper for rescaling actions to within a max and min action."""
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.spaces import Box
|
|
|
|
|
|
class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs):
|
|
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
|
|
|
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
|
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from gymnasium.wrappers import RescaleAction
|
|
>>> import numpy as np
|
|
>>> env = gym.make("Hopper-v4")
|
|
>>> _ = env.reset(seed=42)
|
|
>>> obs, _, _, _, _ = env.step(np.array([1,1,1]))
|
|
>>> _ = env.reset(seed=42)
|
|
>>> min_action = -0.5
|
|
>>> max_action = np.array([0.0, 0.5, 0.75])
|
|
>>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action)
|
|
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
|
|
>>> np.alltrue(obs == wrapped_env_obs)
|
|
True
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: gym.Env,
|
|
min_action: Union[float, int, np.ndarray],
|
|
max_action: Union[float, int, np.ndarray],
|
|
):
|
|
"""Initializes the :class:`RescaleAction` wrapper.
|
|
|
|
Args:
|
|
env (Env): The environment to apply the wrapper
|
|
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
|
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
|
"""
|
|
assert isinstance(
|
|
env.action_space, Box
|
|
), f"expected Box action space, got {type(env.action_space)}"
|
|
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
|
|
|
gym.utils.RecordConstructorArgs.__init__(
|
|
self, min_action=min_action, max_action=max_action
|
|
)
|
|
gym.ActionWrapper.__init__(self, env)
|
|
|
|
self.min_action = (
|
|
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
|
)
|
|
self.max_action = (
|
|
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
|
|
)
|
|
|
|
self.action_space = Box(
|
|
low=min_action,
|
|
high=max_action,
|
|
shape=env.action_space.shape,
|
|
dtype=env.action_space.dtype,
|
|
)
|
|
|
|
def action(self, action):
|
|
"""Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.
|
|
|
|
Args:
|
|
action: The action to rescale
|
|
|
|
Returns:
|
|
The rescaled action
|
|
"""
|
|
assert np.all(np.greater_equal(action, self.min_action)), (
|
|
action,
|
|
self.min_action,
|
|
)
|
|
assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action)
|
|
low = self.env.action_space.low
|
|
high = self.env.action_space.high
|
|
action = low + (high - low) * (
|
|
(action - self.min_action) / (self.max_action - self.min_action)
|
|
)
|
|
action = np.clip(action, low, high)
|
|
return action
|