467 lines
21 KiB
Python
467 lines
21 KiB
Python
"""
|
|
Save util taken from stable_baselines
|
|
used to serialize data (class parameters) of model classes
|
|
"""
|
|
|
|
import base64
|
|
import functools
|
|
import io
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import pickle
|
|
import warnings
|
|
import zipfile
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import cloudpickle
|
|
import torch as th
|
|
|
|
import stable_baselines3 as sb3
|
|
from stable_baselines3.common.type_aliases import TensorDict
|
|
from stable_baselines3.common.utils import get_device, get_system_info
|
|
|
|
|
|
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
|
|
"""
|
|
Recursive version of getattr
|
|
taken from https://stackoverflow.com/questions/31174295
|
|
|
|
Ex:
|
|
> MyObject.sub_object = SubObject(name='test')
|
|
> recursive_getattr(MyObject, 'sub_object.name') # return test
|
|
:param obj:
|
|
:param attr: Attribute to retrieve
|
|
:return: The attribute
|
|
"""
|
|
|
|
def _getattr(obj: Any, attr: str) -> Any:
|
|
return getattr(obj, attr, *args)
|
|
|
|
return functools.reduce(_getattr, [obj, *attr.split(".")])
|
|
|
|
|
|
def recursive_setattr(obj: Any, attr: str, val: Any) -> None:
|
|
"""
|
|
Recursive version of setattr
|
|
taken from https://stackoverflow.com/questions/31174295
|
|
|
|
Ex:
|
|
> MyObject.sub_object = SubObject(name='test')
|
|
> recursive_setattr(MyObject, 'sub_object.name', 'hello')
|
|
:param obj:
|
|
:param attr: Attribute to set
|
|
:param val: New value of the attribute
|
|
"""
|
|
pre, _, post = attr.rpartition(".")
|
|
return setattr(recursive_getattr(obj, pre) if pre else obj, post, val)
|
|
|
|
|
|
def is_json_serializable(item: Any) -> bool:
|
|
"""
|
|
Test if an object is serializable into JSON
|
|
|
|
:param item: The object to be tested for JSON serialization.
|
|
:return: True if object is JSON serializable, false otherwise.
|
|
"""
|
|
# Try with try-except struct.
|
|
json_serializable = True
|
|
try:
|
|
_ = json.dumps(item)
|
|
except TypeError:
|
|
json_serializable = False
|
|
return json_serializable
|
|
|
|
|
|
def data_to_json(data: Dict[str, Any]) -> str:
|
|
"""
|
|
Turn data (class parameters) into a JSON string for storing
|
|
|
|
:param data: Dictionary of class parameters to be
|
|
stored. Items that are not JSON serializable will be
|
|
pickled with Cloudpickle and stored as bytearray in
|
|
the JSON file
|
|
:return: JSON string of the data serialized.
|
|
"""
|
|
# First, check what elements can not be JSONfied,
|
|
# and turn them into byte-strings
|
|
serializable_data = {}
|
|
for data_key, data_item in data.items():
|
|
# See if object is JSON serializable
|
|
if is_json_serializable(data_item):
|
|
# All good, store as it is
|
|
serializable_data[data_key] = data_item
|
|
else:
|
|
# Not serializable, cloudpickle it into
|
|
# bytes and convert to base64 string for storing.
|
|
# Also store type of the class for consumption
|
|
# from other languages/humans, so we have an
|
|
# idea what was being stored.
|
|
base64_encoded = base64.b64encode(cloudpickle.dumps(data_item)).decode()
|
|
|
|
# Use ":" to make sure we do
|
|
# not override these keys
|
|
# when we include variables of the object later
|
|
cloudpickle_serialization = {
|
|
":type:": str(type(data_item)),
|
|
":serialized:": base64_encoded,
|
|
}
|
|
|
|
# Add first-level JSON-serializable items of the
|
|
# object for further details (but not deeper than this to
|
|
# avoid deep nesting).
|
|
# First we check that object has attributes (not all do,
|
|
# e.g. numpy scalars)
|
|
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
|
|
# Take elements from __dict__ for custom classes
|
|
item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
|
|
for variable_name, variable_item in item_generator():
|
|
# Check if serializable. If not, just include the
|
|
# string-representation of the object.
|
|
if is_json_serializable(variable_item):
|
|
cloudpickle_serialization[variable_name] = variable_item
|
|
else:
|
|
cloudpickle_serialization[variable_name] = str(variable_item)
|
|
|
|
serializable_data[data_key] = cloudpickle_serialization
|
|
json_string = json.dumps(serializable_data, indent=4)
|
|
return json_string
|
|
|
|
|
|
def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Turn JSON serialization of class-parameters back into dictionary.
|
|
|
|
:param json_string: JSON serialization of the class-parameters
|
|
that should be loaded.
|
|
:param custom_objects: Dictionary of objects to replace
|
|
upon loading. If a variable is present in this dictionary as a
|
|
key, it will not be deserialized and the corresponding item
|
|
will be used instead. Similar to custom_objects in
|
|
``keras.models.load_model``. Useful when you have an object in
|
|
file that can not be deserialized.
|
|
:return: Loaded class parameters.
|
|
"""
|
|
if custom_objects is not None and not isinstance(custom_objects, dict):
|
|
raise ValueError("custom_objects argument must be a dict or None")
|
|
|
|
json_dict = json.loads(json_string)
|
|
# This will be filled with deserialized data
|
|
return_data = {}
|
|
for data_key, data_item in json_dict.items():
|
|
if custom_objects is not None and data_key in custom_objects.keys():
|
|
# If item is provided in custom_objects, replace
|
|
# the one from JSON with the one in custom_objects
|
|
return_data[data_key] = custom_objects[data_key]
|
|
elif isinstance(data_item, dict) and ":serialized:" in data_item.keys():
|
|
# If item is dictionary with ":serialized:"
|
|
# key, this means it is serialized with cloudpickle.
|
|
serialization = data_item[":serialized:"]
|
|
# Try-except deserialization in case we run into
|
|
# errors. If so, we can tell bit more information to
|
|
# user.
|
|
try:
|
|
base64_object = base64.b64decode(serialization.encode())
|
|
deserialized_object = cloudpickle.loads(base64_object)
|
|
except (RuntimeError, TypeError, AttributeError) as e:
|
|
warnings.warn(
|
|
f"Could not deserialize object {data_key}. "
|
|
"Consider using `custom_objects` argument to replace "
|
|
"this object.\n"
|
|
f"Exception: {e}"
|
|
)
|
|
else:
|
|
return_data[data_key] = deserialized_object
|
|
else:
|
|
# Read as it is
|
|
return_data[data_key] = data_item
|
|
return return_data
|
|
|
|
|
|
@functools.singledispatch
|
|
def open_path(
|
|
path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None
|
|
) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]:
|
|
"""
|
|
Opens a path for reading or writing with a preferred suffix and raises debug information.
|
|
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
|
|
matches the provided mode, i.e. If the mode is read ("r", "read") it checks that the path is readable.
|
|
If the mode is write ("w", "write") it checks that the file is writable.
|
|
|
|
If the provided path is a string or a pathlib.Path, it ensures that it exists. If the mode is "read"
|
|
it checks that it exists, if it doesn't exist it attempts to read path.suffix if a suffix is provided.
|
|
If the mode is "write" and the path does not exist, it creates all the parent folders. If the path
|
|
points to a folder, it changes the path to path_2. If the path already exists and verbose >= 2,
|
|
it raises a warning.
|
|
|
|
:param path: the path to open.
|
|
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
|
path actually exists. If path is a io.BufferedIOBase the path exists.
|
|
:param mode: how to open the file. "w"|"write" for writing, "r"|"read" for reading.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
|
|
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
|
is not None, we attempt to open the path with the suffix.
|
|
:return:
|
|
"""
|
|
# Note(antonin): the true annotation should be IO[bytes]
|
|
# but there is not easy way to check that
|
|
allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO, io.BufferedRandom)
|
|
if not isinstance(path, allowed_types):
|
|
raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.")
|
|
if path.closed:
|
|
raise ValueError(f"File stream {path} is closed.")
|
|
mode = mode.lower()
|
|
try:
|
|
mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode]
|
|
except KeyError as e:
|
|
raise ValueError("Expected mode to be either 'w' or 'r'.") from e
|
|
if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable():
|
|
error_msg = "writable" if "w" == mode else "readable"
|
|
raise ValueError(f"Expected a {error_msg} file.")
|
|
return path
|
|
|
|
|
|
@open_path.register(str)
|
|
def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
|
|
"""
|
|
Open a path given by a string. If writing to the path, the function ensures
|
|
that the path exists.
|
|
|
|
:param path: the path to open. If mode is "w" then it ensures that the path exists
|
|
by creating the necessary folders and renaming path if it points to a folder.
|
|
:param mode: how to open the file. "w" for writing, "r" for reading.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
|
|
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
|
is not None, we attempt to open the path with the suffix.
|
|
:return:
|
|
"""
|
|
return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix)
|
|
|
|
|
|
@open_path.register(pathlib.Path)
|
|
def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: Optional[str] = None) -> io.BufferedIOBase:
|
|
"""
|
|
Open a path given by a string. If writing to the path, the function ensures
|
|
that the path exists.
|
|
|
|
:param path: the path to check. If mode is "w" then it
|
|
ensures that the path exists by creating the necessary folders and
|
|
renaming path if it points to a folder.
|
|
:param mode: how to open the file. "w" for writing, "r" for reading.
|
|
:param verbose: Verbosity level: 0 for no output, 2 for indicating if path without suffix is not found when mode is "r"
|
|
:param suffix: The preferred suffix. If mode is "w" then the opened file has the suffix.
|
|
If mode is "r" then we attempt to open the path. If an error is raised and the suffix
|
|
is not None, we attempt to open the path with the suffix.
|
|
:return:
|
|
"""
|
|
if mode not in ("w", "r"):
|
|
raise ValueError("Expected mode to be either 'w' or 'r'.")
|
|
|
|
if mode == "r":
|
|
try:
|
|
return open_path(path.open("rb"), mode, verbose, suffix)
|
|
except FileNotFoundError as error:
|
|
if suffix is not None and suffix != "":
|
|
newpath = pathlib.Path(f"{path}.{suffix}")
|
|
if verbose >= 2:
|
|
warnings.warn(f"Path '{path}' not found. Attempting {newpath}.")
|
|
path, suffix = newpath, None
|
|
else:
|
|
raise error
|
|
else:
|
|
try:
|
|
if path.suffix == "" and suffix is not None and suffix != "":
|
|
path = pathlib.Path(f"{path}.{suffix}")
|
|
if path.exists() and path.is_file() and verbose >= 2:
|
|
warnings.warn(f"Path '{path}' exists, will overwrite it.")
|
|
return open_path(path.open("wb"), mode, verbose, suffix)
|
|
except IsADirectoryError:
|
|
warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2")
|
|
path = pathlib.Path(f"{path}_2")
|
|
except FileNotFoundError: # Occurs when the parent folder doesn't exist
|
|
warnings.warn(f"Path '{path.parent}' does not exist. Will create it.")
|
|
path.parent.mkdir(exist_ok=True, parents=True)
|
|
|
|
# if opening was successful uses the open_path() function
|
|
# if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib
|
|
# with corrections
|
|
# if reading failed with FileNotFoundError, calls open_path_pathlib with suffix
|
|
return open_path_pathlib(path, mode, verbose, suffix)
|
|
|
|
|
|
def save_to_zip_file(
|
|
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
|
data: Optional[Dict[str, Any]] = None,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
pytorch_variables: Optional[Dict[str, Any]] = None,
|
|
verbose: int = 0,
|
|
) -> None:
|
|
"""
|
|
Save model data to a zip archive.
|
|
|
|
:param save_path: Where to store the model.
|
|
if save_path is a str or pathlib.Path ensures that the path actually exists.
|
|
:param data: Class parameters being stored (non-PyTorch variables)
|
|
:param params: Model parameters being stored expected to contain an entry for every
|
|
state_dict with its name and the state_dict.
|
|
:param pytorch_variables: Other PyTorch variables expected to contain name and value of the variable.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
file = open_path(save_path, "w", verbose=0, suffix="zip")
|
|
# data/params can be None, so do not
|
|
# try to serialize them blindly
|
|
if data is not None:
|
|
serialized_data = data_to_json(data)
|
|
|
|
# Create a zip-archive and write our objects there.
|
|
with zipfile.ZipFile(file, mode="w") as archive:
|
|
# Do not try to save "None" elements
|
|
if data is not None:
|
|
archive.writestr("data", serialized_data)
|
|
if pytorch_variables is not None:
|
|
with archive.open("pytorch_variables.pth", mode="w", force_zip64=True) as pytorch_variables_file:
|
|
th.save(pytorch_variables, pytorch_variables_file)
|
|
if params is not None:
|
|
for file_name, dict_ in params.items():
|
|
with archive.open(file_name + ".pth", mode="w", force_zip64=True) as param_file:
|
|
th.save(dict_, param_file)
|
|
# Save metadata: library version when file was saved
|
|
archive.writestr("_stable_baselines3_version", sb3.__version__)
|
|
# Save system info about the current python env
|
|
archive.writestr("system_info.txt", get_system_info(print_info=False)[1])
|
|
|
|
if isinstance(save_path, (str, pathlib.Path)):
|
|
file.close()
|
|
|
|
|
|
def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj: Any, verbose: int = 0) -> None:
|
|
"""
|
|
Save an object to path creating the necessary folders along the way.
|
|
If the path exists and is a directory, it will raise a warning and rename the path.
|
|
If a suffix is provided in the path, it will use that suffix, otherwise, it will use '.pkl'.
|
|
|
|
:param path: the path to open.
|
|
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
|
path actually exists. If path is a io.BufferedIOBase the path exists.
|
|
:param obj: The object to save.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
file = open_path(path, "w", verbose=verbose, suffix="pkl")
|
|
# Use protocol>=4 to support saving replay buffers >= 4Gb
|
|
# See https://docs.python.org/3/library/pickle.html
|
|
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
if isinstance(path, (str, pathlib.Path)):
|
|
file.close()
|
|
|
|
|
|
def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: int = 0) -> Any:
|
|
"""
|
|
Load an object from the path. If a suffix is provided in the path, it will use that suffix.
|
|
If the path does not exist, it will attempt to load using the .pkl suffix.
|
|
|
|
:param path: the path to open.
|
|
if save_path is a str or pathlib.Path and mode is "w", single dispatch ensures that the
|
|
path actually exists. If path is a io.BufferedIOBase the path exists.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
"""
|
|
file = open_path(path, "r", verbose=verbose, suffix="pkl")
|
|
obj = pickle.load(file)
|
|
if isinstance(path, (str, pathlib.Path)):
|
|
file.close()
|
|
return obj
|
|
|
|
|
|
def load_from_zip_file(
|
|
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
|
load_data: bool = True,
|
|
custom_objects: Optional[Dict[str, Any]] = None,
|
|
device: Union[th.device, str] = "auto",
|
|
verbose: int = 0,
|
|
print_system_info: bool = False,
|
|
) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]:
|
|
"""
|
|
Load model data from a .zip archive
|
|
|
|
:param load_path: Where to load the model from
|
|
:param load_data: Whether we should load and return data
|
|
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
|
:param custom_objects: Dictionary of objects to replace
|
|
upon loading. If a variable is present in this dictionary as a
|
|
key, it will not be deserialized and the corresponding item
|
|
will be used instead. Similar to custom_objects in
|
|
``keras.models.load_model``. Useful when you have an object in
|
|
file that can not be deserialized.
|
|
:param device: Device on which the code should run.
|
|
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
|
|
:param print_system_info: Whether to print or not the system info
|
|
about the saved model.
|
|
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
|
|
and dict of pytorch variables
|
|
"""
|
|
file = open_path(load_path, "r", verbose=verbose, suffix="zip")
|
|
|
|
# set device to cpu if cuda is not available
|
|
device = get_device(device=device)
|
|
|
|
# Open the zip archive and load data
|
|
try:
|
|
with zipfile.ZipFile(file) as archive:
|
|
namelist = archive.namelist()
|
|
# If data or parameters is not in the
|
|
# zip archive, assume they were stored
|
|
# as None (_save_to_file_zip allows this).
|
|
data = None
|
|
pytorch_variables = None
|
|
params = {}
|
|
|
|
# Debug system info first
|
|
if print_system_info:
|
|
if "system_info.txt" in namelist:
|
|
print("== SAVED MODEL SYSTEM INFO ==")
|
|
print(archive.read("system_info.txt").decode())
|
|
else:
|
|
warnings.warn(
|
|
"The model was saved with SB3 <= 1.2.0 and thus cannot print system information.",
|
|
UserWarning,
|
|
)
|
|
|
|
if "data" in namelist and load_data:
|
|
# Load class parameters that are stored
|
|
# with either JSON or pickle (not PyTorch variables).
|
|
json_data = archive.read("data").decode()
|
|
data = json_to_data(json_data, custom_objects=custom_objects)
|
|
|
|
# Check for all .pth files and load them using th.load.
|
|
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
|
|
# files store state_dicts of variables with custom names (e.g. policy, policy.optimizer)
|
|
pth_files = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth"]
|
|
for file_path in pth_files:
|
|
with archive.open(file_path, mode="r") as param_file:
|
|
# File has to be seekable, but param_file is not, so load in BytesIO first
|
|
# fixed in python >= 3.7
|
|
file_content = io.BytesIO()
|
|
file_content.write(param_file.read())
|
|
# go to start of file
|
|
file_content.seek(0)
|
|
# Load the parameters with the right ``map_location``.
|
|
# Remove ".pth" ending with splitext
|
|
# Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911
|
|
th_object = th.load(file_content, map_location=device, weights_only=False)
|
|
# "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138
|
|
if file_path == "pytorch_variables.pth" or file_path == "tensors.pth":
|
|
# PyTorch variables (not state_dicts)
|
|
pytorch_variables = th_object
|
|
else:
|
|
# State dicts. Store into params dictionary
|
|
# with same name as in .zip file (without .pth)
|
|
params[os.path.splitext(file_path)[0]] = th_object
|
|
except zipfile.BadZipFile as e:
|
|
# load_path wasn't a zip file
|
|
raise ValueError(f"Error: the file {load_path} wasn't a zip-file") from e
|
|
finally:
|
|
if isinstance(load_path, (str, pathlib.Path)):
|
|
file.close()
|
|
return data, params, pytorch_variables
|