Files
2024-10-30 22:14:35 +01:00

75 lines
2.4 KiB
Python

"""
This module contains string subclasses that help to adhere to a uniform
naming scheme for repository ids.
This is especially helpful when pushing or pulling models in an automated fashion, e.g.
when pushing many models from a benchmark such as the RL Baselines3 Zoo.
https://github.com/DLR-RM/rl-baselines3-zoo
"""
# Note: it is best practice to implement __new__ when overriding immutable types
# read more here:
# https://docs.python.org/3/reference/datamodel.html#object.__new__
# https://stackoverflow.com/a/2673863
class EnvironmentName(str):
"""
A name of an environment. Slashes are replaced by dashes so the name can be used
for construction file paths and URLs without accidentally introducing hierarchy.
"""
def __new__(cls, gym_id: str):
normalized_str = gym_id.replace("/", "-")
if ":" in normalized_str:
split_by_colon = normalized_str.split(":")
if len(split_by_colon) == 2:
# split by colon and take the first part
normalized_str = split_by_colon[1]
else:
raise ValueError(
f"Environment name {gym_id} contains more than one colon!"
)
normalized_name = super().__new__(cls, normalized_str)
normalized_name._gym_id = gym_id
return normalized_name
@property
def gym_id(self):
"""
The gym id corresponding to the environment name.
This is the value to be passed to `gym.make`
"""
return self._gym_id
class ModelName(str):
"""
A name of a model. Derived from the used algorithm and the environment that has been
trained on. Since a normalized environment name is used, it is safe to construct
file paths and URLs from the model name.
"""
def __new__(cls, algo_name: str, environment_name: EnvironmentName):
return super().__new__(cls, f"{algo_name}-{environment_name}")
@property
def filename(self):
"""
The filename under which the model is stored
when saving it using `model.save(model_name)`
"""
return f"{self}.zip"
class ModelRepoId(str):
"""
The name of a repository. Derived from the associated organization and the model
name.
"""
def __new__(cls, org: str, model_name: ModelName):
return super().__new__(cls, f"{org}/{model_name}")