307 lines
11 KiB
Python
307 lines
11 KiB
Python
from typing import Dict, SupportsFloat
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium import spaces
|
|
|
|
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
|
|
|
|
try:
|
|
import cv2
|
|
|
|
cv2.ocl.setUseOpenCL(False)
|
|
except ImportError:
|
|
cv2 = None # type: ignore[assignment]
|
|
|
|
|
|
class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Sticky action.
|
|
|
|
Paper: https://arxiv.org/abs/1709.06009
|
|
Official implementation: https://github.com/mgbellemare/Arcade-Learning-Environment
|
|
|
|
:param env: Environment to wrap
|
|
:param action_repeat_probability: Probability of repeating the last action
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
|
|
super().__init__(env)
|
|
self.action_repeat_probability = action_repeat_probability
|
|
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
|
|
|
|
def reset(self, **kwargs) -> AtariResetReturn:
|
|
self._sticky_action = 0 # NOOP
|
|
return self.env.reset(**kwargs)
|
|
|
|
def step(self, action: int) -> AtariStepReturn:
|
|
if self.np_random.random() >= self.action_repeat_probability:
|
|
self._sticky_action = action
|
|
return self.env.step(self._sticky_action)
|
|
|
|
|
|
class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Sample initial states by taking random number of no-ops on reset.
|
|
No-op is assumed to be action 0.
|
|
|
|
:param env: Environment to wrap
|
|
:param noop_max: Maximum value of no-ops to run
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
|
|
super().__init__(env)
|
|
self.noop_max = noop_max
|
|
self.override_num_noops = None
|
|
self.noop_action = 0
|
|
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
|
|
|
|
def reset(self, **kwargs) -> AtariResetReturn:
|
|
self.env.reset(**kwargs)
|
|
if self.override_num_noops is not None:
|
|
noops = self.override_num_noops
|
|
else:
|
|
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
|
assert noops > 0
|
|
obs = np.zeros(0)
|
|
info: Dict = {}
|
|
for _ in range(noops):
|
|
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
|
|
if terminated or truncated:
|
|
obs, info = self.env.reset(**kwargs)
|
|
return obs, info
|
|
|
|
|
|
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Take action on reset for environments that are fixed until firing.
|
|
|
|
:param env: Environment to wrap
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env) -> None:
|
|
super().__init__(env)
|
|
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
|
|
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
|
|
|
|
def reset(self, **kwargs) -> AtariResetReturn:
|
|
self.env.reset(**kwargs)
|
|
obs, _, terminated, truncated, _ = self.env.step(1)
|
|
if terminated or truncated:
|
|
self.env.reset(**kwargs)
|
|
obs, _, terminated, truncated, _ = self.env.step(2)
|
|
if terminated or truncated:
|
|
self.env.reset(**kwargs)
|
|
return obs, {}
|
|
|
|
|
|
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Make end-of-life == end-of-episode, but only reset on true game over.
|
|
Done by DeepMind for the DQN and co. since it helps value estimation.
|
|
|
|
:param env: Environment to wrap
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env) -> None:
|
|
super().__init__(env)
|
|
self.lives = 0
|
|
self.was_real_done = True
|
|
|
|
def step(self, action: int) -> AtariStepReturn:
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
self.was_real_done = terminated or truncated
|
|
# check current lives, make loss of life terminal,
|
|
# then update lives to handle bonus lives
|
|
lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
|
|
if 0 < lives < self.lives:
|
|
# for Qbert sometimes we stay in lives == 0 condition for a few frames
|
|
# so its important to keep lives > 0, so that we only reset once
|
|
# the environment advertises done.
|
|
terminated = True
|
|
self.lives = lives
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, **kwargs) -> AtariResetReturn:
|
|
"""
|
|
Calls the Gym environment reset, only when lives are exhausted.
|
|
This way all states are still reachable even though lives are episodic,
|
|
and the learner need not know about any of this behind-the-scenes.
|
|
|
|
:param kwargs: Extra keywords passed to env.reset() call
|
|
:return: the first observation of the environment
|
|
"""
|
|
if self.was_real_done:
|
|
obs, info = self.env.reset(**kwargs)
|
|
else:
|
|
# no-op step to advance from terminal/lost life state
|
|
obs, _, terminated, truncated, info = self.env.step(0)
|
|
|
|
# The no-op step can lead to a game over, so we need to check it again
|
|
# to see if we should reset the environment and avoid the
|
|
# monitor.py `RuntimeError: Tried to step environment that needs reset`
|
|
if terminated or truncated:
|
|
obs, info = self.env.reset(**kwargs)
|
|
self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
|
|
return obs, info
|
|
|
|
|
|
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Return only every ``skip``-th frame (frameskipping)
|
|
and return the max between the two last frames.
|
|
|
|
:param env: Environment to wrap
|
|
:param skip: Number of ``skip``-th frame
|
|
The same action will be taken ``skip`` times.
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, skip: int = 4) -> None:
|
|
super().__init__(env)
|
|
# most recent raw observations (for max pooling across time steps)
|
|
assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
|
|
assert env.observation_space.shape is not None, "No shape defined for the observation space"
|
|
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
|
|
self._skip = skip
|
|
|
|
def step(self, action: int) -> AtariStepReturn:
|
|
"""
|
|
Step the environment with the given action
|
|
Repeat action, sum reward, and max over last observations.
|
|
|
|
:param action: the action
|
|
:return: observation, reward, terminated, truncated, information
|
|
"""
|
|
total_reward = 0.0
|
|
terminated = truncated = False
|
|
for i in range(self._skip):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
done = terminated or truncated
|
|
if i == self._skip - 2:
|
|
self._obs_buffer[0] = obs
|
|
if i == self._skip - 1:
|
|
self._obs_buffer[1] = obs
|
|
total_reward += float(reward)
|
|
if done:
|
|
break
|
|
# Note that the observation on the done=True frame
|
|
# doesn't matter
|
|
max_frame = self._obs_buffer.max(axis=0)
|
|
|
|
return max_frame, total_reward, terminated, truncated, info
|
|
|
|
|
|
class ClipRewardEnv(gym.RewardWrapper):
|
|
"""
|
|
Clip the reward to {+1, 0, -1} by its sign.
|
|
|
|
:param env: Environment to wrap
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env) -> None:
|
|
super().__init__(env)
|
|
|
|
def reward(self, reward: SupportsFloat) -> float:
|
|
"""
|
|
Bin reward to {+1, 0, -1} by its sign.
|
|
|
|
:param reward:
|
|
:return:
|
|
"""
|
|
return np.sign(float(reward))
|
|
|
|
|
|
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
|
|
"""
|
|
Convert to grayscale and warp frames to 84x84 (default)
|
|
as done in the Nature paper and later work.
|
|
|
|
:param env: Environment to wrap
|
|
:param width: New frame width
|
|
:param height: New frame height
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
|
|
super().__init__(env)
|
|
self.width = width
|
|
self.height = height
|
|
assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}"
|
|
|
|
self.observation_space = spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(self.height, self.width, 1),
|
|
dtype=env.observation_space.dtype, # type: ignore[arg-type]
|
|
)
|
|
|
|
def observation(self, frame: np.ndarray) -> np.ndarray:
|
|
"""
|
|
returns the current observation from a frame
|
|
|
|
:param frame: environment frame
|
|
:return: the observation
|
|
"""
|
|
assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`"
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
|
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
|
return frame[:, :, None]
|
|
|
|
|
|
class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
|
|
"""
|
|
Atari 2600 preprocessings
|
|
|
|
Specifically:
|
|
|
|
* Noop reset: obtain initial state by taking random number of no-ops on reset.
|
|
* Frame skipping: 4 by default
|
|
* Max-pooling: most recent two observations
|
|
* Termination signal when a life is lost.
|
|
* Resize to a square image: 84x84 by default
|
|
* Grayscale observation
|
|
* Clip reward to {-1, 0, 1}
|
|
* Sticky actions: disabled by default
|
|
|
|
See https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/
|
|
for a visual explanation.
|
|
|
|
.. warning::
|
|
Use this wrapper only with Atari v4 without frame skip: ``env_id = "*NoFrameskip-v4"``.
|
|
|
|
:param env: Environment to wrap
|
|
:param noop_max: Max number of no-ops
|
|
:param frame_skip: Frequency at which the agent experiences the game.
|
|
This correspond to repeating the action ``frame_skip`` times.
|
|
:param screen_size: Resize Atari frame
|
|
:param terminal_on_life_loss: If True, then step() returns done=True whenever a life is lost.
|
|
:param clip_reward: If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
|
|
:param action_repeat_probability: Probability of repeating the last action
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env: gym.Env,
|
|
noop_max: int = 30,
|
|
frame_skip: int = 4,
|
|
screen_size: int = 84,
|
|
terminal_on_life_loss: bool = True,
|
|
clip_reward: bool = True,
|
|
action_repeat_probability: float = 0.0,
|
|
) -> None:
|
|
if action_repeat_probability > 0.0:
|
|
env = StickyActionEnv(env, action_repeat_probability)
|
|
if noop_max > 0:
|
|
env = NoopResetEnv(env, noop_max=noop_max)
|
|
# frame_skip=1 is the same as no frame-skip (action repeat)
|
|
if frame_skip > 1:
|
|
env = MaxAndSkipEnv(env, skip=frame_skip)
|
|
if terminal_on_life_loss:
|
|
env = EpisodicLifeEnv(env)
|
|
if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined]
|
|
env = FireResetEnv(env)
|
|
env = WarpFrame(env, width=screen_size, height=screen_size)
|
|
if clip_reward:
|
|
env = ClipRewardEnv(env)
|
|
|
|
super().__init__(env)
|