148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
"""Utility functions for the wrappers."""
|
|
from collections import OrderedDict
|
|
from functools import singledispatch
|
|
|
|
import numpy as np
|
|
|
|
from gymnasium import Space
|
|
from gymnasium.error import CustomSpaceError
|
|
from gymnasium.spaces import (
|
|
Box,
|
|
Dict,
|
|
Discrete,
|
|
Graph,
|
|
GraphInstance,
|
|
MultiBinary,
|
|
MultiDiscrete,
|
|
Sequence,
|
|
Text,
|
|
Tuple,
|
|
)
|
|
from gymnasium.spaces.space import T_cov
|
|
|
|
|
|
__all__ = ["RunningMeanStd", "update_mean_var_count_from_moments", "create_zero_array"]
|
|
|
|
|
|
class RunningMeanStd:
|
|
"""Tracks the mean, variance and count of values."""
|
|
|
|
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
|
def __init__(self, epsilon=1e-4, shape=()):
|
|
"""Tracks the mean, variance and count of values."""
|
|
self.mean = np.zeros(shape, "float64")
|
|
self.var = np.ones(shape, "float64")
|
|
self.count = epsilon
|
|
|
|
def update(self, x):
|
|
"""Updates the mean, var and count from a batch of samples."""
|
|
batch_mean = np.mean(x, axis=0)
|
|
batch_var = np.var(x, axis=0)
|
|
batch_count = x.shape[0]
|
|
self.update_from_moments(batch_mean, batch_var, batch_count)
|
|
|
|
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
|
"""Updates from batch mean, variance and count moments."""
|
|
self.mean, self.var, self.count = update_mean_var_count_from_moments(
|
|
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
|
|
)
|
|
|
|
|
|
def update_mean_var_count_from_moments(
|
|
mean, var, count, batch_mean, batch_var, batch_count
|
|
):
|
|
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
|
|
delta = batch_mean - mean
|
|
tot_count = count + batch_count
|
|
|
|
new_mean = mean + delta * batch_count / tot_count
|
|
m_a = var * count
|
|
m_b = batch_var * batch_count
|
|
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
|
|
new_var = M2 / tot_count
|
|
new_count = tot_count
|
|
|
|
return new_mean, new_var, new_count
|
|
|
|
|
|
@singledispatch
|
|
def create_zero_array(space: Space[T_cov]) -> T_cov:
|
|
"""Creates a zero-based array of a space, this is similar to ``create_empty_array`` except all arrays are valid samples from the space.
|
|
|
|
As some ``Box`` cases have ``high`` or ``low`` that don't contain zero then the ``create_empty_array`` would in case
|
|
create arrays which is not contained in the space.
|
|
|
|
Args:
|
|
space: The space to create a zero array for
|
|
|
|
Returns:
|
|
Valid sample from the space that is as close to zero as possible
|
|
"""
|
|
if isinstance(space, Space):
|
|
raise CustomSpaceError(
|
|
f"Space of type `{type(space)}` doesn't have an registered `create_zero_array` function. Register `{type(space)}` for `create_zero_array` to support it."
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
f"The space provided to `create_zero_array` is not a gymnasium Space instance, type: {type(space)}, {space}"
|
|
)
|
|
|
|
|
|
@create_zero_array.register(Box)
|
|
def _create_box_zero_array(space: Box):
|
|
zero_array = np.zeros(space.shape, dtype=space.dtype)
|
|
zero_array = np.where(space.low > 0, space.low, zero_array)
|
|
zero_array = np.where(space.high < 0, space.high, zero_array)
|
|
return zero_array
|
|
|
|
|
|
@create_zero_array.register(Discrete)
|
|
def _create_discrete_zero_array(space: Discrete):
|
|
return space.start
|
|
|
|
|
|
@create_zero_array.register(MultiDiscrete)
|
|
def _create_multidiscrete_zero_array(space: MultiDiscrete):
|
|
return np.array(space.start, copy=True, dtype=space.dtype)
|
|
|
|
|
|
@create_zero_array.register(MultiBinary)
|
|
def _create_array_zero_array(space: MultiBinary):
|
|
return np.zeros(space.shape, dtype=space.dtype)
|
|
|
|
|
|
@create_zero_array.register(Tuple)
|
|
def _create_tuple_zero_array(space: Tuple):
|
|
return tuple(create_zero_array(subspace) for subspace in space.spaces)
|
|
|
|
|
|
@create_zero_array.register(Dict)
|
|
def _create_dict_zero_array(space: Dict):
|
|
return OrderedDict(
|
|
{key: create_zero_array(subspace) for key, subspace in space.spaces.items()}
|
|
)
|
|
|
|
|
|
@create_zero_array.register(Sequence)
|
|
def _create_sequence_zero_array(space: Sequence):
|
|
if space.stack:
|
|
return create_zero_array(space.stacked_feature_space)
|
|
else:
|
|
return tuple()
|
|
|
|
|
|
@create_zero_array.register(Text)
|
|
def _create_text_zero_array(space: Text):
|
|
return "".join(space.characters[0] for _ in range(space.min_length))
|
|
|
|
|
|
@create_zero_array.register(Graph)
|
|
def _create_graph_zero_array(space: Graph):
|
|
nodes = np.expand_dims(create_zero_array(space.node_space), axis=0)
|
|
if space.edge_space is None:
|
|
return GraphInstance(nodes=nodes, edges=None, edge_links=None)
|
|
else:
|
|
edges = np.expand_dims(create_zero_array(space.edge_space), axis=0)
|
|
edge_links = np.zeros((1, 2), dtype=np.int64)
|
|
return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)
|