1860 lines
72 KiB
Python
1860 lines
72 KiB
Python
# mypy: allow-untyped-defs
|
|
import copyreg
|
|
import difflib
|
|
import functools
|
|
import io
|
|
import os
|
|
import pickle
|
|
import re
|
|
import shutil
|
|
import struct
|
|
import sys
|
|
import tarfile
|
|
import tempfile
|
|
import threading
|
|
import warnings
|
|
from contextlib import closing, contextmanager
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
BinaryIO,
|
|
Callable,
|
|
cast,
|
|
Dict,
|
|
IO,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
|
|
|
|
import torch
|
|
import torch._weights_only_unpickler as _weights_only_unpickler
|
|
from torch._sources import get_source_lines_and_file
|
|
from torch._utils import _import_dotted_name
|
|
from torch.storage import _get_dtype_from_pickle_storage_type
|
|
from torch.types import Storage
|
|
|
|
|
|
__all__ = [
|
|
"SourceChangeWarning",
|
|
"mkdtemp",
|
|
"register_package",
|
|
"check_module_version_greater_or_equal",
|
|
"validate_cuda_device",
|
|
"validate_hpu_device",
|
|
"location_tag",
|
|
"default_restore_location",
|
|
"normalize_storage_type",
|
|
"storage_to_tensor_type",
|
|
"save",
|
|
"load",
|
|
"StorageType",
|
|
"LoadEndianness",
|
|
"get_default_load_endianness",
|
|
"set_default_load_endianness",
|
|
"get_default_mmap_options",
|
|
"set_default_mmap_options",
|
|
"clear_safe_globals",
|
|
"get_safe_globals",
|
|
"add_safe_globals",
|
|
"safe_globals",
|
|
"skip_data",
|
|
]
|
|
|
|
|
|
DEFAULT_PROTOCOL = 2
|
|
|
|
LONG_SIZE = struct.Struct("=l").size
|
|
INT_SIZE = struct.Struct("=i").size
|
|
SHORT_SIZE = struct.Struct("=h").size
|
|
|
|
MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
|
|
PROTOCOL_VERSION = 1001
|
|
STORAGE_KEY_SEPARATOR = ","
|
|
|
|
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
|
|
MAP_LOCATION: TypeAlias = Optional[
|
|
Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
|
|
]
|
|
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
|
|
|
|
IS_WINDOWS = sys.platform == "win32"
|
|
|
|
if not IS_WINDOWS:
|
|
from mmap import MAP_PRIVATE, MAP_SHARED
|
|
else:
|
|
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
|
|
|
|
|
|
# _serialization_tls is used to store thread local state specific to serialization
|
|
# that needs to be propagated to other files, in particular we use this for
|
|
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
|
|
# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
|
|
# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
|
|
class _SerializationLocal(threading.local):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.map_location: Optional[MAP_LOCATION] = None
|
|
self.skip_data: bool = False
|
|
self.materialize_fake_tensors: bool = False
|
|
|
|
|
|
_serialization_tls = _SerializationLocal()
|
|
|
|
|
|
class SourceChangeWarning(Warning):
|
|
pass
|
|
|
|
|
|
@contextmanager
|
|
def mkdtemp():
|
|
path = tempfile.mkdtemp()
|
|
try:
|
|
yield path
|
|
finally:
|
|
shutil.rmtree(path)
|
|
|
|
|
|
_package_registry: List[
|
|
Tuple[
|
|
int,
|
|
Callable[[STORAGE], Optional[str]],
|
|
Callable[[STORAGE, str], Optional[STORAGE]],
|
|
]
|
|
] = []
|
|
|
|
|
|
class LoadEndianness(Enum):
|
|
NATIVE = 1
|
|
LITTLE = 2
|
|
BIG = 3
|
|
|
|
|
|
_default_load_endian: Optional[LoadEndianness] = None
|
|
|
|
|
|
def get_default_load_endianness() -> Optional[LoadEndianness]:
|
|
"""
|
|
Get fallback byte order for loading files
|
|
|
|
If byteorder mark is not present in saved checkpoint,
|
|
this byte order is used as fallback.
|
|
By default, it's "native" byte order.
|
|
|
|
Returns:
|
|
default_load_endian: Optional[LoadEndianness]
|
|
"""
|
|
return _default_load_endian
|
|
|
|
|
|
def set_default_load_endianness(endianness):
|
|
"""
|
|
Set fallback byte order for loading files
|
|
|
|
If byteorder mark is not present in saved checkpoint,
|
|
this byte order is used as fallback.
|
|
By default, it's "native" byte order.
|
|
|
|
Args:
|
|
endianness: the new fallback byte order
|
|
"""
|
|
global _default_load_endian
|
|
if not isinstance(endianness, LoadEndianness) and endianness is not None:
|
|
raise TypeError("Invalid argument type in function set_default_load_endianness")
|
|
_default_load_endian = endianness
|
|
|
|
|
|
_default_mmap_options: int = MAP_PRIVATE
|
|
|
|
|
|
def get_default_mmap_options() -> int:
|
|
"""
|
|
Get default mmap options for :func:`torch.load` with ``mmap=True``.
|
|
|
|
Defaults to ``mmap.MAP_PRIVATE``.
|
|
|
|
|
|
Returns:
|
|
default_mmap_options: int
|
|
"""
|
|
return _default_mmap_options
|
|
|
|
|
|
class set_default_mmap_options:
|
|
"""
|
|
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
|
|
|
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
|
|
Please open an issue if you need any other option to be added here.
|
|
|
|
.. note::
|
|
This feature is currently not supported for Windows.
|
|
|
|
Args:
|
|
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
|
|
"""
|
|
|
|
def __init__(self, flags: int) -> None:
|
|
if IS_WINDOWS:
|
|
raise RuntimeError(
|
|
"Changing the default mmap options is currently not supported for Windows"
|
|
)
|
|
if flags != MAP_PRIVATE and flags != MAP_SHARED:
|
|
raise ValueError(
|
|
"Invalid argument in function set_default_mmap_options, "
|
|
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
|
)
|
|
global _default_mmap_options
|
|
self.prev = _default_mmap_options
|
|
_default_mmap_options = flags
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
global _default_mmap_options
|
|
_default_mmap_options = self.prev
|
|
|
|
|
|
def clear_safe_globals() -> None:
|
|
"""
|
|
Clears the list of globals that are safe for ``weights_only`` load.
|
|
"""
|
|
_weights_only_unpickler._clear_safe_globals()
|
|
|
|
|
|
def get_safe_globals() -> List[Any]:
|
|
"""
|
|
Returns the list of user-added globals that are safe for ``weights_only`` load.
|
|
"""
|
|
return _weights_only_unpickler._get_safe_globals()
|
|
|
|
|
|
def add_safe_globals(safe_globals: List[Any]) -> None:
|
|
"""
|
|
Marks the given globals as safe for ``weights_only`` load. For example, functions
|
|
added to this list can be called during unpickling, classes could be instantiated
|
|
and have state set.
|
|
|
|
Args:
|
|
safe_globals (List[Any]): list of globals to mark as safe
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
|
|
>>> import tempfile
|
|
>>> class MyTensor(torch.Tensor):
|
|
... pass
|
|
>>> t = MyTensor(torch.randn(2, 3))
|
|
>>> with tempfile.NamedTemporaryFile() as f:
|
|
... torch.save(t, f.name)
|
|
# Running `torch.load(f.name, weights_only=True)` will fail with
|
|
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
|
|
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
|
|
... torch.serialization.add_safe_globals([MyTensor])
|
|
... torch.load(f.name, weights_only=True)
|
|
# MyTensor([[-0.5024, -1.8152, -0.5455],
|
|
# [-0.8234, 2.0500, -0.3657]])
|
|
"""
|
|
_weights_only_unpickler._add_safe_globals(safe_globals)
|
|
|
|
|
|
class safe_globals(_weights_only_unpickler._safe_globals):
|
|
r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
|
|
|
|
Args:
|
|
safe_globals: List of globals for weights_only load.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
|
|
>>> import tempfile
|
|
>>> class MyTensor(torch.Tensor):
|
|
... pass
|
|
>>> t = MyTensor(torch.randn(2, 3))
|
|
>>> with tempfile.NamedTemporaryFile() as f:
|
|
... torch.save(t, f.name)
|
|
# Running `torch.load(f.name, weights_only=True)` will fail with
|
|
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
|
|
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
|
|
... with torch.serialization.safe_globals([MyTensor]):
|
|
... torch.load(f.name, weights_only=True)
|
|
# MyTensor([[-0.5024, -1.8152, -0.5455],
|
|
# [-0.8234, 2.0500, -0.3657]])
|
|
>>> assert torch.serialization.get_safe_globals() == []
|
|
"""
|
|
|
|
|
|
class skip_data:
|
|
"""
|
|
Context-manager that skips writing storage bytes for ``torch.save`` calls.
|
|
|
|
Storages will still be saved, but the space that their bytes would usually be written to
|
|
will be empty space. The storage bytes can then be populated in a separate pass.
|
|
|
|
.. warning::
|
|
The ``skip_data`` context manager is an early prototype and is subject to change.
|
|
|
|
Args:
|
|
materialize_fake_tensors: Whether to materialize FakeTensors.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
|
|
>>> import tempfile
|
|
>>> t = torch.randn(2, 3)
|
|
>>> with tempfile.NamedTemporaryFile() as f:
|
|
... with torch.serialization.skip_data():
|
|
... torch.save(t, f.name)
|
|
... torch.load(f.name, weights_only=True)
|
|
tensor([[0., 0., 0.],
|
|
[0., 0., 0.]])
|
|
"""
|
|
|
|
def __init__(self, materialize_fake_tensors: bool = False):
|
|
self.materialize_fake_tensors = materialize_fake_tensors
|
|
|
|
def __enter__(self):
|
|
global _serialization_tls
|
|
self._old_skip_data = _serialization_tls.skip_data
|
|
self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
|
|
_serialization_tls.skip_data = True
|
|
_serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
|
|
|
|
def __exit__(self, type, value, tb):
|
|
global _serialization_tls
|
|
_serialization_tls.skip_data = self._old_skip_data
|
|
_serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
|
|
|
|
|
|
def _is_zipfile(f) -> bool:
|
|
# This is a stricter implementation than zipfile.is_zipfile().
|
|
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
|
# binary. Since we expect the files here to be generated by torch.save or
|
|
# torch.jit.save, it's safe to only check the start bytes and avoid
|
|
# collisions and assume the zip has only 1 file.
|
|
# See bugs.python.org/issue28494.
|
|
|
|
start = f.tell()
|
|
# Read the first few bytes and match against the ZIP file signature
|
|
local_header_magic_number = b"PK\x03\x04"
|
|
read_bytes = f.read(len(local_header_magic_number))
|
|
f.seek(start)
|
|
return read_bytes == local_header_magic_number
|
|
|
|
|
|
def register_package(
|
|
priority: int,
|
|
tagger: Callable[[STORAGE], Optional[str]],
|
|
deserializer: Callable[[STORAGE, str], Optional[STORAGE]],
|
|
):
|
|
"""
|
|
Registers callables for tagging and deserializing storage objects with an associated priority.
|
|
Tagging associates a device with a storage object at save time while deserializing moves a
|
|
storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
|
|
are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
|
|
value that is not `None`.
|
|
|
|
To override the deserialization behavior for a device in the global registry, one can register a
|
|
tagger with a higher priority than the existing tagger.
|
|
|
|
This function can also be used to register a tagger and deserializer for new devices.
|
|
|
|
Args:
|
|
priority: Indicates the priority associated with the tagger and deserializer, where a lower
|
|
value indicates higher priority.
|
|
tagger: Callable that takes in a storage object and returns its tagged device as a string
|
|
or None.
|
|
deserializer: Callable that takes in storage object and a device string and returns a storage
|
|
object on the appropriate device or None.
|
|
|
|
Returns:
|
|
`None`
|
|
|
|
Example:
|
|
>>> def ipu_tag(obj):
|
|
>>> if obj.device.type == 'ipu':
|
|
>>> return 'ipu'
|
|
>>> def ipu_deserialize(obj, location):
|
|
>>> if location.startswith('ipu'):
|
|
>>> ipu = getattr(torch, "ipu", None)
|
|
>>> assert ipu is not None, "IPU device module is not loaded"
|
|
>>> assert torch.ipu.is_available(), "ipu is not available"
|
|
>>> return obj.ipu(location)
|
|
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
|
|
"""
|
|
queue_elem = (priority, tagger, deserializer)
|
|
_package_registry.append(queue_elem)
|
|
_package_registry.sort()
|
|
|
|
|
|
def check_module_version_greater_or_equal(
|
|
module,
|
|
req_version_tuple,
|
|
error_if_malformed=True,
|
|
):
|
|
"""
|
|
Check if a module's version satisfies requirements
|
|
|
|
Usually, a module's version string will be like 'x.y.z', which would be represented
|
|
as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
|
|
string does not match the given tuple's format up to the length of the tuple, then
|
|
error and exit or emit a warning.
|
|
|
|
Args:
|
|
module: the module to check the version of
|
|
req_version_tuple: tuple (usually of ints) representing the required version
|
|
error_if_malformed: whether we should exit if module version string is malformed
|
|
|
|
Returns:
|
|
requirement_is_met: bool
|
|
"""
|
|
try:
|
|
version_strs = module.__version__.split(".")
|
|
# Cast module version fields to match the types of the required version
|
|
module_version = tuple(
|
|
type(req_field)(version_strs[idx])
|
|
for idx, req_field in enumerate(req_version_tuple)
|
|
)
|
|
requirement_is_met = module_version >= req_version_tuple
|
|
|
|
except Exception as e:
|
|
message = (
|
|
f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
|
|
f" with tuple {str(req_version_tuple)}"
|
|
)
|
|
if error_if_malformed:
|
|
raise RuntimeError(message) from e
|
|
else:
|
|
warnings.warn(message + ", but continuing assuming that requirement is met")
|
|
requirement_is_met = True
|
|
|
|
return requirement_is_met
|
|
|
|
|
|
def _cpu_tag(obj):
|
|
if obj.device.type == "cpu":
|
|
return "cpu"
|
|
|
|
|
|
def _mps_tag(obj):
|
|
if obj.device.type == "mps":
|
|
return "mps"
|
|
|
|
|
|
def _meta_tag(obj):
|
|
if obj.device.type == "meta":
|
|
return "meta"
|
|
|
|
|
|
def _backend_tag(backend_name, obj):
|
|
if backend_name == "privateuse1":
|
|
backend_name = torch._C._get_privateuse1_backend_name()
|
|
if obj.device.type == backend_name:
|
|
if obj.device.index is None:
|
|
return backend_name
|
|
else:
|
|
return backend_name + ":" + str(obj.device.index)
|
|
|
|
|
|
def _cpu_deserialize(obj, location):
|
|
if location == "cpu":
|
|
return obj
|
|
|
|
|
|
def _mps_deserialize(obj, location):
|
|
if location.startswith("mps"):
|
|
return obj.mps()
|
|
|
|
|
|
def _meta_deserialize(obj, location):
|
|
if location == "meta":
|
|
return torch.UntypedStorage(obj.nbytes(), device="meta")
|
|
|
|
|
|
def _validate_device(location, backend_name):
|
|
"""
|
|
Check whether the device index of specified backend is valid
|
|
|
|
In case of privateuse1 backend, your must first register a device_module for
|
|
privateuse1 using torch._register_device_module. Implement the following
|
|
methods in device_module like cuda: device_module._utils._get_device_index(location, True),
|
|
device_module.device_count().
|
|
|
|
Args:
|
|
location: string of device
|
|
backend_name: the backend name or the name of privateuse1, which can be renamed
|
|
|
|
Returns:
|
|
device_index: int
|
|
"""
|
|
if not hasattr(torch, backend_name):
|
|
raise RuntimeError(
|
|
f"The {backend_name.upper()} device module is not registered. "
|
|
"If you are running on a CPU-only machine, "
|
|
"please use torch.load with map_location=torch.device('cpu') "
|
|
"to map your storages to the CPU."
|
|
)
|
|
device_module = getattr(torch, backend_name)
|
|
if hasattr(device_module, "_utils") and hasattr(
|
|
device_module._utils, "_get_device_index"
|
|
):
|
|
device_index = device_module._utils._get_device_index(location, True)
|
|
device = torch.device(backend_name, device_index)
|
|
else:
|
|
device = torch.device(location)
|
|
device_index = device.index if device.index else 0
|
|
if hasattr(device_module, "is_available") and not device_module.is_available():
|
|
raise RuntimeError(
|
|
f"Attempting to deserialize object on a {backend_name.upper()} "
|
|
f"device but torch.{backend_name}.is_available() is False. "
|
|
"If you are running on a CPU-only machine, "
|
|
"please use torch.load with map_location=torch.device('cpu') "
|
|
"to map your storages to the CPU."
|
|
)
|
|
if hasattr(device_module, "device_count"):
|
|
device_count = device_module.device_count()
|
|
if device_index >= device_count:
|
|
raise RuntimeError(
|
|
f"Attempting to deserialize object on {backend_name.upper()} device "
|
|
f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
|
|
"Please use torch.load with map_location to map your storages "
|
|
"to an existing device."
|
|
)
|
|
return device
|
|
|
|
|
|
def validate_cuda_device(location):
|
|
return _validate_device(location, "cuda").index
|
|
|
|
|
|
def validate_hpu_device(location):
|
|
return _validate_device(location, "hpu").index
|
|
|
|
|
|
def _deserialize(backend_name, obj, location):
|
|
if backend_name == "privateuse1":
|
|
backend_name = torch._C._get_privateuse1_backend_name()
|
|
if location.startswith(backend_name):
|
|
device = _validate_device(location, backend_name)
|
|
return obj.to(device=device)
|
|
|
|
|
|
register_package(10, _cpu_tag, _cpu_deserialize)
|
|
register_package(
|
|
20,
|
|
functools.partial(_backend_tag, "cuda"),
|
|
functools.partial(_deserialize, "cuda"),
|
|
)
|
|
register_package(21, _mps_tag, _mps_deserialize)
|
|
register_package(22, _meta_tag, _meta_deserialize)
|
|
register_package(
|
|
23,
|
|
functools.partial(_backend_tag, "privateuse1"),
|
|
functools.partial(_deserialize, "privateuse1"),
|
|
)
|
|
register_package(
|
|
24,
|
|
functools.partial(_backend_tag, "hpu"),
|
|
functools.partial(_deserialize, "hpu"),
|
|
)
|
|
register_package(
|
|
25,
|
|
functools.partial(_backend_tag, "xpu"),
|
|
functools.partial(_deserialize, "xpu"),
|
|
)
|
|
|
|
|
|
def location_tag(
|
|
storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage],
|
|
):
|
|
for _, tagger, _ in _package_registry:
|
|
location = tagger(storage)
|
|
if location:
|
|
return location
|
|
raise RuntimeError(
|
|
"don't know how to determine data location of " + torch.typename(storage)
|
|
)
|
|
|
|
|
|
def default_restore_location(storage, location):
|
|
"""
|
|
Restores `storage` using a deserializer function registered for the `location`.
|
|
|
|
This function looks in the registry for deserializer functions that match the `location`.
|
|
If found, it attempts to use them, in priority order, to restore `storage` until one
|
|
returns a not `None` result. If no deserializer can be found in the registry, or all found fail
|
|
to bear a result, it raises a `RuntimeError`.
|
|
|
|
Args:
|
|
storage (STORAGE): the storage object to restore
|
|
location (str): the location tag associated with the storage object
|
|
|
|
Returns:
|
|
storage: Optional[STORAGE]
|
|
|
|
Raises:
|
|
RuntimeError: If no deserializer matching `location` is found in the registry or if
|
|
all matching ones return `None`.
|
|
"""
|
|
for _, _, fn in _package_registry:
|
|
result = fn(storage, location)
|
|
if result is not None:
|
|
return result
|
|
raise RuntimeError(
|
|
"don't know how to restore data location of "
|
|
+ torch.typename(storage)
|
|
+ " (tagged with "
|
|
+ location
|
|
+ ")"
|
|
)
|
|
|
|
|
|
def normalize_storage_type(storage_type):
|
|
return getattr(torch, storage_type.__name__)
|
|
|
|
|
|
def storage_to_tensor_type(storage):
|
|
storage_type = type(storage)
|
|
module = _import_dotted_name(storage_type.__module__)
|
|
return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
|
|
|
|
|
|
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
|
|
return isinstance(name_or_buffer, (str, os.PathLike))
|
|
|
|
|
|
class _opener:
|
|
def __init__(self, file_like):
|
|
self.file_like = file_like
|
|
|
|
def __enter__(self):
|
|
return self.file_like
|
|
|
|
def __exit__(self, *args):
|
|
pass
|
|
|
|
|
|
class _open_file(_opener):
|
|
def __init__(self, name, mode):
|
|
super().__init__(open(name, mode))
|
|
|
|
def __exit__(self, *args):
|
|
self.file_like.close()
|
|
|
|
|
|
class _open_buffer_reader(_opener):
|
|
def __init__(self, buffer):
|
|
super().__init__(buffer)
|
|
_check_seekable(buffer)
|
|
|
|
|
|
class _open_buffer_writer(_opener):
|
|
def __exit__(self, *args):
|
|
self.file_like.flush()
|
|
|
|
|
|
def _open_file_like(name_or_buffer, mode):
|
|
if _is_path(name_or_buffer):
|
|
return _open_file(name_or_buffer, mode)
|
|
else:
|
|
if "w" in mode:
|
|
return _open_buffer_writer(name_or_buffer)
|
|
elif "r" in mode:
|
|
return _open_buffer_reader(name_or_buffer)
|
|
else:
|
|
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
|
|
|
|
|
|
class _open_zipfile_reader(_opener):
|
|
def __init__(self, name_or_buffer) -> None:
|
|
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
|
|
|
|
|
class _open_zipfile_writer_file(_opener):
|
|
def __init__(self, name) -> None:
|
|
self.file_stream = None
|
|
self.name = str(name)
|
|
try:
|
|
self.name.encode("ascii")
|
|
except UnicodeEncodeError:
|
|
# PyTorchFileWriter only supports ascii filename.
|
|
# For filenames with non-ascii characters, we rely on Python
|
|
# for writing out the file.
|
|
self.file_stream = io.FileIO(self.name, mode="w")
|
|
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
|
|
else:
|
|
super().__init__(torch._C.PyTorchFileWriter(self.name))
|
|
|
|
def __exit__(self, *args) -> None:
|
|
self.file_like.write_end_of_file()
|
|
if self.file_stream is not None:
|
|
self.file_stream.close()
|
|
|
|
|
|
class _open_zipfile_writer_buffer(_opener):
|
|
def __init__(self, buffer) -> None:
|
|
if not callable(getattr(buffer, "write", None)):
|
|
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
|
|
if not hasattr(buffer, "write"):
|
|
raise AttributeError(msg)
|
|
raise TypeError(msg)
|
|
self.buffer = buffer
|
|
super().__init__(torch._C.PyTorchFileWriter(buffer))
|
|
|
|
def __exit__(self, *args) -> None:
|
|
self.file_like.write_end_of_file()
|
|
self.buffer.flush()
|
|
|
|
|
|
def _open_zipfile_writer(name_or_buffer):
|
|
container: Type[_opener]
|
|
if _is_path(name_or_buffer):
|
|
container = _open_zipfile_writer_file
|
|
else:
|
|
container = _open_zipfile_writer_buffer
|
|
return container(name_or_buffer)
|
|
|
|
|
|
def _is_compressed_file(f) -> bool:
|
|
compress_modules = ["gzip"]
|
|
try:
|
|
return f.__module__ in compress_modules
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def _should_read_directly(f):
|
|
"""
|
|
Checks if f is a file that should be read directly. It should be read
|
|
directly if it is backed by a real file (has a fileno) and is not a
|
|
a compressed file (e.g. gzip)
|
|
"""
|
|
if _is_compressed_file(f):
|
|
return False
|
|
try:
|
|
return f.fileno() >= 0
|
|
except io.UnsupportedOperation:
|
|
return False
|
|
except AttributeError:
|
|
return False
|
|
|
|
|
|
def _check_seekable(f) -> bool:
|
|
def raise_err_msg(patterns, e):
|
|
for p in patterns:
|
|
if p in str(e):
|
|
msg = (
|
|
str(e)
|
|
+ ". You can only torch.load from a file that is seekable."
|
|
+ " Please pre-load the data into a buffer like io.BytesIO and"
|
|
+ " try to load from it instead."
|
|
)
|
|
raise type(e)(msg)
|
|
raise e
|
|
|
|
try:
|
|
f.seek(f.tell())
|
|
return True
|
|
except (io.UnsupportedOperation, AttributeError) as e:
|
|
raise_err_msg(["seek", "tell"], e)
|
|
return False
|
|
|
|
|
|
def _check_dill_version(pickle_module) -> None:
|
|
"""Checks if using dill as the pickle module, and if so, checks if it is the correct version.
|
|
If dill version is lower than 0.3.1, a ValueError is raised.
|
|
|
|
Args:
|
|
pickle_module: module used for pickling metadata and objects
|
|
|
|
"""
|
|
if pickle_module is not None and pickle_module.__name__ == "dill":
|
|
required_dill_version = (0, 3, 1)
|
|
if not check_module_version_greater_or_equal(
|
|
pickle_module, required_dill_version, False
|
|
):
|
|
raise ValueError(
|
|
(
|
|
"'torch' supports dill >= {}, but you have dill {}."
|
|
" Please upgrade dill or switch to 'pickle'"
|
|
).format(
|
|
".".join([str(num) for num in required_dill_version]),
|
|
pickle_module.__version__,
|
|
)
|
|
)
|
|
|
|
|
|
def _check_save_filelike(f):
|
|
if not _is_path(f) and not hasattr(f, "write"):
|
|
raise AttributeError(
|
|
"expected 'f' to be string, path, or a file-like object with "
|
|
"a 'write' attribute"
|
|
)
|
|
|
|
|
|
def save(
|
|
obj: object,
|
|
f: FILE_LIKE,
|
|
pickle_module: Any = pickle,
|
|
pickle_protocol: int = DEFAULT_PROTOCOL,
|
|
_use_new_zipfile_serialization: bool = True,
|
|
_disable_byteorder_record: bool = False,
|
|
) -> None:
|
|
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
|
# The first line of this docstring overrides the one Sphinx generates for the
|
|
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
|
|
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
|
|
|
|
"""save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
|
|
|
|
Saves an object to a disk file.
|
|
|
|
See also: :ref:`saving-loading-tensors`
|
|
|
|
Args:
|
|
obj: saved object
|
|
f: a file-like object (has to implement write and flush) or a string or
|
|
os.PathLike object containing a file name
|
|
pickle_module: module used for pickling metadata and objects
|
|
pickle_protocol: can be specified to override the default protocol
|
|
|
|
.. note::
|
|
A common PyTorch convention is to save tensors using .pt file extension.
|
|
|
|
.. note::
|
|
PyTorch preserves storage sharing across serialization. See
|
|
:ref:`preserve-storage-sharing` for more details.
|
|
|
|
.. note::
|
|
The 1.6 release of PyTorch switched ``torch.save`` to use a new
|
|
zipfile-based file format. ``torch.load`` still retains the ability to
|
|
load files in the old format. If for any reason you want ``torch.save``
|
|
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("makes cwd dirty")
|
|
>>> # Save to file
|
|
>>> x = torch.tensor([0, 1, 2, 3, 4])
|
|
>>> torch.save(x, "tensor.pt")
|
|
>>> # Save to io.BytesIO buffer
|
|
>>> buffer = io.BytesIO()
|
|
>>> torch.save(x, buffer)
|
|
"""
|
|
torch._C._log_api_usage_once("torch.save")
|
|
_check_dill_version(pickle_module)
|
|
_check_save_filelike(f)
|
|
|
|
if _use_new_zipfile_serialization:
|
|
with _open_zipfile_writer(f) as opened_zipfile:
|
|
_save(
|
|
obj,
|
|
opened_zipfile,
|
|
pickle_module,
|
|
pickle_protocol,
|
|
_disable_byteorder_record,
|
|
)
|
|
return
|
|
else:
|
|
global _serialization_tls
|
|
if _serialization_tls.skip_data:
|
|
raise RuntimeError(
|
|
"Cannot use skip_data=True with _use_new_zipfile_serialization=False"
|
|
)
|
|
with _open_file_like(f, "wb") as opened_file:
|
|
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
|
|
|
|
|
|
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
|
import torch.nn as nn
|
|
|
|
serialized_container_types = {}
|
|
serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {}
|
|
|
|
# Since loading storages that view the same data with different dtypes is
|
|
# not supported, we need to keep track of the dtype associated with each
|
|
# storage data_ptr and throw an error if the dtype is ever different.
|
|
# TODO: This feature could be added in the future
|
|
storage_dtypes: Dict[int, torch.dtype] = {}
|
|
|
|
def persistent_id(obj: Any) -> Optional[Tuple]:
|
|
# FIXME: the docs say that persistent_id should only return a string
|
|
# but torch store returns tuples. This works only in the binary protocol
|
|
# see
|
|
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
|
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
|
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
|
if obj in serialized_container_types:
|
|
return None
|
|
serialized_container_types[obj] = True
|
|
source_file = source = None
|
|
try:
|
|
source_lines, _, source_file = get_source_lines_and_file(obj)
|
|
source = "".join(source_lines)
|
|
except (
|
|
Exception
|
|
): # saving the source is optional, so we can ignore any errors
|
|
warnings.warn(
|
|
"Couldn't retrieve source code for container of "
|
|
"type " + obj.__name__ + ". It won't be checked "
|
|
"for correctness upon loading."
|
|
)
|
|
return ("module", obj, source_file, source)
|
|
|
|
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
|
storage: torch.UntypedStorage
|
|
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, this case
|
|
# can be deleted
|
|
storage = obj._untyped_storage
|
|
storage_dtype = obj.dtype
|
|
storage_type_str = obj._pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
dtype = obj.dtype
|
|
storage_numel = obj._size()
|
|
|
|
elif isinstance(obj, torch.UntypedStorage):
|
|
storage = obj
|
|
storage_dtype = torch.uint8
|
|
storage_type = normalize_storage_type(type(obj))
|
|
dtype = torch.uint8
|
|
storage_numel = storage.nbytes()
|
|
else:
|
|
raise TypeError(f"type not recognized: {type(obj)}")
|
|
|
|
# If storage is allocated, ensure that any other saved storages
|
|
# pointing to the same data all have the same dtype. If storage is
|
|
# not allocated, don't perform this check
|
|
if storage.data_ptr() != 0:
|
|
if storage.data_ptr() in storage_dtypes:
|
|
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
|
raise RuntimeError(
|
|
"Cannot save multiple tensors or storages that "
|
|
"view the same data as different types"
|
|
)
|
|
else:
|
|
storage_dtypes[storage.data_ptr()] = storage_dtype
|
|
|
|
view_metadata: Optional[Tuple[str, int, int]]
|
|
|
|
# Offset is always 0, but we keep it for backwards compatibility
|
|
# with the old serialization format (which supported storage views)
|
|
offset = 0
|
|
storage_key = str(storage._cdata)
|
|
location = location_tag(storage)
|
|
|
|
# TODO: There's an issue here with FC. It might be impossible to
|
|
# solve, but it's worth noting. Imagine we save a list `[storage,
|
|
# tensor]`, where `tensor.storage()` is the same as `storage`, and
|
|
# `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
|
|
# torch.float`. The storage will be serialized with element size
|
|
# of 1, since we're choosing to serialize the first occurance of
|
|
# a duplicate storage. Since this legacy serialization format saves
|
|
# the numel of the storage, rather than nbytes directly, we'll be
|
|
# effectively saving nbytes in this case. We'll be able to load it
|
|
# and the tensor back up with no problems in _this_ and future
|
|
# versions of pytorch, but in older versions, here's the problem:
|
|
# the storage will be loaded up as a UntypedStorage, and then the
|
|
# FloatTensor will loaded and the UntypedStorage will be assigned to
|
|
# it. Since the storage dtype does not match the tensor dtype, this
|
|
# will cause an error. If we reverse the list, like `[tensor,
|
|
# storage]`, then we will save the `tensor.storage()` as a faked
|
|
# `FloatStorage`, and the saved size will be the correct
|
|
# dtype-specific numel count that old versions expect. `tensor`
|
|
# will be able to load up properly in old versions, pointing to
|
|
# a FloatStorage. However, `storage` is still being translated to
|
|
# a UntypedStorage, and it will try to resolve to the same
|
|
# FloatStorage that `tensor` contains. This will also cause an
|
|
# error. It doesn't seem like there's any way around this.
|
|
# Probably, we just cannot maintain FC for the legacy format if the
|
|
# saved list contains both a tensor and a storage that point to the
|
|
# same data. We should still be able to maintain FC for lists of
|
|
# just tensors, as long as all views share the same dtype as the
|
|
# tensor they are viewing.
|
|
|
|
if storage_key not in serialized_storages:
|
|
serialized_storages[storage_key] = (storage, dtype)
|
|
is_view = storage._cdata != storage._cdata
|
|
if is_view:
|
|
view_metadata = (str(storage._cdata), offset, storage.nbytes())
|
|
else:
|
|
view_metadata = None
|
|
|
|
res = (
|
|
"storage",
|
|
storage_type,
|
|
storage_key,
|
|
location,
|
|
storage_numel,
|
|
view_metadata,
|
|
)
|
|
return res
|
|
return None
|
|
|
|
sys_info = dict(
|
|
protocol_version=PROTOCOL_VERSION,
|
|
little_endian=sys.byteorder == "little",
|
|
type_sizes=dict(
|
|
short=SHORT_SIZE,
|
|
int=INT_SIZE,
|
|
long=LONG_SIZE,
|
|
),
|
|
)
|
|
|
|
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
|
|
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
|
|
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
|
|
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
|
|
pickler.persistent_id = persistent_id
|
|
pickler.dump(obj)
|
|
|
|
serialized_storage_keys = sorted(serialized_storages.keys())
|
|
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
|
|
f.flush()
|
|
for key in serialized_storage_keys:
|
|
storage, dtype = serialized_storages[key]
|
|
storage._write_file(
|
|
f, _should_read_directly(f), True, torch._utils._element_size(dtype)
|
|
)
|
|
|
|
|
|
def _save(
|
|
obj,
|
|
zip_file,
|
|
pickle_module,
|
|
pickle_protocol,
|
|
_disable_byteorder_record,
|
|
):
|
|
serialized_storages = {}
|
|
id_map: Dict[int, str] = {}
|
|
|
|
# Since loading storages that view the same data with different dtypes is
|
|
# not supported, we need to keep track of the dtype associated with each
|
|
# storage data_ptr and throw an error if the dtype is ever different.
|
|
# TODO: This feature could be added in the future
|
|
storage_dtypes: Dict[int, torch.dtype] = {}
|
|
|
|
def persistent_id(obj):
|
|
# FIXME: the docs say that persistent_id should only return a string
|
|
# but torch store returns tuples. This works only in the binary protocol
|
|
# see
|
|
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
|
|
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
|
|
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
|
|
if isinstance(obj, torch.storage.TypedStorage):
|
|
# TODO: Once we decide to break serialization FC, this case
|
|
# can be deleted
|
|
storage = obj._untyped_storage
|
|
storage_dtype = obj.dtype
|
|
storage_type_str = obj._pickle_storage_type()
|
|
storage_type = getattr(torch, storage_type_str)
|
|
storage_numel = obj._size()
|
|
|
|
else:
|
|
storage = obj
|
|
storage_dtype = torch.uint8
|
|
storage_type = normalize_storage_type(type(obj))
|
|
storage_numel = storage.nbytes()
|
|
|
|
# If storage is allocated, ensure that any other saved storages
|
|
# pointing to the same data all have the same dtype. If storage is
|
|
# not allocated, don't perform this check
|
|
if str(storage.device) != "meta" and storage.data_ptr() != 0:
|
|
if storage.data_ptr() in storage_dtypes:
|
|
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
|
raise RuntimeError(
|
|
"Cannot save multiple tensors or storages that "
|
|
"view the same data as different types"
|
|
)
|
|
else:
|
|
storage_dtypes[storage.data_ptr()] = storage_dtype
|
|
|
|
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
|
|
if hasattr(obj, "_fake_device") and obj._fake_device is not None:
|
|
location = str(obj._fake_device)
|
|
else:
|
|
location = location_tag(storage)
|
|
serialized_storages[storage_key] = storage
|
|
|
|
return ("storage", storage_type, storage_key, location, storage_numel)
|
|
|
|
return None
|
|
|
|
# Write the pickle data for `obj`
|
|
data_buf = io.BytesIO()
|
|
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
|
|
pickler.persistent_id = persistent_id
|
|
pickler.dump(obj)
|
|
data_value = data_buf.getvalue()
|
|
zip_file.write_record("data.pkl", data_value, len(data_value))
|
|
|
|
# Write byte order marker
|
|
if not _disable_byteorder_record:
|
|
if sys.byteorder not in ["little", "big"]:
|
|
raise ValueError("Unknown endianness type: " + sys.byteorder)
|
|
|
|
zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
|
|
|
|
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
|
|
for key in sorted(serialized_storages.keys()):
|
|
name = f"data/{key}"
|
|
storage = serialized_storages[key]
|
|
num_bytes = storage.nbytes()
|
|
global _serialization_tls
|
|
if _serialization_tls.skip_data:
|
|
zip_file.write_record_metadata(name, num_bytes)
|
|
else:
|
|
# given that we copy things around anyway, we might use storage.cpu()
|
|
# this means to that to get tensors serialized, you need to implement
|
|
# .cpu() on the underlying Storage
|
|
if storage.device.type != "cpu":
|
|
storage = storage.cpu()
|
|
# Now that it is on the CPU we can directly copy it into the zip file
|
|
zip_file.write_record(name, storage, num_bytes)
|
|
|
|
|
|
def load(
|
|
f: FILE_LIKE,
|
|
map_location: MAP_LOCATION = None,
|
|
pickle_module: Any = None,
|
|
*,
|
|
weights_only: Optional[bool] = None,
|
|
mmap: Optional[bool] = None,
|
|
**pickle_load_args: Any,
|
|
) -> Any:
|
|
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
|
# The first line of this docstring overrides the one Sphinx generates for the
|
|
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
|
|
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
|
|
|
|
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
|
|
|
|
Loads an object saved with :func:`torch.save` from a file.
|
|
|
|
:func:`torch.load` uses Python's unpickling facilities but treats storages,
|
|
which underlie tensors, specially. They are first deserialized on the
|
|
CPU and are then moved to the device they were saved from. If this fails
|
|
(e.g. because the run time system doesn't have certain devices), an exception
|
|
is raised. However, storages can be dynamically remapped to an alternative
|
|
set of devices using the :attr:`map_location` argument.
|
|
|
|
If :attr:`map_location` is a callable, it will be called once for each serialized
|
|
storage with two arguments: storage and location. The storage argument
|
|
will be the initial deserialization of the storage, residing on the CPU.
|
|
Each serialized storage has a location tag associated with it which
|
|
identifies the device it was saved from, and this tag is the second
|
|
argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
|
|
for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
|
|
:attr:`map_location` should return either ``None`` or a storage. If
|
|
:attr:`map_location` returns a storage, it will be used as the final deserialized
|
|
object, already moved to the right device. Otherwise, :func:`torch.load` will
|
|
fall back to the default behavior, as if :attr:`map_location` wasn't specified.
|
|
|
|
If :attr:`map_location` is a :class:`torch.device` object or a string containing
|
|
a device tag, it indicates the location where all tensors should be loaded.
|
|
|
|
Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
|
|
appearing in the file (keys), to ones that specify where to put the
|
|
storages (values).
|
|
|
|
User extensions can register their own location tags and tagging and
|
|
deserialization methods using :func:`torch.serialization.register_package`.
|
|
|
|
Args:
|
|
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
|
|
or a string or os.PathLike object containing a file name
|
|
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
|
|
locations
|
|
pickle_module: module used for unpickling metadata and objects (has to
|
|
match the :attr:`pickle_module` used to serialize file)
|
|
weights_only: Indicates whether unpickler should be restricted to
|
|
loading only tensors, primitive types, dictionaries
|
|
and any types added via :func:`torch.serialization.add_safe_globals`.
|
|
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
|
|
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
|
|
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
|
|
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
|
|
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
|
|
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
|
|
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
|
|
:attr:`errors=...`.
|
|
|
|
.. warning::
|
|
:func:`torch.load()` unless `weights_only` parameter is set to `True`,
|
|
uses ``pickle`` module implicitly, which is known to be insecure.
|
|
It is possible to construct malicious pickle data which will execute arbitrary code
|
|
during unpickling. Never load data that could have come from an untrusted
|
|
source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
|
|
|
|
.. note::
|
|
When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
|
|
will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
|
|
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
|
|
|
|
.. note::
|
|
By default, we decode byte strings as ``utf-8``. This is to avoid a common error
|
|
case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
|
|
when loading files saved by Python 2 in Python 3. If this default
|
|
is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
|
|
these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
|
|
to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
|
|
as byte arrays which can be decoded later with ``byte_array.decode(...)``.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("undefined filepaths")
|
|
>>> torch.load("tensors.pt", weights_only=True)
|
|
# Load all tensors onto the CPU
|
|
>>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True)
|
|
# Load all tensors onto the CPU, using a function
|
|
>>> torch.load(
|
|
... "tensors.pt", map_location=lambda storage, loc: storage, weights_only=True
|
|
... )
|
|
# Load all tensors onto GPU 1
|
|
>>> torch.load(
|
|
... "tensors.pt",
|
|
... map_location=lambda storage, loc: storage.cuda(1),
|
|
... weights_only=True,
|
|
... ) # type: ignore[attr-defined]
|
|
# Map tensors from GPU 1 to GPU 0
|
|
>>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, weights_only=True)
|
|
# Load tensor from io.BytesIO object
|
|
# Loading from a buffer setting weights_only=False, warning this can be unsafe
|
|
>>> with open("tensor.pt", "rb") as f:
|
|
... buffer = io.BytesIO(f.read())
|
|
>>> torch.load(buffer, weights_only=False)
|
|
# Load a module with 'ascii' encoding for unpickling
|
|
# Loading from a module setting weights_only=False, warning this can be unsafe
|
|
>>> torch.load("module.pt", encoding="ascii", weights_only=False)
|
|
"""
|
|
torch._C._log_api_usage_once("torch.load")
|
|
UNSAFE_MESSAGE = (
|
|
"Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
|
|
"but it can result in arbitrary code execution. Do it only if you got the file from a "
|
|
"trusted source."
|
|
)
|
|
DOCS_MESSAGE = (
|
|
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
|
|
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
|
|
)
|
|
|
|
def _get_wo_message(message: str) -> str:
|
|
unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
|
|
has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
|
|
blocklist_pattern = r"whose module (\S+) is blocked"
|
|
has_blocklist = re.search(blocklist_pattern, message) is not None
|
|
if has_unsafe_global:
|
|
updated_message = (
|
|
"Weights only load failed. This file can still be loaded, to do so you have two options, "
|
|
"\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
|
|
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
|
|
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
|
|
+ message
|
|
)
|
|
else:
|
|
updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
|
|
if not has_blocklist:
|
|
updated_message += (
|
|
"Please file an issue with the following so that we can make "
|
|
"`weights_only=True` compatible with your use case: WeightsUnpickler error: "
|
|
)
|
|
updated_message += message
|
|
return updated_message + DOCS_MESSAGE
|
|
|
|
global _serialization_tls
|
|
skip_data = _serialization_tls.skip_data
|
|
if skip_data:
|
|
raise RuntimeError(
|
|
"`torch.load` called within a torch.serialization.skip_data context manager "
|
|
"is not supported yet. Please call torch.load outside the skip_data context manager."
|
|
)
|
|
|
|
if weights_only is None:
|
|
weights_only, warn_weights_only = False, True
|
|
else:
|
|
warn_weights_only = False
|
|
|
|
# Add ability to force safe only weight loads via environment variable
|
|
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
|
|
"1",
|
|
"y",
|
|
"yes",
|
|
"true",
|
|
]:
|
|
weights_only = True
|
|
|
|
if weights_only:
|
|
if pickle_module is not None:
|
|
raise RuntimeError(
|
|
"Can not safely load weights when explicit pickle_module is specified"
|
|
)
|
|
else:
|
|
if pickle_module is None:
|
|
if warn_weights_only:
|
|
warnings.warn(
|
|
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
|
|
"the default pickle module implicitly. It is possible to construct malicious pickle data "
|
|
"which will execute arbitrary code during unpickling (See "
|
|
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
|
|
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
|
|
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
|
|
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
|
|
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
|
|
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
|
|
"Please open an issue on GitHub for any issues related to this experimental feature.",
|
|
FutureWarning,
|
|
stacklevel=2,
|
|
)
|
|
pickle_module = pickle
|
|
|
|
# make flipping default BC-compatible
|
|
if mmap is None:
|
|
mmap = False
|
|
|
|
_check_dill_version(pickle_module)
|
|
|
|
if "encoding" not in pickle_load_args.keys():
|
|
pickle_load_args["encoding"] = "utf-8"
|
|
|
|
with _open_file_like(f, "rb") as opened_file:
|
|
if _is_zipfile(opened_file):
|
|
# The zipfile reader is going to advance the current file position.
|
|
# If we want to actually tail call to torch.jit.load, we need to
|
|
# reset back to the original position.
|
|
orig_position = opened_file.tell()
|
|
overall_storage = None
|
|
with _open_zipfile_reader(opened_file) as opened_zipfile:
|
|
if _is_torchscript_zip(opened_zipfile):
|
|
warnings.warn(
|
|
"'torch.load' received a zip file that looks like a TorchScript archive"
|
|
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
|
|
" silence this warning)",
|
|
UserWarning,
|
|
)
|
|
opened_file.seek(orig_position)
|
|
return torch.jit.load(opened_file, map_location=map_location)
|
|
if mmap:
|
|
if not _is_path(f):
|
|
raise ValueError(
|
|
"f must be a file path in order to use the mmap argument"
|
|
)
|
|
size = os.path.getsize(f)
|
|
if not IS_WINDOWS:
|
|
shared = get_default_mmap_options() == MAP_SHARED
|
|
else:
|
|
shared = False
|
|
overall_storage = torch.UntypedStorage.from_file(
|
|
os.fspath(f), shared, size
|
|
)
|
|
if weights_only:
|
|
try:
|
|
return _load(
|
|
opened_zipfile,
|
|
map_location,
|
|
_weights_only_unpickler,
|
|
overall_storage=overall_storage,
|
|
**pickle_load_args,
|
|
)
|
|
except pickle.UnpicklingError as e:
|
|
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
|
return _load(
|
|
opened_zipfile,
|
|
map_location,
|
|
pickle_module,
|
|
overall_storage=overall_storage,
|
|
**pickle_load_args,
|
|
)
|
|
if mmap:
|
|
f_name = "" if not isinstance(f, str) else f"{f}, "
|
|
raise RuntimeError(
|
|
"mmap can only be used with files saved with "
|
|
f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
|
|
"please torch.save your checkpoint with this option in order to use mmap."
|
|
)
|
|
if weights_only:
|
|
try:
|
|
return _legacy_load(
|
|
opened_file,
|
|
map_location,
|
|
_weights_only_unpickler,
|
|
**pickle_load_args,
|
|
)
|
|
except pickle.UnpicklingError as e:
|
|
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
|
|
return _legacy_load(
|
|
opened_file, map_location, pickle_module, **pickle_load_args
|
|
)
|
|
|
|
|
|
# Register pickling support for layout instances such as
|
|
# torch.sparse_coo, etc
|
|
def _get_layout(name):
|
|
"""Get layout extension object from its string representation."""
|
|
cache = _get_layout.cache # type: ignore[attr-defined]
|
|
if not cache:
|
|
for v in torch.__dict__.values():
|
|
if isinstance(v, torch.layout):
|
|
cache[str(v)] = v
|
|
return cache[name]
|
|
|
|
|
|
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
|
|
_get_layout.cache = {} # type: ignore[attr-defined]
|
|
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
|
|
|
|
|
|
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|
deserialized_objects: Dict[int, Any] = {}
|
|
|
|
restore_location = _get_restore_location(map_location)
|
|
|
|
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
|
|
def find_class(self, mod_name, name):
|
|
if type(name) is str and "Storage" in name:
|
|
try:
|
|
return StorageType(name)
|
|
except KeyError:
|
|
pass
|
|
return super().find_class(mod_name, name)
|
|
|
|
def _check_container_source(container_type, source_file, original_source):
|
|
try:
|
|
current_source = "".join(get_source_lines_and_file(container_type)[0])
|
|
except Exception: # saving the source is optional, so we can ignore any errors
|
|
warnings.warn(
|
|
"Couldn't retrieve source code for container of "
|
|
"type " + container_type.__name__ + ". It won't be checked "
|
|
"for correctness upon loading."
|
|
)
|
|
return
|
|
if original_source != current_source:
|
|
if container_type.dump_patches:
|
|
file_name = container_type.__name__ + ".patch"
|
|
diff = difflib.unified_diff(
|
|
current_source.split("\n"),
|
|
original_source.split("\n"),
|
|
source_file,
|
|
source_file,
|
|
lineterm="",
|
|
)
|
|
lines = "\n".join(diff)
|
|
try:
|
|
with open(file_name, "a+") as f:
|
|
file_size = f.seek(0, 2)
|
|
f.seek(0)
|
|
if file_size == 0:
|
|
f.write(lines)
|
|
elif file_size != len(lines) or f.read() != lines:
|
|
raise OSError
|
|
msg = (
|
|
"Saved a reverse patch to " + file_name + ". "
|
|
"Run `patch -p0 < " + file_name + "` to revert your "
|
|
"changes."
|
|
)
|
|
except OSError:
|
|
msg = (
|
|
"Tried to save a patch, but couldn't create a "
|
|
"writable file " + file_name + ". Make sure it "
|
|
"doesn't exist and your working directory is "
|
|
"writable."
|
|
)
|
|
else:
|
|
msg = (
|
|
"you can retrieve the original source code by "
|
|
"accessing the object's source attribute or set "
|
|
"`torch.nn.Module.dump_patches = True` and use the "
|
|
"patch tool to revert the changes."
|
|
)
|
|
msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
|
|
warnings.warn(msg, SourceChangeWarning)
|
|
|
|
def legacy_load(f):
|
|
deserialized_objects: Dict[int, Any] = {}
|
|
|
|
def persistent_load(saved_id):
|
|
if isinstance(saved_id, tuple):
|
|
# Ignore containers that don't have any sources saved
|
|
if all(saved_id[1:]):
|
|
_check_container_source(*saved_id)
|
|
return saved_id[0]
|
|
return deserialized_objects[int(saved_id)]
|
|
|
|
with closing(
|
|
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
|
|
) as tar, mkdtemp() as tmpdir:
|
|
tar.extract("storages", path=tmpdir)
|
|
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
|
|
num_storages = pickle_module.load(f, **pickle_load_args)
|
|
for i in range(num_storages):
|
|
args = pickle_module.load(f, **pickle_load_args)
|
|
key, location, storage_type = args
|
|
dtype = storage_type._dtype
|
|
obj = cast(Storage, torch.UntypedStorage)._new_with_file(
|
|
f, torch._utils._element_size(dtype)
|
|
)
|
|
obj = restore_location(obj, location)
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
deserialized_objects[key] = torch.storage.TypedStorage(
|
|
wrap_storage=obj, dtype=dtype, _internal=True
|
|
)
|
|
|
|
storage_views = pickle_module.load(f, **pickle_load_args)
|
|
for target_cdata, root_cdata, offset, numel in storage_views:
|
|
root = deserialized_objects[root_cdata]
|
|
element_size = torch._utils._element_size(root.dtype)
|
|
offset_bytes = offset * element_size
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
|
|
wrap_storage=root._untyped_storage[
|
|
offset_bytes : offset_bytes + numel * element_size
|
|
],
|
|
dtype=root.dtype,
|
|
_internal=True,
|
|
)
|
|
|
|
tar.extract("tensors", path=tmpdir)
|
|
with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
|
|
num_tensors = pickle_module.load(f, **pickle_load_args)
|
|
for _ in range(num_tensors):
|
|
args = pickle_module.load(f, **pickle_load_args)
|
|
key, storage_id, original_tensor_type = args
|
|
storage = deserialized_objects[storage_id]
|
|
(ndim,) = struct.unpack("<i", f.read(4))
|
|
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
|
|
f.read(4)
|
|
numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
|
|
stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
|
|
(storage_offset,) = struct.unpack("<q", f.read(8))
|
|
tensor = torch.empty((0,), dtype=storage.dtype).set_(
|
|
storage._untyped_storage, storage_offset, numel, stride
|
|
)
|
|
deserialized_objects[key] = tensor
|
|
|
|
pickle_file = tar.extractfile("pickle")
|
|
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
|
|
unpickler.persistent_load = persistent_load
|
|
result = unpickler.load()
|
|
return result
|
|
|
|
deserialized_objects = {}
|
|
|
|
def persistent_load(saved_id):
|
|
assert isinstance(saved_id, tuple)
|
|
typename = _maybe_decode_ascii(saved_id[0])
|
|
data = saved_id[1:]
|
|
|
|
if typename == "module":
|
|
# Ignore containers that don't have any sources saved
|
|
if all(data[1:]):
|
|
_check_container_source(*data)
|
|
return data[0]
|
|
elif typename == "storage":
|
|
storage_type, root_key, location, numel, view_metadata = data
|
|
location = _maybe_decode_ascii(location)
|
|
dtype = storage_type.dtype
|
|
|
|
nbytes = numel * torch._utils._element_size(dtype)
|
|
|
|
if root_key not in deserialized_objects:
|
|
if torch._guards.active_fake_mode() is not None:
|
|
obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
|
|
else:
|
|
obj = cast(Storage, torch.UntypedStorage(nbytes))
|
|
obj._torch_load_uninitialized = True
|
|
obj = restore_location(obj, location)
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
typed_storage = torch.storage.TypedStorage(
|
|
wrap_storage=obj, dtype=dtype, _internal=True
|
|
)
|
|
deserialized_objects[root_key] = typed_storage
|
|
else:
|
|
typed_storage = deserialized_objects[root_key]
|
|
if typed_storage._data_ptr() == 0:
|
|
typed_storage = torch.storage.TypedStorage(
|
|
device=typed_storage._untyped_storage.device,
|
|
dtype=dtype,
|
|
_internal=True,
|
|
)
|
|
|
|
if view_metadata is not None:
|
|
view_key, offset, view_size = view_metadata
|
|
offset_bytes = offset * torch._utils._element_size(dtype)
|
|
view_size_bytes = view_size * torch._utils._element_size(dtype)
|
|
if view_key not in deserialized_objects:
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
deserialized_objects[view_key] = torch.storage.TypedStorage(
|
|
wrap_storage=typed_storage._untyped_storage[
|
|
offset_bytes : offset_bytes + view_size_bytes
|
|
],
|
|
dtype=dtype,
|
|
_internal=True,
|
|
)
|
|
res = deserialized_objects[view_key]
|
|
|
|
else:
|
|
res = typed_storage
|
|
return res
|
|
else:
|
|
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
|
|
|
|
_check_seekable(f)
|
|
f_should_read_directly = _should_read_directly(f)
|
|
|
|
if f_should_read_directly and f.tell() == 0:
|
|
# legacy_load requires that f has fileno()
|
|
# only if offset is zero we can attempt the legacy tar file loader
|
|
try:
|
|
return legacy_load(f)
|
|
except tarfile.TarError:
|
|
if _is_zipfile(f):
|
|
# .zip is used for torch.jit.save and will throw an un-pickling error here
|
|
raise RuntimeError(
|
|
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
|
|
) from None
|
|
# if not a tarfile, reset file offset and proceed
|
|
f.seek(0)
|
|
|
|
if not hasattr(f, "readinto") and (3, 8, 0) <= sys.version_info < (3, 8, 2):
|
|
raise RuntimeError(
|
|
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
|
|
f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
|
|
"functionality."
|
|
)
|
|
|
|
magic_number = pickle_module.load(f, **pickle_load_args)
|
|
if magic_number != MAGIC_NUMBER:
|
|
raise RuntimeError("Invalid magic number; corrupt file?")
|
|
protocol_version = pickle_module.load(f, **pickle_load_args)
|
|
if protocol_version != PROTOCOL_VERSION:
|
|
raise RuntimeError(f"Invalid protocol version: {protocol_version}")
|
|
|
|
_sys_info = pickle_module.load(f, **pickle_load_args)
|
|
unpickler = UnpicklerWrapper(f, **pickle_load_args)
|
|
unpickler.persistent_load = persistent_load
|
|
result = unpickler.load()
|
|
|
|
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
|
|
|
|
if torch._guards.active_fake_mode() is None:
|
|
offset = f.tell() if f_should_read_directly else None
|
|
for key in deserialized_storage_keys:
|
|
assert key in deserialized_objects
|
|
typed_storage = deserialized_objects[key]
|
|
typed_storage._untyped_storage._set_from_file(
|
|
f,
|
|
offset,
|
|
f_should_read_directly,
|
|
torch._utils._element_size(typed_storage.dtype),
|
|
)
|
|
if offset is not None:
|
|
offset = f.tell()
|
|
|
|
torch._utils._validate_loaded_sparse_tensors()
|
|
|
|
return result
|
|
|
|
|
|
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
|
|
# When using encoding='bytes' in Py3, some **internal** keys stored as
|
|
# strings in Py2 are loaded as bytes. This function decodes them with
|
|
# ascii encoding, one that Py3 uses by default.
|
|
#
|
|
# NOTE: This should only be used on internal keys (e.g., `typename` and
|
|
# `location` in `persistent_load` below!
|
|
if isinstance(bytes_str, bytes):
|
|
return bytes_str.decode("ascii")
|
|
return bytes_str
|
|
|
|
|
|
def _get_restore_location(map_location):
|
|
if map_location is None:
|
|
restore_location = default_restore_location
|
|
elif isinstance(map_location, dict):
|
|
|
|
def restore_location(storage, location):
|
|
location = map_location.get(location, location)
|
|
return default_restore_location(storage, location)
|
|
|
|
elif isinstance(map_location, (str, bytes)):
|
|
|
|
def restore_location(storage, location):
|
|
return default_restore_location(storage, map_location)
|
|
|
|
elif isinstance(map_location, torch.device):
|
|
|
|
def restore_location(storage, location):
|
|
return default_restore_location(storage, str(map_location))
|
|
|
|
else:
|
|
|
|
def restore_location(storage, location):
|
|
result = map_location(storage, location)
|
|
if result is None:
|
|
result = default_restore_location(storage, location)
|
|
return result
|
|
|
|
return restore_location
|
|
|
|
|
|
class StorageType:
|
|
def __init__(self, name):
|
|
self._dtype = _get_dtype_from_pickle_storage_type(name)
|
|
|
|
@property
|
|
def dtype(self):
|
|
return self._dtype
|
|
|
|
def __str__(self):
|
|
return f"StorageType(dtype={self.dtype})"
|
|
|
|
|
|
def _load(
|
|
zip_file,
|
|
map_location,
|
|
pickle_module,
|
|
pickle_file="data.pkl",
|
|
overall_storage=None,
|
|
**pickle_load_args,
|
|
):
|
|
restore_location = _get_restore_location(map_location)
|
|
|
|
loaded_storages = {}
|
|
|
|
# check if byteswapping is needed
|
|
byteordername = "byteorder"
|
|
byteorderdata = None
|
|
if zip_file.has_record(byteordername):
|
|
byteorderdata = zip_file.get_record(byteordername)
|
|
if byteorderdata not in [b"little", b"big"]:
|
|
raise ValueError("Unknown endianness type: " + byteorderdata.decode())
|
|
elif (
|
|
get_default_load_endianness() == LoadEndianness.LITTLE
|
|
or get_default_load_endianness() is None
|
|
):
|
|
byteorderdata = b"little"
|
|
elif get_default_load_endianness() == LoadEndianness.BIG:
|
|
byteorderdata = b"big"
|
|
elif get_default_load_endianness() == LoadEndianness.NATIVE:
|
|
pass
|
|
else:
|
|
raise ValueError("Invalid load endianness type")
|
|
|
|
if (
|
|
not zip_file.has_record(byteordername)
|
|
and get_default_load_endianness() is None
|
|
and sys.byteorder == "big"
|
|
):
|
|
# Default behaviour was changed
|
|
# See https://github.com/pytorch/pytorch/issues/101688
|
|
warnings.warn(
|
|
"The default load endianness for checkpoints without a byteorder mark "
|
|
"on big endian machines was changed from 'native' to 'little' endian, "
|
|
"to avoid this behavior please use "
|
|
"torch.serialization.set_default_load_endianness to set "
|
|
"the desired default load endianness",
|
|
UserWarning,
|
|
)
|
|
|
|
def load_tensor(dtype, numel, key, location):
|
|
name = f"data/{key}"
|
|
if torch._guards.detect_fake_mode(None) is not None:
|
|
nbytes = numel * torch._utils._element_size(dtype)
|
|
storage = torch.UntypedStorage(nbytes, device="meta")
|
|
elif overall_storage is not None:
|
|
storage_offset = zip_file.get_record_offset(name)
|
|
storage = overall_storage[storage_offset : storage_offset + numel]
|
|
else:
|
|
storage = (
|
|
zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
|
|
._typed_storage()
|
|
._untyped_storage
|
|
)
|
|
# swap here if byteswapping is needed
|
|
if byteorderdata is not None:
|
|
if byteorderdata.decode() != sys.byteorder:
|
|
storage.byteswap(dtype)
|
|
|
|
# TODO: Once we decide to break serialization FC, we can
|
|
# stop wrapping with TypedStorage
|
|
typed_storage = torch.storage.TypedStorage(
|
|
wrap_storage=restore_location(storage, location),
|
|
dtype=dtype,
|
|
_internal=True,
|
|
)
|
|
|
|
if typed_storage._data_ptr() != 0:
|
|
loaded_storages[key] = typed_storage
|
|
|
|
return typed_storage
|
|
|
|
def persistent_load(saved_id):
|
|
assert isinstance(saved_id, tuple)
|
|
typename = _maybe_decode_ascii(saved_id[0])
|
|
data = saved_id[1:]
|
|
|
|
assert (
|
|
typename == "storage"
|
|
), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
|
storage_type, key, location, numel = data
|
|
if storage_type is torch.UntypedStorage:
|
|
dtype = torch.uint8
|
|
else:
|
|
dtype = storage_type.dtype
|
|
|
|
if key in loaded_storages:
|
|
typed_storage = loaded_storages[key]
|
|
else:
|
|
nbytes = numel * torch._utils._element_size(dtype)
|
|
typed_storage = load_tensor(
|
|
dtype, nbytes, key, _maybe_decode_ascii(location)
|
|
)
|
|
|
|
return typed_storage
|
|
|
|
load_module_mapping: Dict[str, str] = {
|
|
# See https://github.com/pytorch/pytorch/pull/51633
|
|
"torch.tensor": "torch._tensor"
|
|
}
|
|
|
|
# Need to subclass Unpickler instead of directly monkey-patching the find_class method
|
|
# because it's marked readonly in pickle.
|
|
# The type: ignore is because mypy can't statically determine the type of this class.
|
|
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
|
|
# from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
|
|
# Lets us override the imports that pickle uses when unpickling an object.
|
|
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
|
|
def find_class(self, mod_name, name):
|
|
if type(name) is str and "Storage" in name:
|
|
try:
|
|
return StorageType(name)
|
|
except KeyError:
|
|
pass
|
|
mod_name = load_module_mapping.get(mod_name, mod_name)
|
|
return super().find_class(mod_name, name)
|
|
|
|
# Load the data (which may in turn use `persistent_load` to load tensors)
|
|
data_file = io.BytesIO(zip_file.get_record(pickle_file))
|
|
|
|
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
|
unpickler.persistent_load = persistent_load
|
|
# Needed for tensors where storage device and rebuild tensor device are
|
|
# not connected (wrapper subclasses and tensors rebuilt using numpy)
|
|
global _serialization_tls
|
|
_serialization_tls.map_location = map_location
|
|
result = unpickler.load()
|
|
_serialization_tls.map_location = None
|
|
|
|
torch._utils._validate_loaded_sparse_tensors()
|
|
torch._C._log_api_usage_metadata(
|
|
"torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
|
|
)
|
|
return result
|
|
|
|
|
|
def _is_torchscript_zip(zip_file):
|
|
return "constants.pkl" in zip_file.get_all_records()
|