I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,3 @@
"""Module for 2d physics environments with functional and environment implementations."""
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional, CartPoleJaxEnv
from gymnasium.envs.phys2d.pendulum import PendulumFunctional, PendulumJaxEnv

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

View File

@ -0,0 +1,34 @@
[remap]
importer="texture"
type="CompressedTexture2D"
uid="uid://vmira7dplrxn"
path="res://.godot/imported/clockwise.png-438b12da40068f686b2bc4d4bddaecd2.ctex"
metadata={
"vram_texture": false
}
[deps]
source_file="res://rl/Lib/site-packages/gymnasium/envs/phys2d/assets/clockwise.png"
dest_files=["res://.godot/imported/clockwise.png-438b12da40068f686b2bc4d4bddaecd2.ctex"]
[params]
compress/mode=0
compress/high_quality=false
compress/lossy_quality=0.7
compress/hdr_compression=1
compress/normal_map=0
compress/channel_pack=0
mipmaps/generate=false
mipmaps/limit=-1
roughness/mode=0
roughness/src_normal=""
process/fix_alpha_border=true
process/premult_alpha=false
process/normal_map_invert_y=false
process/hdr_as_srgb=false
process/hdr_clamp_exposure=false
process/size_limit=0
detect_3d/compress_to=1

View File

