130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
import os
|
|
import sys
|
|
from typing import Callable, List, Optional
|
|
|
|
import torch
|
|
from torch.types import Storage
|
|
|
|
|
|
__all__: List[str] = []
|
|
|
|
|
|
def _dummy_fn(name: str) -> Callable:
|
|
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
|
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
|
|
|
|
return fn
|
|
|
|
|
|
if not hasattr(torch._C, "_gds_register_buffer"):
|
|
assert not hasattr(torch._C, "_gds_deregister_buffer")
|
|
assert not hasattr(torch._C, "_gds_register_handle")
|
|
assert not hasattr(torch._C, "_gds_deregister_handle")
|
|
assert not hasattr(torch._C, "_gds_load_storage")
|
|
assert not hasattr(torch._C, "_gds_save_storage")
|
|
# Define functions
|
|
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
|
|
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
|
|
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
|
|
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
|
|
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
|
|
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
|
|
|
|
|
|
def _gds_register_buffer(s: Storage) -> None:
|
|
"""Registers a buffer.
|
|
|
|
Args:
|
|
s (Storage): Buffer to register.
|
|
"""
|
|
torch._C._gds_register_buffer(s)
|
|
|
|
|
|
def _gds_deregister_buffer(s: Storage) -> None:
|
|
"""Registers a buffer.
|
|
|
|
Args:
|
|
s (Storage): Buffer to register.
|
|
"""
|
|
torch._C._gds_deregister_buffer(s)
|
|
|
|
|
|
class _GdsFile:
|
|
r"""Wrapper around cuFile.
|
|
|
|
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
|
|
|
|
Args:
|
|
filename (str): Name of the file to open.
|
|
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
|
|
be added automatically.
|
|
|
|
.. _CUDA GPUDirect Storage Documentation:
|
|
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
|
|
"""
|
|
|
|
def __init__(self, filename: str, flags: int):
|
|
if sys.platform == "win32":
|
|
raise RuntimeError("GdsFile is not supported on this platform.")
|
|
self.filename = filename
|
|
self.flags = flags
|
|
self.fd = os.open(filename, flags | os.O_DIRECT)
|
|
self.handle: Optional[int] = None
|
|
self.register_handle()
|
|
|
|
def __del__(self) -> None:
|
|
if self.handle is not None:
|
|
self.deregister_handle()
|
|
os.close(self.fd)
|
|
|
|
def register_handle(self) -> None:
|
|
"""Registers file descriptor to cuFile Driver.
|
|
|
|
This is a wrapper around ``cuFileHandleRegister``.
|
|
"""
|
|
assert (
|
|
self.handle is None
|
|
), "Cannot register a handle that is already registered."
|
|
self.handle = torch._C._gds_register_handle(self.fd)
|
|
|
|
def deregister_handle(self) -> None:
|
|
"""Deregisters file descriptor from cuFile Driver.
|
|
|
|
This is a wrapper around ``cuFileHandleDeregister``.
|
|
"""
|
|
assert (
|
|
self.handle is not None
|
|
), "Cannot deregister a handle that is not registered."
|
|
torch._C._gds_deregister_handle(self.handle)
|
|
self.handle = None
|
|
|
|
def load_storage(self, storage: Storage, offset: int = 0) -> None:
|
|
"""Loads data from the file into the storage.
|
|
|
|
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
|
|
will be loaded from the file at ``offset`` into the storage.
|
|
|
|
Args:
|
|
storage (Storage): Storage to load data into.
|
|
offset (int, optional): Offset into the file to start loading from. (Default: 0)
|
|
"""
|
|
assert (
|
|
self.handle is not None
|
|
), "Cannot load data from a file that is not registered."
|
|
torch._C._gds_load_storage(self.handle, storage, offset)
|
|
|
|
def save_storage(self, storage: Storage, offset: int = 0) -> None:
|
|
"""Saves data from the storage into the file.
|
|
|
|
This is a wrapper around ``cuFileWrite``. All bytes of the storage
|
|
will be written to the file at ``offset``.
|
|
|
|
Args:
|
|
storage (Storage): Storage to save data from.
|
|
offset (int, optional): Offset into the file to start saving to. (Default: 0)
|
|
"""
|
|
assert (
|
|
self.handle is not None
|
|
), "Cannot save data to a file that is not registered."
|
|
torch._C._gds_save_storage(self.handle, storage, offset)
|