75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
"""
|
|
This module contains string subclasses that help to adhere to a uniform
|
|
naming scheme for repository ids.
|
|
|
|
This is especially helpful when pushing or pulling models in an automated fashion, e.g.
|
|
when pushing many models from a benchmark such as the RL Baselines3 Zoo.
|
|
https://github.com/DLR-RM/rl-baselines3-zoo
|
|
"""
|
|
|
|
# Note: it is best practice to implement __new__ when overriding immutable types
|
|
# read more here:
|
|
# https://docs.python.org/3/reference/datamodel.html#object.__new__
|
|
# https://stackoverflow.com/a/2673863
|
|
|
|
|
|
class EnvironmentName(str):
|
|
"""
|
|
A name of an environment. Slashes are replaced by dashes so the name can be used
|
|
for construction file paths and URLs without accidentally introducing hierarchy.
|
|
"""
|
|
|
|
def __new__(cls, gym_id: str):
|
|
normalized_str = gym_id.replace("/", "-")
|
|
if ":" in normalized_str:
|
|
split_by_colon = normalized_str.split(":")
|
|
if len(split_by_colon) == 2:
|
|
# split by colon and take the first part
|
|
normalized_str = split_by_colon[1]
|
|
else:
|
|
raise ValueError(
|
|
f"Environment name {gym_id} contains more than one colon!"
|
|
)
|
|
normalized_name = super().__new__(cls, normalized_str)
|
|
normalized_name._gym_id = gym_id
|
|
return normalized_name
|
|
|
|
@property
|
|
def gym_id(self):
|
|
"""
|
|
The gym id corresponding to the environment name.
|
|
|
|
This is the value to be passed to `gym.make`
|
|
"""
|
|
return self._gym_id
|
|
|
|
|
|
class ModelName(str):
|
|
"""
|
|
A name of a model. Derived from the used algorithm and the environment that has been
|
|
trained on. Since a normalized environment name is used, it is safe to construct
|
|
file paths and URLs from the model name.
|
|
"""
|
|
|
|
def __new__(cls, algo_name: str, environment_name: EnvironmentName):
|
|
return super().__new__(cls, f"{algo_name}-{environment_name}")
|
|
|
|
@property
|
|
def filename(self):
|
|
"""
|
|
The filename under which the model is stored
|
|
|
|
when saving it using `model.save(model_name)`
|
|
"""
|
|
return f"{self}.zip"
|
|
|
|
|
|
class ModelRepoId(str):
|
|
"""
|
|
The name of a repository. Derived from the associated organization and the model
|
|
name.
|
|
"""
|
|
|
|
def __new__(cls, org: str, model_name: ModelName):
|
|
return super().__new__(cls, f"{org}/{model_name}")
|