@ -0,0 +1,313 @@
"""Implementation of a Jax-accelerated cartpole environment."""
from __future__ import annotations
from typing import Any, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
FunctionalJaxEnv,
FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
class CartPoleFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
"""Cartpole but in jax and functional.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
>>> key = jax.random.PRNGKey(0)
>>> env = CartPoleFunctional({"x_init": 0.5})
>>> state = env.initial(key)
>>> print(state)
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> env.transform(jax.jit)
>>> state = env.initial(key)
>>> print(state)
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.12780523]
>>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey)
>>> print(vstate)
[[ 0.25117755 -0.03159595 0.09428263 0.12404168]
[ 0.231457 0.41420317 -0.13484478 0.29151905]
[-0.11706758 -0.37130308 0.13587534 0.33141208]
[-0.4613737 0.36557996 0.3950702 0.3639989 ]
[-0.14707637 -0.34273267 -0.32374108 -0.48110402]
[-0.45774353 0.3633288 -0.3157575 -0.03586268]
[ 0.37344885 -0.279778 -0.33894253 0.07415426]
[-0.20234215 0.39775252 -0.2556088 0.32877135]
[-0.2572986 -0.29943776 -0.45600426 -0.35740316]
[ 0.05436695 0.35021234 -0.36484408 0.2805779 ]]
>>> print(env.transition(vstate, jnp.array([0 for _ in range(10)])))
[[ 0.25054562 -0.38763174 0.09676346 0.4448946 ]
[ 0.23974106 0.09849604 -0.1290144 0.5390002 ]
[-0.12449364 -0.7323911 0.14250359 0.6634313 ]
[-0.45406207 -0.01028753 0.4023502 0.7505522 ]
[-0.15393102 -0.6168968 -0.33336315 -0.30407968]
[-0.45047694 0.08870795 -0.31647477 0.14311607]
[ 0.36785328 -0.54895645 -0.33745944 0.24393772]
[-0.19438711 0.10855066 -0.24903338 0.5316877 ]
[-0.26328734 -0.5420943 -0.46315232 -0.2344252 ]
[ 0.06137119 0.08665388 -0.35923252 0.4403924 ]]
"""
gravity = 9.8
masscart = 1.0
masspole = 0.1
total_mass = masspole + masscart
length = 0.5
polemass_length = masspole + length
force_mag = 10.0
tau = 0.02
theta_threshold_radians = 12 * 2 * np.pi / 360
x_threshold = 2.4
x_init = 0.05
screen_width = 600
screen_height = 400
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)
def initial(self, rng: PRNGKey):
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
)
def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
force = jnp.sign(action - 0.5) * self.force_mag
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot**2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
)
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)
return state
def observation(self, state: jax.Array) -> jax.Array:
"""Cartpole observation."""
return state
def terminal(self, state: jax.Array) -> jax.Array:
"""Checks if the state is terminal."""
x, _, theta, _ = state
terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
)
return terminated
def reward(
self, state: StateType, action: ActType, next_state: StateType
) -> jax.Array:
"""Computes the reward for the state transition using the action."""
x, _, theta, _ = state
terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
)
reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
return reward
def render_image(
self,
state: StateType,
render_state: RenderStateType,
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an image of the state using the render state."""
try:
import pygame
from pygame import gfxdraw
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
screen, clock = render_state
world_width = self.x_threshold * 2
scale = self.screen_width / world_width
polewidth = 10.0
polelen = scale * (2 * self.length)
cartwidth = 50.0
cartheight = 30.0
x = state
surf = pygame.Surface((self.screen_width, self.screen_height))
surf.fill((255, 255, 255))
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
carty = 100 # TOP OF CART
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0))
gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0))
l, r, t, b = (
-polewidth / 2,
polewidth / 2,
polelen - polewidth / 2,
-polewidth / 2,
)
pole_coords = []
for coord in [(l, b), (l, t), (r, t), (r, b)]:
coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
pole_coords.append(coord)
gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101))
gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101))
gfxdraw.aacircle(
surf,
int(cartx),
int(carty + axleoffset),
int(polewidth / 2),
(129, 132, 203),
)
gfxdraw.filled_circle(
surf,
int(cartx),
int(carty + axleoffset),
int(polewidth / 2),
(129, 132, 203),
)
gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
return (screen, clock), np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
"""Initialises the render state for a screen width and height."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
clock = pygame.time.Clock()
return screen, clock
def render_close(self, render_state: RenderStateType) -> None:
"""Closes the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.display.quit()
pygame.quit()
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = CartPoleFunctional(**kwargs)
env.transform(jax.jit)
FunctionalJaxEnv.__init__(
self,
env,
metadata=self.metadata,
render_mode=render_mode,
)
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(
self,
num_envs: int,
render_mode: str | None = None,
max_episode_steps: int = 200,
**kwargs: Any,
):
"""Constructor for the vectorized CartPole where the kwargs are applied to the functional environment."""
EzPickle.__init__(
self,
num_envs=num_envs,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
**kwargs,
)
env = CartPoleFunctional(**kwargs)
env.transform(jax.jit)
FunctionalJaxVectorEnv.__init__(
self,
func_env=env,
num_envs=num_envs,
metadata=self.metadata,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
)

View File

@ -0,0 +1,245 @@
"""Implementation of a Jax-accelerated pendulum environment."""
from __future__ import annotations
from os import path
from typing import Any, Optional, Tuple
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import (
FunctionalJaxEnv,
FunctionalJaxVectorEnv,
)
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
class PendulumFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
"""Pendulum but in jax and functional structure."""
max_speed = 8
max_torque = 2.0
dt = 0.05
g = 10.0
m = 1.0
l = 1.0
high_x = jnp.pi
high_y = 1.0
screen_dim = 500
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
def initial(self, rng: PRNGKey):
"""Initial state generation."""
high = jnp.array([self.high_x, self.high_y])
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
def transition(
self, state: jax.Array, action: int | jax.Array, rng: None = None
) -> jax.Array:
"""Pendulum transition."""
th, thdot = state # th := theta
u = action
g = self.g
m = self.m
l = self.l
dt = self.dt
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
newth = th + newthdot * dt
new_state = jnp.array([newth, newthdot])
return new_state
def observation(self, state: jax.Array) -> jax.Array:
"""Generates an observation based on the state."""
theta, thetadot = state
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
"""Generates the reward based on the state, action and next state."""
th, thdot = state # th := theta
u = action
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi
costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2)
return -costs
def terminal(self, state: StateType) -> bool:
"""Determines if the state is a terminal state."""
return False
def render_image(
self,
state: StateType,
render_state: tuple[pygame.Surface, pygame.time.Clock, float | None], # type: ignore # noqa: F821
) -> tuple[RenderStateType, np.ndarray]:
"""Renders an RGB image."""
try:
import pygame
from pygame import gfxdraw
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
screen, clock, last_u = render_state
surf = pygame.Surface((self.screen_dim, self.screen_dim))
surf.fill((255, 255, 255))
bound = 2.2
scale = self.screen_dim / (bound * 2)
offset = self.screen_dim // 2
rod_length = 1 * scale
rod_width = 0.2 * scale
l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
coords = [(l, b), (l, t), (r, t), (r, b)]
transformed_coords = []
for c in coords:
c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2)
c = (c[0] + offset, c[1] + offset)
transformed_coords.append(c)
gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
rod_end = (rod_length, 0)
rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2)
rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset))
gfxdraw.aacircle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
gfxdraw.filled_circle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
img = pygame.image.load(fname)
if last_u is not None:
scale_img = pygame.transform.smoothscale(
img,
(scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2),
)
is_flip = bool(last_u > 0)
scale_img = pygame.transform.flip(scale_img, is_flip, True)
surf.blit(
scale_img,
(
offset - scale_img.get_rect().centerx,
offset - scale_img.get_rect().centery,
),
)
# drawing axle
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
return (screen, clock, last_u), np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
"""Initialises the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
clock = pygame.time.Clock()
return screen, clock, None
def render_close(self, render_state: RenderStateType):
"""Closes the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic-control]`"
) from e
pygame.display.quit()
pygame.quit()
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
super().__init__(
env,
metadata=self.metadata,
render_mode=render_mode,
)
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(
self,
num_envs: int,
render_mode: str | None = None,
max_episode_steps: int = 200,
**kwargs: Any,
):
"""Constructor for the vectorized CartPole where the kwargs are applied to the functional environment."""
EzPickle.__init__(
self,
num_envs=num_envs,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
**kwargs,
)
env = PendulumFunctional(**kwargs)
env.transform(jax.jit)
FunctionalJaxVectorEnv.__init__(
self,
func_env=env,
num_envs=num_envs,
metadata=self.metadata,
render_mode=render_mode,
max_episode_steps=max_episode_steps,
)