"""Implementation of a Jax-accelerated cartpole environment.""" from __future__ import annotations from typing import Any, Tuple import jax import jax.numpy as jnp import numpy as np from jax.random import PRNGKey import gymnasium as gym from gymnasium.error import DependencyNotInstalled from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.experimental.functional_jax_env import ( FunctionalJaxEnv, FunctionalJaxVectorEnv, ) from gymnasium.utils import EzPickle RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 class CartPoleFunctional( FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType] ): """Cartpole but in jax and functional. Example: >>> import jax >>> import jax.numpy as jnp >>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional >>> key = jax.random.PRNGKey(0) >>> env = CartPoleFunctional({"x_init": 0.5}) >>> state = env.initial(key) >>> print(state) [ 0.46532142 -0.27484107 0.13302994 -0.20361817] >>> print(env.transition(state, 0)) [ 0.4598246 -0.6357784 0.12895757 0.1278053 ] >>> env.transform(jax.jit) >>> state = env.initial(key) >>> print(state) [ 0.46532142 -0.27484107 0.13302994 -0.20361817] >>> print(env.transition(state, 0)) [ 0.4598246 -0.6357784 0.12895757 0.12780523] >>> vkey = jax.random.split(key, 10) >>> env.transform(jax.vmap) >>> vstate = env.initial(vkey) >>> print(vstate) [[ 0.25117755 -0.03159595 0.09428263 0.12404168] [ 0.231457 0.41420317 -0.13484478 0.29151905] [-0.11706758 -0.37130308 0.13587534 0.33141208] [-0.4613737 0.36557996 0.3950702 0.3639989 ] [-0.14707637 -0.34273267 -0.32374108 -0.48110402] [-0.45774353 0.3633288 -0.3157575 -0.03586268] [ 0.37344885 -0.279778 -0.33894253 0.07415426] [-0.20234215 0.39775252 -0.2556088 0.32877135] [-0.2572986 -0.29943776 -0.45600426 -0.35740316] [ 0.05436695 0.35021234 -0.36484408 0.2805779 ]] >>> print(env.transition(vstate, jnp.array([0 for _ in range(10)]))) [[ 0.25054562 -0.38763174 0.09676346 0.4448946 ] [ 0.23974106 0.09849604 -0.1290144 0.5390002 ] [-0.12449364 -0.7323911 0.14250359 0.6634313 ] [-0.45406207 -0.01028753 0.4023502 0.7505522 ] [-0.15393102 -0.6168968 -0.33336315 -0.30407968] [-0.45047694 0.08870795 -0.31647477 0.14311607] [ 0.36785328 -0.54895645 -0.33745944 0.24393772] [-0.19438711 0.10855066 -0.24903338 0.5316877 ] [-0.26328734 -0.5420943 -0.46315232 -0.2344252 ] [ 0.06137119 0.08665388 -0.35923252 0.4403924 ]] """ gravity = 9.8 masscart = 1.0 masspole = 0.1 total_mass = masspole + masscart length = 0.5 polemass_length = masspole + length force_mag = 10.0 tau = 0.02 theta_threshold_radians = 12 * 2 * np.pi / 360 x_threshold = 2.4 x_init = 0.05 screen_width = 600 screen_height = 400 observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32) action_space = gym.spaces.Discrete(2) def initial(self, rng: PRNGKey): """Initial state generation.""" return jax.random.uniform( key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,) ) def transition( self, state: jax.Array, action: int | jax.Array, rng: None = None ) -> StateType: """Cartpole transition.""" x, x_dot, theta, theta_dot = state force = jnp.sign(action - 0.5) * self.force_mag costheta = jnp.cos(theta) sintheta = jnp.sin(theta) # For the interested reader: # https://coneural.org/florian/papers/05_cart_pole.pdf temp = ( force + self.polemass_length * theta_dot**2 * sintheta ) / self.total_mass thetaacc = (self.gravity * sintheta - costheta * temp) / ( self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass) ) xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass x = x + self.tau * x_dot x_dot = x_dot + self.tau * xacc theta = theta + self.tau * theta_dot theta_dot = theta_dot + self.tau * thetaacc state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32) return state def observation(self, state: jax.Array) -> jax.Array: """Cartpole observation.""" return state def terminal(self, state: jax.Array) -> jax.Array: """Checks if the state is terminal.""" x, _, theta, _ = state terminated = ( (x < -self.x_threshold) | (x > self.x_threshold) | (theta < -self.theta_threshold_radians) | (theta > self.theta_threshold_radians) ) return terminated def reward( self, state: StateType, action: ActType, next_state: StateType ) -> jax.Array: """Computes the reward for the state transition using the action.""" x, _, theta, _ = state terminated = ( (x < -self.x_threshold) | (x > self.x_threshold) | (theta < -self.theta_threshold_radians) | (theta > self.theta_threshold_radians) ) reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0) return reward def render_image( self, state: StateType, render_state: RenderStateType, ) -> tuple[RenderStateType, np.ndarray]: """Renders an image of the state using the render state.""" try: import pygame from pygame import gfxdraw except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e screen, clock = render_state world_width = self.x_threshold * 2 scale = self.screen_width / world_width polewidth = 10.0 polelen = scale * (2 * self.length) cartwidth = 50.0 cartheight = 30.0 x = state surf = pygame.Surface((self.screen_width, self.screen_height)) surf.fill((255, 255, 255)) l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 axleoffset = cartheight / 4.0 cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART carty = 100 # TOP OF CART cart_coords = [(l, b), (l, t), (r, t), (r, b)] cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords] gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0)) gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0)) l, r, t, b = ( -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2, ) pole_coords = [] for coord in [(l, b), (l, t), (r, t), (r, b)]: coord = pygame.math.Vector2(coord).rotate_rad(-x[2]) coord = (coord[0] + cartx, coord[1] + carty + axleoffset) pole_coords.append(coord) gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101)) gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101)) gfxdraw.aacircle( surf, int(cartx), int(carty + axleoffset), int(polewidth / 2), (129, 132, 203), ) gfxdraw.filled_circle( surf, int(cartx), int(carty + axleoffset), int(polewidth / 2), (129, 132, 203), ) gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0)) surf = pygame.transform.flip(surf, False, True) screen.blit(surf, (0, 0)) return (screen, clock), np.transpose( np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2) ) def render_init( self, screen_width: int = 600, screen_height: int = 400 ) -> RenderStateType: """Initialises the render state for a screen width and height.""" try: import pygame except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e pygame.init() screen = pygame.Surface((screen_width, screen_height)) clock = pygame.time.Clock() return screen, clock def render_close(self, render_state: RenderStateType) -> None: """Closes the render state.""" try: import pygame except ImportError as e: raise DependencyNotInstalled( "pygame is not installed, run `pip install gymnasium[classic-control]`" ) from e pygame.display.quit() pygame.quit() class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): """Jax-based implementation of the CartPole environment.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 50} def __init__(self, render_mode: str | None = None, **kwargs: Any): """Constructor for the CartPole where the kwargs are applied to the functional environment.""" EzPickle.__init__(self, render_mode=render_mode, **kwargs) env = CartPoleFunctional(**kwargs) env.transform(jax.jit) FunctionalJaxEnv.__init__( self, env, metadata=self.metadata, render_mode=render_mode, ) class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): """Jax-based implementation of the vectorized CartPole environment.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 50} def __init__( self, num_envs: int, render_mode: str | None = None, max_episode_steps: int = 200, **kwargs: Any, ): """Constructor for the vectorized CartPole where the kwargs are applied to the functional environment.""" EzPickle.__init__( self, num_envs=num_envs, render_mode=render_mode, max_episode_steps=max_episode_steps, **kwargs, ) env = CartPoleFunctional(**kwargs) env.transform(jax.jit) FunctionalJaxVectorEnv.__init__( self, func_env=env, num_envs=num_envs, metadata=self.metadata, render_mode=render_mode, max_episode_steps=max_episode_steps, )