Files
2024-10-30 22:14:35 +01:00

246 lines
8.1 KiB
Python

"""Implementation of a Jax-accelerated pendulum environment."""
from __future__ import annotations
from os import path
from typing import Any, Optional, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
FunctionalJaxEnv,
FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
class PendulumFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional structure."""
max_speed = 8
max_torque = 2.0
dt = 0.05
g = 10.0
m = 1.0
l = 1.0
high_x = jnp.pi
high_y = 1.0
screen_dim = 500
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
def initial(self, rng: PRNGKey):
"""Initial state generation."""
high = jnp.array([self.high_x, self.high_y])
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
) -> jax.Array:
"""Pendulum transition."""
th, thdot = state # th := theta
u = action
g = self.g
m = self.m
l = self.l
dt = self.dt
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
newth = th + newthdot * dt
new_state = jnp.array([newth, newthdot])
return new_state
def observation(self, state: jax.Array) -> jax.Array:
"""Generates an observation based on the state."""
theta, thetadot = state
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
"""Generates the reward based on the state, action and next state."""
th, thdot = state # th := theta
u = action
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi
costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2)
return -costs
def terminal(self, state: StateType) -> bool:
"""Determines if the state is a terminal state."""
return False
def render_image(
self,
state: StateType,
render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an RGB image."""
try:
import pygame
from pygame import gfxdraw
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
screen, clock, last_u = render_state
surf = pygame.Surface((self.screen_dim, self.screen_dim))
surf.fill((255, 255, 255))
bound = 2.2
scale = self.screen_dim / (bound * 2)
offset = self.screen_dim // 2
rod_length = 1 * scale
rod_width = 0.2 * scale
l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
coords = [(l, b), (l, t), (r, t), (r, b)]
transformed_coords = []
for c in coords:
c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2)
c = (c[0] + offset, c[1] + offset)
transformed_coords.append(c)
gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
rod_end = (rod_length, 0)
rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2)
rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset))
gfxdraw.aacircle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
gfxdraw.filled_circle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
img = pygame.image.load(fname)
if last_u is not None:
scale_img = pygame.transform.smoothscale(
img,
(scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2),
)
is_flip = bool(last_u > 0)
scale_img = pygame.transform.flip(scale_img, is_flip, True)
surf.blit(
scale_img,
(
offset - scale_img.get_rect().centerx,
offset - scale_img.get_rect().centery,
),
)
# drawing axle
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
return (screen, clock, last_u), np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
"""Initialises the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
clock = pygame.time.Clock()
return screen, clock, None
def render_close(self, render_state: RenderStateType):
"""Closes the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.display.quit()
pygame.quit()
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
super().__init__(
env,
metadata=self.metadata,
render_mode=render_mode,
)
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(
self,
num_envs: int,
render_mode: str | None = None,
max_episode_steps: int = 200,
**kwargs: Any,
):
"""Constructor for the vectorized CartPole where the kwargs are applied to the functional environment."""
EzPickle.__init__(
self,
num_envs=num_envs,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
**kwargs,
)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
FunctionalJaxVectorEnv.__init__(
self,
func_env=env,
num_envs=num_envs,
metadata=self.metadata,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
)