"""Implementation of a Jax-accelerated pendulum environment.""" from __future__ import annotations from os import path from typing import Any, Optional, Tuple import jax import jax.numpy as jnp import numpy as np from jax.random import PRNGKey import gymnasium as gym from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.experimental.functional_jax_env import ( FunctionalJaxEnv, FunctionalJaxVectorEnv, ) from gymnasium.utils import EzPickle RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821 class PendulumFunctional( FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] ): """Pendulum but in jax and functional structure.""" max_speed = 8 max_torque = 2.0 dt = 0.05 g = 10.0 m = 1.0 l = 1.0 high_x = jnp.pi high_y = 1.0 screen_dim = 500 observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32) action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32) def initial(self, rng: PRNGKey): """Initial state generation.""" high = jnp.array([self.high_x, self.high_y]) return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape) def transition( self, state: jax.Array, action: int | jax.Array, rng: None = None ) -> jax.Array: """Pendulum transition.""" th, thdot = state # th := theta u = action g = self.g m = self.m l = self.l dt = self.dt u = jnp.clip(u, -self.max_torque, self.max_torque)[0] newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) newth = th + newthdot * dt new_state = jnp.array([newth, newthdot]) return new_state def observation(self, state: jax.Array) -> jax.Array: """Generates an observation based on the state.""" theta, thetadot = state return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot]) def reward(self, state: StateType, action: ActType, next_state: StateType) -> float: """Generates the reward based on the state, action and next state.""" th, thdot = state # th := theta u = action u = jnp.clip(u, -self.max_torque, self.max_torque)[0] th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2) return -costs def terminal(self, state: StateType) -> bool: """Determines if the state is a terminal state.""" return False def render_image( self, state: StateType, render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821 ) -> tuple[RenderStateType, np.ndarray]: """Renders an RGB image.""" try: import pygame from pygame import gfxdraw except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e screen, clock, last_u = render_state surf = pygame.Surface((self.screen_dim, self.screen_dim)) surf.fill((255, 255, 255)) bound = 2.2 scale = self.screen_dim / (bound * 2) offset = self.screen_dim // 2 rod_length = 1 * scale rod_width = 0.2 * scale l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2 coords = [(l, b), (l, t), (r, t), (r, b)] transformed_coords = [] for c in coords: c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2) c = (c[0] + offset, c[1] + offset) transformed_coords.append(c) gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77)) gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77)) gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77)) gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77)) rod_end = (rod_length, 0) rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2) rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset)) gfxdraw.aacircle( surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77) ) gfxdraw.filled_circle( surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77) ) fname = path.join(path.dirname(__file__), "assets/clockwise.png") img = pygame.image.load(fname) if last_u is not None: scale_img = pygame.transform.smoothscale( img, (scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2), ) is_flip = bool(last_u > 0) scale_img = pygame.transform.flip(scale_img, is_flip, True) surf.blit( scale_img, ( offset - scale_img.get_rect().centerx, offset - scale_img.get_rect().centery, ), ) # drawing axle gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0)) gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0)) surf = pygame.transform.flip(surf, False, True) screen.blit(surf, (0, 0)) return (screen, clock, last_u), np.transpose( np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2) ) def render_init( self, screen_width: int = 600, screen_height: int = 400 ) -> RenderStateType: """Initialises the render state.""" try: import pygame except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e pygame.init() screen = pygame.Surface((screen_width, screen_height)) clock = pygame.time.Clock() return screen, clock, None def render_close(self, render_state: RenderStateType): """Closes the render state.""" try: import pygame except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e pygame.display.quit() pygame.quit() class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based pendulum environment using the functional version as base.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 30} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor where the kwargs are passed to the base environment to modify the parameters.""" EzPickle.__init__(self, render_mode=render_mode, **kwargs) env = PendulumFunctional(**kwargs) env.transform(jax.jit) super().__init__( env, metadata=self.metadata, render_mode=render_mode, ) class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 50} def __init__( self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs: Any, ): """Constructor for the vectorized CartPole where the kwargs are applied to the functional environment.""" EzPickle.__init__( self, num_envs=num_envs, render_mode=render_mode, max_episode_steps=max_episode_steps, **kwargs, ) env = PendulumFunctional(**kwargs) env.transform(jax.jit) FunctionalJaxVectorEnv.__init__( self, func_env=env, num_envs=num_envs, metadata=self.metadata, render_mode=render_mode, max_episode_steps=max_episode_steps, )