I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,147 @@
"""Utility functions for the wrappers."""
from collections import OrderedDict
from functools import singledispatch
import numpy as np
from gymnasium import Space
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
Graph,
GraphInstance,
MultiBinary,
MultiDiscrete,
Sequence,
Text,
Tuple,
)
from gymnasium.spaces.space import T_cov
__all__ = ["RunningMeanStd", "update_mean_var_count_from_moments", "create_zero_array"]
class RunningMeanStd:
"""Tracks the mean, variance and count of values."""
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon
def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)
def update_mean_var_count_from_moments(
mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
@singledispatch
def create_zero_array(space: Space[T_cov]) -> T_cov:
"""Creates a zero-based array of a space, this is similar to ``create_empty_array`` except all arrays are valid samples from the space.
As some ``Box`` cases have ``high`` or ``low`` that don't contain zero then the ``create_empty_array`` would in case
create arrays which is not contained in the space.
Args:
space: The space to create a zero array for
Returns:
Valid sample from the space that is as close to zero as possible
"""
if isinstance(space, Space):
raise CustomSpaceError(
f"Space of type `{type(space)}` doesn't have an registered `create_zero_array` function. Register `{type(space)}` for `create_zero_array` to support it."
)
else:
raise TypeError(
f"The space provided to `create_zero_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
)
@create_zero_array.register(Box)
def _create_box_zero_array(space: Box):
zero_array = np.zeros(space.shape, dtype=space.dtype)
zero_array = np.where(space.low > 0, space.low, zero_array)
zero_array = np.where(space.high < 0, space.high, zero_array)
return zero_array
@create_zero_array.register(Discrete)
def _create_discrete_zero_array(space: Discrete):
return space.start
@create_zero_array.register(MultiDiscrete)
def _create_multidiscrete_zero_array(space: MultiDiscrete):
return np.array(space.start, copy=True, dtype=space.dtype)
@create_zero_array.register(MultiBinary)
def _create_array_zero_array(space: MultiBinary):
return np.zeros(space.shape, dtype=space.dtype)
@create_zero_array.register(Tuple)
def _create_tuple_zero_array(space: Tuple):
return tuple(create_zero_array(subspace) for subspace in space.spaces)
@create_zero_array.register(Dict)
def _create_dict_zero_array(space: Dict):
return OrderedDict(
{key: create_zero_array(subspace) for key, subspace in space.spaces.items()}
)
@create_zero_array.register(Sequence)
def _create_sequence_zero_array(space: Sequence):
if space.stack:
return create_zero_array(space.stacked_feature_space)
else:
return tuple()
@create_zero_array.register(Text)
def _create_text_zero_array(space: Text):
return "".join(space.characters[0] for _ in range(space.min_length))
@create_zero_array.register(Graph)
def _create_graph_zero_array(space: Graph):
nodes = np.expand_dims(create_zero_array(space.node_space), axis=0)
if space.edge_space is None:
return GraphInstance(nodes=nodes, edges=None, edge_links=None)
else:
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
edge_links = np.zeros((1, 2), dtype=np.int64)
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)