I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,515 @@
"""This module provides a Blackjack functional environment and Gymnasium environment wrapper BlackJackJaxEnv."""
import math
import os
from typing import NamedTuple, Optional, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.random import PRNGKey
from gymnasium import spaces
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv
from gymnasium.utils import EzPickle, seeding
from gymnasium.wrappers import HumanRendering
RenderStateType = Tuple["pygame.Surface", str, int] # type: ignore # noqa: F821
deck = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10])
class EnvState(NamedTuple):
"""A named tuple which contains the full state of the blackjack game."""
dealer_hand: jax.Array
player_hand: jax.Array
dealer_cards: int
player_cards: int
done: int
def cmp(a, b):
"""Returns 1 if a > b, otherwise returns -1."""
return (a > b).astype(int) - (a < b).astype(int)
def random_card(key):
"""Draws a randowm card (with replacement)."""
key = random.split(key)[0]
choice = random.choice(key, deck, shape=(1,))
return choice[0].astype(int), key
def draw_hand(key, hand):
"""Draws a starting hand of two random cards."""
new_card, key = random_card(key)
hand = hand.at[0].set(new_card)
new_card, key = random_card(key)
hand = hand.at[1].set(new_card)
return hand, key
def draw_card(key, hand, index):
"""Draws a new card and adds it to a hand."""
new_card, key = random_card(key)
hand = hand.at[index].set(new_card)
return key, hand, index + 1
def usable_ace(hand):
"""Checks to se if a hand has a usable ace."""
return jnp.logical_and((jnp.count_nonzero(hand == 1) > 0), (sum(hand) + 10 <= 21))
def take(env_state):
"""This function is called if the player has decided to take a card."""
state, key = env_state
dealer_hand = state.dealer_hand
player_hand = state.player_hand
dealer_cards = state.dealer_cards
player_cards = state.player_cards
key, new_player_hand, _ = draw_card(key, player_hand, player_cards)
new_player_cards = player_cards + 1
# done is set to zero here because it is determined later whether the player is bust
return (
EnvState(
dealer_hand=dealer_hand,
player_hand=new_player_hand,
dealer_cards=dealer_cards,
player_cards=new_player_cards,
done=0,
),
key,
)
def dealer_stop(val):
"""This function determines if the dealer should stop drawing."""
return sum_hand(val[1]) < 17
def draw_card_wrapper(val):
"""Wrapper function for draw_card."""
return draw_card(*val)
def notake(env_state):
"""This function is called if the player has decided to not take a card.
Calling this function ends the active portion
of the game and turns control over to the dealer.
"""
state, key = env_state
dealer_hand = state.dealer_hand
player_hand = state.player_hand
dealer_cards = state.dealer_cards
player_cards = state.player_cards
key, dealer_hand, dealer_cards = jax.lax.while_loop(
dealer_stop,
draw_card_wrapper,
(key, dealer_hand, dealer_cards),
)
return (
EnvState(
dealer_hand=dealer_hand,
player_hand=player_hand,
dealer_cards=dealer_cards,
player_cards=player_cards,
done=1,
),
key,
)
def sum_hand(hand):
"""Returns the total points in a hand."""
return sum(hand) + (10 * usable_ace(hand))
def is_bust(hand):
"""Returns whether or not the hand is a bust."""
return sum_hand(hand) > 21
def score(hand):
"""Returns the score for a hand(0 if a bust)."""
return (jnp.logical_not(is_bust(hand))) * sum_hand(hand)
def is_natural(hand):
"""Returns if the hand is a natural blackjack."""
return jnp.logical_and(
jnp.logical_and(
jnp.count_nonzero(hand) == 2, (jnp.count_nonzero(hand == 1) > 0)
),
(jnp.count_nonzero(hand == 10) > 0),
)
class BlackjackFunctional(
FuncEnv[jax.Array, jax.Array, int, float, bool, RenderStateType]
):
"""Blackjack is a card game where the goal is to beat the dealer by obtaining cards that sum to closer to 21 (without going over 21) than the dealers cards.
### Description
Card Values:
- Face cards (Jack, Queen, King) have a point value of 10.
- Aces can either count as 11 (called a 'usable ace') or 1.
- Numerical cards (2-9) have a value equal to their number.
This game is played with an infinite deck (or with replacement).
The game starts with the dealer having one face up and one face down card,
while the player has two face up cards.
The player can request additional cards (hit, action=1) until they decide to stop (stick, action=0)
or exceed 21 (bust, immediate loss).
After the player sticks, the dealer reveals their facedown card, and draws
until their sum is 17 or greater. If the dealer goes bust, the player wins.
If neither the player nor the dealer busts, the outcome (win, lose, draw) is
decided by whose sum is closer to 21.
### Action Space
There are two actions: stick (0), and hit (1).
### Observation Space
The observation consists of a 3-tuple containing: the player's current sum,
the value of the dealer's one showing card (1-10 where 1 is ace),
and whether the player holds a usable ace (0 or 1).
This environment corresponds to the version of the blackjack problem
described in Example 5.1 in Reinforcement Learning: An Introduction
by Sutton and Barto (http://incompleteideas.net/book/the-book-2nd.html).
### Rewards
- win game: +1
- lose game: -1
- draw game: 0
- win game with natural blackjack:
+1.5 (if <a href="#nat">natural</a> is True)
+1 (if <a href="#nat">natural</a> is False)
### Arguments
```
gym.make('Jax-Blackjack-v0', natural=False, sutton_and_barto=False)
```
<a id="nat">`natural=False`</a>: Whether to give an additional reward for
starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
<a id="sutton_and_barto">`sutton_and_barto=False`</a>: Whether to follow the exact rules outlined in the book by
Sutton and Barto. If `sutton_and_barto` is `True`, the keyword argument `natural` will be ignored.
If the player achieves a natural blackjack and the dealer does not, the player
will win (i.e. get a reward of +1). The reverse rule does not apply.
If both the player and the dealer get a natural, it will be a draw (i.e. reward 0).
### Version History
* v0: Initial version release (0.0.0), adapted from original gym blackjack v1
"""
action_space = spaces.Discrete(2)
observation_space = spaces.Box(
low=np.array([1, 1, 0]), high=np.array([32, 11, 1]), shape=(3,), dtype=np.int32
)
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
}
def __init__(self, natural: bool = False, sutton_and_barto: bool = True):
"""Initializes Blackjack functional env."""
# Flag to payout 1.5 on a "natural" blackjack win, like casino rules
# Ref: http://www.bicyclecards.com/how-to-play/blackjack/
self.natural = natural
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sutton_and_barto = sutton_and_barto
def transition(self, state: EnvState, action: Union[int, jax.Array], key: PRNGKey):
"""The blackjack environment's state transition function."""
env_state = jax.lax.cond(action, take, notake, (state, key))
hand_state, key = env_state
dealer_hand = hand_state.dealer_hand
player_hand = hand_state.player_hand
dealer_cards = hand_state.dealer_cards
player_cards = hand_state.player_cards
# note that only a bust or player action ends the round, the player
# can still request another card with 21 cards
done = (is_bust(player_hand) * action) + ((jnp.logical_not(action)) * 1)
new_state = EnvState(
dealer_hand=dealer_hand,
player_hand=player_hand,
dealer_cards=dealer_cards,
player_cards=player_cards,
done=done,
)
return new_state
def initial(self, rng: PRNGKey):
"""Blackjack initial observataion function."""
player_hand = jnp.zeros(21)
dealer_hand = jnp.zeros(21)
player_hand, rng = draw_hand(rng, player_hand)
dealer_hand, rng = draw_hand(rng, dealer_hand)
dealer_cards = 2
player_cards = 2
state = EnvState(
dealer_hand=dealer_hand,
player_hand=player_hand,
dealer_cards=dealer_cards,
player_cards=player_cards,
done=0,
)
return state
def observation(self, state: EnvState) -> jax.Array:
"""Blackjack observation."""
return jnp.array(
[
sum_hand(state.player_hand),
state.dealer_hand[0],
usable_ace(state.player_hand) * 1.0,
],
dtype=np.int32,
)
def terminal(self, state: EnvState) -> jax.Array:
"""Determines if a particular Blackjack observation is terminal."""
return (state.done) > 0
def reward(
self, state: EnvState, action: ActType, next_state: StateType
) -> jax.Array:
"""Calculates reward from a state."""
state = next_state
dealer_hand = state.dealer_hand
player_hand = state.player_hand
# -1 reward if the player busts, otherwise +1 if better than dealer, 0 if tie, -1 if loss.
reward = (
0.0
+ (is_bust(player_hand) * -1 * action)
+ ((jnp.logical_not(action)) * cmp(score(player_hand), score(dealer_hand)))
)
# in the natural setting, if the player wins with a natural blackjack, then reward is 1.5
if self.natural and not self.sutton_and_barto:
condition = jnp.logical_and(is_natural(player_hand), (reward == 1))
reward = reward * jnp.logical_not(condition) + 1.5 * condition
# in the sutton and barto setting, if the player gets a natural blackjack and the dealer gets
# a non-natural blackjack, the player wins. A dealer natural blackjack and a player
# non-natural blackjack should result in a tie.
if self.sutton_and_barto:
condition = jnp.logical_and(
is_natural(player_hand), jnp.logical_not(is_natural(dealer_hand))
)
reward = reward * jnp.logical_not(condition) + 1 * condition
return reward
def render_init(
self, screen_width: int = 600, screen_height: int = 500
) -> RenderStateType:
"""Returns an initial render state."""
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
rng = seeding.np_random(0)[0]
suits = ["C", "D", "H", "S"]
dealer_top_card_suit = rng.choice(suits)
dealer_top_card_value_str = rng.choice(["J", "Q", "K"])
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
return screen, dealer_top_card_value_str, dealer_top_card_suit
def render_image(
self,
state: StateType,
render_state: RenderStateType,
) -> Tuple[RenderStateType, np.ndarray]:
"""Renders an image from a state."""
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[toy_text]`"
)
screen, dealer_top_card_value_str, dealer_top_card_suit = render_state
player_sum, dealer_card_value, usable_ace = self.observation(state)
screen_width, screen_height = 600, 500
card_img_height = screen_height // 3
card_img_width = int(card_img_height * 142 / 197)
spacing = screen_height // 20
bg_color = (7, 99, 36)
white = (255, 255, 255)
if dealer_card_value == 1:
display_card_value = "A"
elif dealer_card_value == 10:
display_card_value = dealer_top_card_value_str
else:
display_card_value = str(math.floor(dealer_card_value))
screen.fill(bg_color)
def get_image(path):
cwd = os.path.dirname(__file__)
cwd = os.path.join(cwd, "..")
cwd = os.path.join(cwd, "toy_text")
image = pygame.image.load(os.path.join(cwd, path))
return image
def get_font(path, size):
cwd = os.path.dirname(__file__)
cwd = os.path.join(cwd, "..")
cwd = os.path.join(cwd, "toy_text")
font = pygame.font.Font(os.path.join(cwd, path), size)
return font
small_font = get_font(
os.path.join("font", "Minecraft.ttf"), screen_height // 15
)
dealer_text = small_font.render(
"Dealer: " + str(dealer_card_value), True, white
)
dealer_text_rect = screen.blit(dealer_text, (spacing, spacing))
def scale_card_img(card_img):
return pygame.transform.scale(card_img, (card_img_width, card_img_height))
dealer_card_img = scale_card_img(
get_image(
os.path.join(
"img",
f"{dealer_top_card_suit}{display_card_value}.png",
)
)
)
dealer_card_rect = screen.blit(
dealer_card_img,
(
screen_width // 2 - card_img_width - spacing // 2,
dealer_text_rect.bottom + spacing,
),
)
hidden_card_img = scale_card_img(get_image(os.path.join("img", "Card.png")))
screen.blit(
hidden_card_img,
(
screen_width // 2 + spacing // 2,
dealer_text_rect.bottom + spacing,
),
)
player_text = small_font.render("Player", True, white)
player_text_rect = screen.blit(
player_text, (spacing, dealer_card_rect.bottom + 1.5 * spacing)
)
large_font = get_font(os.path.join("font", "Minecraft.ttf"), screen_height // 6)
player_sum_text = large_font.render(str(player_sum), True, white)
player_sum_text_rect = screen.blit(
player_sum_text,
(
screen_width // 2 - player_sum_text.get_width() // 2,
player_text_rect.bottom + spacing,
),
)
if usable_ace:
usable_ace_text = small_font.render("usable ace", True, white)
screen.blit(
usable_ace_text,
(
screen_width // 2 - usable_ace_text.get_width() // 2,
player_sum_text_rect.bottom + spacing // 2,
),
)
return render_state, np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_close(self, render_state: RenderStateType) -> None:
"""Closes the render state."""
try:
import pygame
except ImportError as e:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
) from e
pygame.display.quit()
pygame.quit()
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional blackjack env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
"""Initializes Gym wrapper for blackjack functional env."""
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = BlackjackFunctional(**kwargs)
env.transform(jax.jit)
super().__init__(
env,
metadata=self.metadata,
render_mode=render_mode,
)
# Pixel art from Mariia Khmelnytska (https://www.123rf.com/photo_104453049_stock-vector-pixel-art-playing-cards-standart-deck-vector-set.html)
# Jax structure inspired by https://medium.com/@ngoodger_7766/writing-an-rl-environment-in-jax-9f74338898ba
if __name__ == "__main__":
"""
Temporary environment tester function.
"""
env = HumanRendering(BlackJackJaxEnv(render_mode="rgb_array"))
obs, info = env.reset()
print(obs, info)
terminal = False
while not terminal:
action = int(input("Please input an action\n"))
obs, reward, terminal, truncated, info = env.step(action)
print(obs, reward, terminal, truncated, info)
exit()