69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
"""Wrapper that converts a color observation to grayscale."""
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.spaces import Box
|
|
|
|
|
|
class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
|
|
"""Convert the image observation from RGB to gray scale.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from gymnasium.wrappers import GrayScaleObservation
|
|
>>> env = gym.make("CarRacing-v2")
|
|
>>> env.observation_space
|
|
Box(0, 255, (96, 96, 3), uint8)
|
|
>>> env = GrayScaleObservation(gym.make("CarRacing-v2"))
|
|
>>> env.observation_space
|
|
Box(0, 255, (96, 96), uint8)
|
|
>>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True)
|
|
>>> env.observation_space
|
|
Box(0, 255, (96, 96, 1), uint8)
|
|
"""
|
|
|
|
def __init__(self, env: gym.Env, keep_dim: bool = False):
|
|
"""Convert the image observation from RGB to gray scale.
|
|
|
|
Args:
|
|
env (Env): The environment to apply the wrapper
|
|
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
|
Otherwise, they are of shape AxB.
|
|
"""
|
|
gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim)
|
|
gym.ObservationWrapper.__init__(self, env)
|
|
|
|
self.keep_dim = keep_dim
|
|
|
|
assert (
|
|
isinstance(self.observation_space, Box)
|
|
and len(self.observation_space.shape) == 3
|
|
and self.observation_space.shape[-1] == 3
|
|
)
|
|
|
|
obs_shape = self.observation_space.shape[:2]
|
|
if self.keep_dim:
|
|
self.observation_space = Box(
|
|
low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
|
|
)
|
|
else:
|
|
self.observation_space = Box(
|
|
low=0, high=255, shape=obs_shape, dtype=np.uint8
|
|
)
|
|
|
|
def observation(self, observation):
|
|
"""Converts the colour observation to greyscale.
|
|
|
|
Args:
|
|
observation: Color observations
|
|
|
|
Returns:
|
|
Grayscale observations
|
|
"""
|
|
import cv2
|
|
|
|
observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
|
|
if self.keep_dim:
|
|
observation = np.expand_dims(observation, -1)
|
|
return observation
|