Files
Reinforced-Learning-Godot/rl/Lib/site-packages/stable_baselines3/common/save_util.py
2024-10-30 22:14:35 +01:00

467 lines
21 KiB
Python

"""
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
import base64
import functools
import io
import json
import os
import pathlib
import pickle
import warnings
import zipfile
from typing import Any, Dict, Optional, Tuple, Union
import cloudpickle
import torch as th
import stable_baselines3 as sb3
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device, get_system_info
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
"""
Recursive version of getattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_getattr(MyObject, 'sub_object.name') # return test
:param obj:
:param attr: Attribute to retrieve
:return: The attribute
"""
def _getattr(obj: Any, attr: str) -> Any:
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj, *attr.split(".")])
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
"""
Recursive version of setattr
taken from https://stackoverflow.com/questions/31174295
Ex:
> MyObject.sub_object = SubObject(name='test')
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
:param obj:
:param attr: Attribute to set
:param val: New value of the attribute
"""
pre, _, post = attr.rpartition(".")
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
def is_json_serializable(item: Any) -> bool:
"""
Test if an object is serializable into JSON
:param item: The object to be tested for JSON serialization.
:return: True if object is JSON serializable, false otherwise.
"""
# Try with try-except struct.
json_serializable = True
try:
_ = json.dumps(item)
except TypeError:
json_serializable = False
return json_serializable
def data_to_json(data: Dict[str, Any]) -> str:
"""
Turn data (class parameters) into a JSON string for storing
:param data: Dictionary of class parameters to be
stored. Items that are not JSON serializable will be
pickled with Cloudpickle and stored as bytearray in
the JSON file
:return: JSON string of the data serialized.
"""
# First, check what elements can not be JSONfied,
# and turn them into byte-strings
serializable_data = {}
for data_key, data_item in data.items():
# See if object is JSON serializable
if is_json_serializable(data_item):
# All good, store as it is
serializable_data[data_key] = data_item
else:
# Not serializable, cloudpickle it into
# bytes and convert to base64 string for storing.
# Also store type of the class for consumption
# from other languages/humans, so we have an
# idea what was being stored.
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
# Use ":" to make sure we do
# not override these keys
# when we include variables of the object later
cloudpickle_serialization = {
":type:": str(type(data_item)),
":serialized:": base64_encoded,
}
# Add first-level JSON-serializable items of the
# object for further details (but not deeper than this to
# avoid deep nesting).
# First we check that object has attributes (not all do,
# e.g. numpy scalars)
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
# string-representation of the object.
if is_json_serializable(variable_item):
cloudpickle_serialization[variable_name] = variable_item
else:
cloudpickle_serialization[variable_name] = str(variable_item)
serializable_data[data_key] = cloudpickle_serialization
json_string = json.dumps(serializable_data, indent=4)
return json_string
def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
:param json_string: JSON serialization of the class-parameters
that should be loaded.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:return: Loaded class parameters.
"""
if custom_objects is not None and not isinstance(custom_objects, dict):
raise ValueError("custom_objects argument must be a dict or None")
json_dict = json.loads(json_string)
# This will be filled with deserialized data
return_data = {}
for data_key, data_item in json_dict.items():
if custom_objects is not None and data_key in custom_objects.keys():
# If item is provided in custom_objects, replace
# the one from JSON with the one in custom_objects
return_data[data_key] = custom_objects[data_key]
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
# If item is dictionary with ":serialized:"
# key, this means it is serialized with cloudpickle.
serialization = data_item[":serialized:"]
# Try-except deserialization in case we run into
# errors. If so, we can tell bit more information to
# user.
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except (RuntimeError, TypeError, AttributeError) as e:
warnings.warn(
f"Could not deserialize object {data_key}. "
"Consider using `custom_objects` argument to replace "
"this object.\n"
f"Exception: {e}"
)
else:
return_data[data_key] = deserialized_object
else:
# Read as it is
return_data[data_key] = data_item
return return_data
@functools.singledispatch
def open_path(
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
If the mode is write ("w", "write") it checks that the file is writable.
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2,
it raises a warning.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
# Note(antonin): the true annotation should be IO[bytes]
# but there is not easy way to check that
allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom)
if not isinstance(path, allowed_types):
raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.")
if path.closed:
raise ValueError(f"File stream {path} is closed.")
mode = mode.lower()
try:
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
except KeyError as e:
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
error_msg = "writable" if "w" == mode else "readable"
raise ValueError(f"Expected a {error_msg} file.")
return path
@open_path.register(str)
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to open. If mode is "w" then it ensures that the path exists
by creating the necessary folders and renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix)
@open_path.register(pathlib.Path)
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
:param path: the path to check. If mode is "w" then it
ensures that the path exists by creating the necessary folders and
renaming path if it points to a folder.
:param mode: how to open the file. "w" for writing, "r" for reading.
:param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r"
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
is not None, we attempt to open the path with the suffix.
:return:
"""
if mode not in ("w", "r"):
raise ValueError("Expected mode to be either 'w' or 'r'.")
if mode == "r":
try:
return open_path(path.open("rb"), mode, verbose, suffix)
except FileNotFoundError as error:
if suffix is not None and suffix != "":
newpath = pathlib.Path(f"{path}.{suffix}")
if verbose >= 2:
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
path, suffix = newpath, None
else:
raise error
else:
try:
if path.suffix == "" and suffix is not None and suffix != "":
path = pathlib.Path(f"{path}.{suffix}")
if path.exists() and path.is_file() and verbose >= 2:
warnings.warn(f"Path '{path}' exists, will overwrite it.")
return open_path(path.open("wb"), mode, verbose, suffix)
except IsADirectoryError:
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
path = pathlib.Path(f"{path}_2")
except FileNotFoundError: # Occurs when the parent folder doesn't exist
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
path.parent.mkdir(exist_ok=True, parents=True)
# if opening was successful uses the open_path() function
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
# with corrections
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
return open_path_pathlib(path, mode, verbose, suffix)
def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
pytorch_variables: Optional[Dict[str, Any]] = None,
verbose: int = 0,
) -> None:
"""
Save model data to a zip archive.
:param save_path: Where to store the model.
if save_path is a str or pathlib.Path ensures that the path actually exists.
:param data: Class parameters being stored (non-PyTorch variables)
:param params: Model parameters being stored expected to contain an entry for every
state_dict with its name and the state_dict.
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(save_path, "w", verbose=0, suffix="zip")
# data/params can be None, so do not
# try to serialize them blindly
if data is not None:
serialized_data = data_to_json(data)
# Create a zip-archive and write our objects there.
with zipfile.ZipFile(file, mode="w") as archive:
# Do not try to save "None" elements
if data is not None:
archive.writestr("data", serialized_data)
if pytorch_variables is not None:
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
th.save(pytorch_variables, pytorch_variables_file)
if params is not None:
for file_name, dict_ in params.items():
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
th.save(dict_, param_file)
# Save metadata: library version when file was saved
archive.writestr("_stable_baselines3_version", sb3.__version__)
# Save system info about the current python env
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
if isinstance(save_path, (str, pathlib.Path)):
file.close()
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param obj: The object to save.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "w", verbose=verbose, suffix="pkl")
# Use protocol>=4 to support saving replay buffers >= 4Gb
# See https://docs.python.org/3/library/pickle.html
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
if isinstance(path, (str, pathlib.Path)):
file.close()
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
"""
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
If the path does not exist, it will attempt to load using the .pkl suffix.
:param path: the path to open.
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
path actually exists. If path is a io.BufferedIOBase the path exists.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
file = open_path(path, "r", verbose=verbose, suffix="pkl")
obj = pickle.load(file)
if isinstance(path, (str, pathlib.Path)):
file.close()
return obj
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
custom_objects: Optional[Dict[str, Any]] = None,
device: Union[th.device, str] = "auto",
verbose: int = 0,
print_system_info: bool = False,
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
"""
Load model data from a .zip archive
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param device: Device on which the code should run.
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
:param print_system_info: Whether to print or not the system info
about the saved model.
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
"""
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
# set device to cpu if cuda is not available
device = get_device(device=device)
# Open the zip archive and load data
try:
with zipfile.ZipFile(file) as archive:
namelist = archive.namelist()
# If data or parameters is not in the
# zip archive, assume they were stored
# as None (_save_to_file_zip allows this).
data = None
pytorch_variables = None
params = {}
# Debug system info first
if print_system_info:
if "system_info.txt" in namelist:
print("== SAVED MODEL SYSTEM INFO ==")
print(archive.read("system_info.txt").decode())
else:
warnings.warn(
"The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
UserWarning,
)
if "data" in namelist and load_data:
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data, custom_objects=custom_objects)
# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
for file_path in pth_files:
with archive.open(file_path, mode="r") as param_file:
# File has to be seekable, but param_file is not, so load in BytesIO first
# fixed in python >= 3.7
file_content = io.BytesIO()
file_content.write(param_file.read())
# go to start of file
file_content.seek(0)
# Load the parameters with the right ``map_location``.
# Remove ".pth" ending with splitext
# Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
th_object = th.load(file_content, map_location=device, weights_only=False)
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
# PyTorch variables (not state_dicts)
pytorch_variables = th_object
else:
# State dicts. Store into params dictionary
# with same name as in .zip file (without .pth)
params[os.path.splitext(file_path)[0]] = th_object
except zipfile.BadZipFile as e:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
finally:
if isinstance(load_path, (str, pathlib.Path)):
file.close()
return data, params, pytorch_variables