I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,113 @@
# mypy: allow-untyped-defs
import copyreg
import os.path as _osp
import weakref
import torch
from torch.utils import (
backcompat as backcompat,
collect_env as collect_env,
data as data,
deterministic as deterministic,
hooks as hooks,
)
from torch.utils.backend_registration import (
generate_methods_for_privateuse1_backend,
rename_privateuse1_backend,
)
from torch.utils.cpp_backtrace import get_cpp_backtrace
from torch.utils.throughput_benchmark import ThroughputBenchmark
def set_module(obj, mod):
"""
Set the module attribute on a python object for a given object for nicer printing
"""
if not isinstance(mod, str):
raise TypeError("The mod argument should be a string")
obj.__module__ = mod
if torch._running_with_deploy():
# not valid inside torch_deploy interpreter, no paths exists for frozen modules
cmake_prefix_path = None
else:
cmake_prefix_path = _osp.join(
_osp.dirname(_osp.dirname(__file__)), "share", "cmake"
)
def swap_tensors(t1, t2):
"""
This function swaps the content of the two Tensor objects.
At a high level, this will make t1 have the content of t2 while preserving
its identity.
This will not work if t1 and t2 have different slots.
"""
# Ensure there are no weakrefs
if weakref.getweakrefs(t1):
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
if weakref.getweakrefs(t2):
raise RuntimeError("Cannot swap t2 because it has weakref associated with it")
t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined]
t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined]
if t1_slots != t2_slots:
raise RuntimeError("Cannot swap t1 and t2 if they have different slots")
def swap_attr(name):
tmp = getattr(t1, name)
setattr(t1, name, (getattr(t2, name)))
setattr(t2, name, tmp)
def error_pre_hook(grad_outputs):
raise RuntimeError(
"Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
"this can happen when you try to run backward on a tensor that was swapped. "
"For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
"you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
"between running forward and backward. To resolve this, please only change the "
"device/dtype before running forward (or after both forward and backward)."
)
def check_use_count(t, name="t1"):
use_count = t._use_count()
error_str = (
f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
f"make sure you are not holding references to the tensor in other places."
)
if use_count > 1:
if use_count == 2 and t.is_leaf:
accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
# Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
if t._use_count() == 2:
accum_grad_node.register_prehook(error_pre_hook)
else:
raise RuntimeError(error_str)
else:
raise RuntimeError(error_str)
check_use_count(t1, "t1")
check_use_count(t2, "t2")
# Swap the types
# Note that this will fail if there are mismatched slots
swap_attr("__class__")
# Swap the dynamic attributes
swap_attr("__dict__")
# Swap the slots
for slot in t1_slots:
if hasattr(t1, slot) and hasattr(t2, slot):
swap_attr(slot)
elif hasattr(t1, slot):
setattr(t2, slot, (getattr(t1, slot)))
delattr(t1, slot)
elif hasattr(t2, slot):
setattr(t1, slot, (getattr(t2, slot)))
delattr(t2, slot)
# Swap the at::Tensor they point to
torch._C._swap_tensor_impl(t1, t2)

View File

@ -0,0 +1,114 @@
# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the
# minimum supported we should use dataclass(slots=True) instead.
from __future__ import annotations
import dataclasses
import itertools
from typing import Generator, List, Type, TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from _typeshed import DataclassInstance
__all__ = ["dataclass_slots"]
_T = TypeVar("_T", bound="DataclassInstance")
def dataclass_slots(cls: Type[_T]) -> Type[DataclassInstance]:
assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses."
def _get_slots(cls: Type[DataclassInstance]) -> Generator[str, None, None]:
slots = cls.__dict__.get("__slots__")
# `__dictoffset__` and `__weakrefoffset__` can tell us whether
# the base type has dict/weakref slots, in a way that works correctly
# for both Python classes and C extension types. Extension types
# don't use `__slots__` for slot creation
if slots is None:
slots = []
if getattr(cls, "__weakrefoffset__", -1) != 0:
slots.append("__weakref__")
if getattr(cls, "__dictrefoffset__", -1) != 0:
slots.append("__dict__")
yield from slots
elif isinstance(slots, str):
yield slots
# Slots may be any iterable, but we cannot handle an iterator
# because it will already be (partially) consumed.
elif not hasattr(cls, "__next__"):
yield from slots
else:
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
def _add_slots(
cls: Type[DataclassInstance], is_frozen: bool, weakref_slot: bool
) -> Type[DataclassInstance]:
# Need to create a new class, since we can't set __slots__
# after a class has been created.
# Make sure __slots__ isn't already set.
if "__slots__" in cls.__dict__:
raise TypeError(f"{cls.__name__} already specifies __slots__")
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
# Make sure slots don't overlap with those in base classes.
inherited_slots = set(
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
)
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
cls_dict["__slots__"] = tuple(
itertools.filterfalse(
inherited_slots.__contains__,
itertools.chain(
# gh-93521: '__weakref__' also needs to be filtered out if
# already present in inherited_slots
field_names,
("__weakref__",) if weakref_slot else (),
),
),
)
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
cls_dict.pop(field_name, None)
# Remove __dict__ itself.
cls_dict.pop("__dict__", None)
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
cls_dict.pop("__weakref__", None) # gh-102069
# And finally create the class.
qualname = getattr(cls, "__qualname__", None)
cls = type(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
def _dataclass_getstate(self: _T) -> object:
fields = dataclasses.fields(self)
return [getattr(self, f.name) for f in fields]
def _dataclass_setstate(self: _T, state: List[object]) -> None:
fields = dataclasses.fields(self)
for field, value in zip(fields, state):
# use setattr because dataclass may be frozen
object.__setattr__(self, field.name, value)
if is_frozen:
# Need this for pickling frozen classes with slots.
if "__getstate__" not in cls_dict:
cls.__getstate__ = _dataclass_getstate # type: ignore[method-assign, assignment]
if "__setstate__" not in cls_dict:
cls.__setstate__ = _dataclass_setstate # type: ignore[attr-defined]
return cls
params = getattr(cls, dataclasses._PARAMS) # type: ignore[attr-defined]
weakref_slot = getattr(params, "weakref_slot", False)
return _add_slots(cls, params.frozen, weakref_slot)

View File

@ -0,0 +1,392 @@
import contextlib
import copy
import hashlib
import inspect
import io
import pickle
import tokenize
import unittest
import warnings
from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, NoReturn, Optional, Set, Union
from typing_extensions import deprecated
from unittest import mock
# Types saved/loaded in configs
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
def install_config_module(module: ModuleType) -> None:
"""
Converts a module-level config into a `ConfigModule()`.
See _config_typing.pyi for instructions on how to get the converted module to typecheck.
"""
class ConfigModuleInstance(ConfigModule):
_bypass_keys = set({"_is_dirty", "_hash_digest"})
def visit(
source: Union[ModuleType, type],
dest: Union[ModuleType, SubConfigProxy],
prefix: str,
) -> None:
"""Walk the module structure and move everything to module._config"""
for key, value in list(source.__dict__.items()):
if (
key.startswith("__")
or isinstance(value, (ModuleType, FunctionType))
or (hasattr(value, "__module__") and value.__module__ == "typing")
):
continue
name = f"{prefix}{key}"
if isinstance(value, CONFIG_TYPES):
config[name] = value
default[name] = value
if dest is module:
delattr(module, key)
elif isinstance(value, type):
assert value.__module__ == module.__name__
# a subconfig with `class Blah:` syntax
proxy = SubConfigProxy(module, f"{name}.")
visit(value, proxy, f"{name}.")
if dest is module:
setattr(dest, key, proxy)
else:
dest.__dict__[key] = proxy
else:
raise AssertionError(f"Unhandled config {key}={value} ({type(value)})")
config: Dict[str, Any] = {}
default: Dict[str, Any] = {}
compile_ignored_keys = get_assignments_with_compile_ignored_comments(module)
visit(module, module, "")
module._config = config # type: ignore[attr-defined]
module._default = default # type: ignore[attr-defined]
module._allowed_keys = set(config.keys()) # type: ignore[attr-defined]
module._compile_ignored_keys = compile_ignored_keys # type: ignore[attr-defined]
module.__class__ = ConfigModuleInstance
module._is_dirty = True # type: ignore[attr-defined]
module._hash_digest = None # type: ignore[attr-defined]
COMPILE_IGNORED_MARKER = "@compile_ignored"
# Gets all the keys (i.e. assignments) with a @compile_ignored comment
def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str]:
source_code = inspect.getsource(module)
assignments = set()
# Tokenize the source code to retrieve comments
tokens = tokenize.tokenize(io.BytesIO(source_code.encode("utf-8")).readline)
current_comment = "", -1
prev_name = ""
for token in tokens:
if token.type == tokenize.COMMENT:
prev_name = ""
maybe_current = token.string.strip()
if COMPILE_IGNORED_MARKER in maybe_current:
assert current_comment == (
"",
-1,
), f"unconsumed {COMPILE_IGNORED_MARKER}"
current_comment = maybe_current, token.start[0]
elif token.type == tokenize.NAME:
# Only accept the first name token, to handle if you have
# something like foo: Bar = ...
if not prev_name:
prev_name = token.string
elif token.type == tokenize.OP and token.string == "=":
# Check if the current assignment follows a comment
# with COMPILE_IGNORED_MARKER
if (
COMPILE_IGNORED_MARKER in current_comment[0]
and current_comment[1] == token.start[0] - 1
):
assignments.add(prev_name)
current_comment = "", -1 # reset
prev_name = ""
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
return assignments
class ConfigModule(ModuleType):
# NOTE: This should be kept in sync with _config_typing.pyi.
# The default values of the configuration settings. This can be used to
# determine if the config has been changed or not.
_default: Dict[str, Any]
# The actual configuration settings. E.g., torch._dynamo.config.debug
# would live as "debug" in the key, and torch._inductor.config.triton.cudagraphs
# maps as "triton.cudagraphs"
_config: Dict[str, Any]
_allowed_keys: Set[str]
_bypass_keys: Set[str]
_compile_ignored_keys: Set[str]
_is_dirty: bool
_hash_digest: Optional[bytes]
def __init__(self) -> None:
raise NotImplementedError(
f"use {__name__}.install_config_module(sys.modules[__name__])"
)
def __setattr__(self, name: str, value: object) -> None:
if name in self._bypass_keys:
super().__setattr__(name, value)
elif name not in self._allowed_keys:
raise AttributeError(f"{self.__name__}.{name} does not exist")
else:
self._config[name] = value
def __getattr__(self, name: str) -> Any:
try:
return self._config[name]
except KeyError as e:
# make hasattr() work properly
raise AttributeError(f"{self.__name__}.{name} does not exist") from e
def __delattr__(self, name: str) -> None:
# must support delete because unittest.mock.patch deletes
# then recreate things
del self._config[name]
def save_config(self) -> bytes:
"""Convert config to a pickled blob"""
config = dict(self._config)
for key in config.get("_save_config_ignore", ()):
config.pop(key)
return pickle.dumps(config, protocol=2)
def save_config_portable(self) -> Dict[str, Any]:
"""Convert config to portable format"""
config: Dict[str, Any] = {}
for key in sorted(self._config):
if key.startswith("_"):
continue
if any(
key.startswith(e) for e in self._config["_cache_config_ignore_prefix"]
):
continue
config[key] = self._config[key]
return config
def codegen_config(self) -> str:
"""Convert config to Python statements that replicate current config.
This does NOT include config settings that are at default values.
"""
lines = []
mod = self.__name__
for k, v in self._config.items():
if k in self._config.get("_save_config_ignore", ()):
if v != self._default[k]:
warnings.warn(f"Skipping serialization of {k} value {v}")
continue
if v == self._default[k]:
continue
lines.append(f"{mod}.{k} = {v!r}")
return "\n".join(lines)
def get_hash(self) -> bytes:
"""Hashes the configs that are not compile_ignored"""
if self._is_dirty or self._hash_digest is None:
dict_to_hash = {
k: v
for k, v in self._config.items()
if k not in self._compile_ignored_keys
}
string_to_hash = repr(sorted(dict_to_hash.items()))
self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest()
self._is_dirty = False
return self._hash_digest
@deprecated(
"`config.to_dict()` has been deprecated. It may no longer change the underlying config."
" use `config.shallow_copy_dict()` or `config.get_config_copy()` instead",
category=FutureWarning,
)
def to_dict(self) -> Dict[str, Any]:
return self.shallow_copy_dict()
def shallow_copy_dict(self) -> Dict[str, Any]:
return {**self._config}
def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None:
"""Restore from a prior call to save_config() or shallow_copy_dict()"""
if not isinstance(maybe_pickled_config, dict):
config = pickle.loads(maybe_pickled_config)
else:
config = maybe_pickled_config
self._config.update(config)
def get_config_copy(self) -> Dict[str, Any]:
return copy.deepcopy(self._config)
def patch(
self,
arg1: Optional[Union[str, Dict[str, Any]]] = None,
arg2: Any = None,
**kwargs: Dict[str, Any],
) -> "ContextDecorator":
"""
Decorator and/or context manager to make temporary changes to a config.
As a decorator:
@config.patch("name", val)
@config.patch(name1=val1, name2=val2)
@config.patch({"name1": val1, "name2", val2})
def foo(...):
...
As a context manager:
with config.patch("name", val):
...
"""
changes: Dict[str, Any]
if arg1 is not None:
if arg2 is not None:
assert isinstance(arg1, str)
# patch("key", True) syntax
changes = {arg1: arg2}
else:
assert isinstance(arg1, dict)
# patch({"key": True}) syntax
changes = arg1
assert not kwargs
else:
# patch(key=True) syntax
changes = kwargs
assert arg2 is None
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
prior: Dict[str, Any] = {}
config = self
dirty = False
class ConfigPatch(ContextDecorator):
def __enter__(self) -> None:
assert not prior
nonlocal dirty
for key in changes.keys():
# KeyError on invalid entry
prior[key] = config._config[key]
dirty = key not in config._compile_ignored_keys
config._config.update(changes)
config._is_dirty = dirty
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
nonlocal dirty
config._config.update(prior)
config._is_dirty = dirty
prior.clear()
return ConfigPatch()
def _make_closure_patcher(self, **changes: Dict[str, Any]) -> Any:
"""
A lower-overhead version of patch() for things on the critical path.
Usage:
# do this off the critical path
change_fn = config.make_closure_patcher(foo=True)
...
revert = change_fn()
try:
...
finally:
revert()
"""
config = self._config
def change() -> Callable[[], None]:
prior = {k: config[k] for k in changes}
config.update(changes)
def revert() -> None:
config.update(prior)
return revert
return change
class ContextDecorator(contextlib.ContextDecorator):
"""
Same as contextlib.ContextDecorator, but with support for
`unittest.TestCase`
"""
def __enter__(self) -> None:
raise NotImplementedError("NYI")
def __exit__(self, exc_type, exc_val, exc_tb) -> NoReturn: # type: ignore[no-untyped-def]
raise NotImplementedError("NYI")
def __call__(self, func: Callable[[Any], Any]) -> Any:
if isinstance(func, type) and issubclass(func, unittest.TestCase):
class _TestCase(func): # type: ignore[valid-type, misc]
@classmethod
def setUpClass(cls) -> None:
self.__enter__()
try:
super().setUpClass()
except Exception:
self.__exit__(None, None, None)
raise
@classmethod
def tearDownClass(cls) -> None:
try:
super().tearDownClass()
finally:
self.__exit__(None, None, None)
_TestCase.__name__ = func.__name__
_TestCase.__qualname__ = func.__qualname__
_TestCase.__module__ = func.__module__
return _TestCase
return super().__call__(func)
class SubConfigProxy:
"""
Shim to redirect to main config.
`config.triton.cudagraphs` maps to _config["triton.cudagraphs"]
"""
def __init__(self, config: object, prefix: str):
# `super().__setattr__` to bypass custom `__setattr__`
super().__setattr__("_config", config)
super().__setattr__("_prefix", prefix)
def __setattr__(self, name: str, value: object) -> None:
return self._config.__setattr__(self._prefix + name, value)
def __getattr__(self, name: str) -> Any:
return self._config.__getattr__(self._prefix + name)
def __delattr__(self, name: str) -> None:
return self._config.__delattr__(self._prefix + name)
def patch_object(obj: object, name: str, value: object) -> object:
"""
Workaround `mock.patch.object` issue with ConfigModule
"""
if isinstance(obj, ConfigModule):
return obj.patch(name, value)
return mock.patch.object(obj, name, value)

View File

@ -0,0 +1,34 @@
# mypy: allow-untyped-defs
from typing import Any, TYPE_CHECKING
"""
This was semi-automatically generated by running
stubgen torch.utils._config_module.py
And then manually extracting the methods of ConfigModule and converting them into top-level functions.
This file should be imported into any file that uses install_config_module like so:
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
from torch.utils._config_module import install_config_module
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
"""
assert TYPE_CHECKING, "Do not use at runtime"
def save_config() -> bytes: ...
def save_config_portable() -> dict[str, Any]: ...
def codegen_config() -> str: ...
def get_hash() -> bytes: ...
def to_dict() -> dict[str, Any]: ...
def shallow_copy_dict() -> dict[str, Any]: ...
def load_config(config: bytes | dict[str, Any]) -> None: ...
def get_config_copy() -> dict[str, Any]: ...
def patch(arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs): ...

View File

@ -0,0 +1,238 @@
# mypy: allow-untyped-defs
# This module provides a FAST (on GPU) content addressable store for storages
# (and tensors on top of them) with VERY WEAK portability guarantees (e.g.,
# don't expect CPU/CUDA to address to the same hash, don't expect it to be
# portable across devices) that is NOT cryptographically secure. In return,
# we are able to hash 40G of tensor data on GPU in less than a second,
# compared to running SHA-1 in CPU which would a minute or so. The primary
# use case is for efficiently snapshotting intermediate tensor data for
# offline debugging, but it's been put in this module in case you think of
# another use case for it. The hash function could be replaced with a
# straight reimplementation of SHA-1, which would give us much stronger
# portability guarantees.
#
# WARNING: THERE IS NO BC/FC GUARANTEE FOR THIS FORMAT! If you need to format
# shift the result, consider packing it into a single torch.save object
# with traditional view sharing.
#
# Because of the weak portability guarantees, you can only write to the
# content store from a single process; we don't provide any capability
# of "reopening" a content store to add more things to it. But we don't
# assume that you can keep all of the tensors you want to add to the store
# in memory at once, because you probably can't! Nor do we assume that
# you know a priori whether or not two storages can be deduplicated or not.
#
# Note: only storages are content-addressed; tensors are name addressed
#
# Note: our padding strategy means that [1, 0] and [1] int16 tensors would
# map to the same (padded) storage. We think this will be immaterial for most
# users.
import ctypes
import functools
import hashlib
import os.path
import struct
from collections import defaultdict
from typing import Dict, Optional, Set
import torch
import torch._prims as prims
import torch._utils
import torch.nn.functional as F
from torch._C import default_generator
from torch.multiprocessing.reductions import StorageWeakRef
def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""
def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)
return compile_hook
return decorate_fn
# Use of torch.compile is mandatory for (1) good memory usage
# and (2) xor_sum implementation. This is our first instance of
# using PT2 to implement a kernel in PyTorch; if we get AOT capabilities
# it would be good to apply it here.
@lazy_compile(dynamic=True)
def hash_storage_kernel(x):
# The randint calls are carefully written to hit things we
# have lowerings for in inductor. Lack of unsigned 32-bit integer
# is a pain.
a = torch.randint(
-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32
).abs()
a = ((a % (2**31 - 1)) + 1).long()
b = (
torch.randint(-(2**31), 2**31, x.shape, device=x.device, dtype=torch.int32)
.abs()
.long()
)
# This is a standard shift-multiply universal hash family
# plus xor sum hash, using Philox to generate random numbers.
# Our Philox RNG is not deterministic across devices so
# don't use this for stable hashing.
#
# This assumes fixed length so you're also obligated to bucket
# by the length of tensor as well
return prims.xor_sum((a * x + b).int(), [0])
# Returns a hex digest of the data in the storage. Guaranteed to be
# SHA-1 if stable_hash=True, otherwise it will consistent for a single
# process run but not necessarily across processes.
def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> str:
import torch._dynamo
from torch._dynamo.utils import is_compile_supported
device_type = storage.device.type
if stable_hash or not is_compile_supported(device_type):
cpu_storage = storage.cpu()
# TODO: make storage support buffer protocol so this isn't
# necessary
buf = (ctypes.c_byte * cpu_storage.nbytes()).from_address(
cpu_storage.data_ptr()
)
sha1 = hashlib.sha1()
sha1.update(buf)
return sha1.hexdigest()
# TODO: factor this into a random utility
if device_type == "cpu":
generator = default_generator
elif device_type == "cuda":
import torch.cuda
generator = torch.cuda.default_generators[storage.device.index]
else:
raise AssertionError(f"unhandled device type {device_type}")
state = generator.get_state()
try:
generator.manual_seed(0)
x = torch.empty(0, dtype=torch.uint8, device=storage.device).set_(storage) # type: ignore[call-overload]
# The dtype-casting view cannot be compiled, and so the
# padding/reshaping also needs to be done externally even
# though it could be profitably fused
pad = -x.numel() % 4
if pad > 0:
x = F.pad(x, (0, pad), "constant", 0)
x = x.view(torch.int32)
# We run the 32-bit hash five times with differing parameters to
# reduce chance of collision
ITER = 5
cs = [hash_storage_kernel(x).item() for _ in range(ITER)]
return struct.pack(">" + "i" * ITER, *cs).hex()
finally:
generator.set_state(state)
class ContentStoreWriter:
# Structure:
# storages/
# 00/
# 0000..00
# tensors/
# name
def __init__(self, loc: str, stable_hash: bool = False) -> None:
self.loc: str = loc
self.seen_storage_hashes: Set[str] = set()
self.stable_hash = stable_hash
# TODO: offer some sort of non-blocking API to speed things up
def write_storage(self, storage: torch.UntypedStorage) -> str:
h = hash_storage(storage, stable_hash=self.stable_hash)
if h in self.seen_storage_hashes:
return h
# TODO: consider not using torch.save for this; we don't actually
# need any metadata for the storage
subfolder = os.path.join(self.loc, "storages")
os.makedirs(subfolder, exist_ok=True)
target = os.path.join(subfolder, h)
if os.path.exists(target):
return h
torch.save(storage, target)
self.seen_storage_hashes.add(h)
return h
def compute_tensor_metadata(self, t: torch.Tensor, h=None):
if h is None:
h = hash_storage(t.untyped_storage(), stable_hash=self.stable_hash)
return (
t.dtype,
h,
t.storage_offset(),
tuple(t.shape),
t.stride(),
torch._utils.get_tensor_metadata(t),
)
def write_tensor(self, name: str, t: torch.Tensor) -> None:
storage = t.untyped_storage()
h = self.write_storage(storage)
# TODO: Support more advanced snapshotting of requires_grad/grad/etc
d, f = os.path.split(name)
payload = self.compute_tensor_metadata(t, h=h)
subfolder = os.path.join(self.loc, "tensors", d)
os.makedirs(subfolder, exist_ok=True)
torch.save(payload, os.path.join(subfolder, f))
class ContentStoreReader:
def __init__(self, loc: str, *, cache=True) -> None:
self.loc = loc
self.storage_cache: Optional[
Dict[Optional[torch.device], Dict[str, StorageWeakRef]]
] = None
if cache:
self.storage_cache = defaultdict(dict)
def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage:
if device is not None:
device = torch.device(device)
ws = (
self.storage_cache[device].get(h)
if self.storage_cache is not None
else None
)
s: Optional[torch.UntypedStorage]
if ws is not None:
s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata)
if s is not None:
return s
s = torch.load(
os.path.join(self.loc, "storages", h),
weights_only=True,
map_location=device,
)._untyped_storage
assert s is not None
if self.storage_cache is not None:
self.storage_cache[device][h] = StorageWeakRef(s)
return s
def read_tensor_metadata(self, name: str):
fn = os.path.join(self.loc, "tensors", name)
if not os.path.exists(fn):
raise FileNotFoundError(fn)
return torch.load(fn, weights_only=True)
def read_tensor(self, name: str, *, device=None) -> torch.Tensor:
dtype, h, storage_offset, size, stride, metadata = self.read_tensor_metadata(
name
)
storage = self.read_storage(h, device=device)
t = torch.tensor([], dtype=dtype, device=storage.device)
t.set_(storage, storage_offset, size, stride)
torch._utils.set_tensor_metadata(t, metadata)
return t

View File

@ -0,0 +1,157 @@
# mypy: allow-untyped-defs
# Extra utilities for working with context managers that should have been
# in the standard library but are not
import functools
import inspect
import warnings
import sys
from typing import Any, Callable, TypeVar, cast
# Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
# 'no_grad' and 'enable_grad').
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
FuncType = Callable[..., Any]
F = TypeVar('F', bound=FuncType)
def _wrap_generator(ctx_factory, func):
"""
Wrap each generator invocation with the context manager factory.
The input should be a function that returns a context manager,
not a context manager itself, to handle one-shot context managers.
"""
@functools.wraps(func)
def generator_context(*args, **kwargs):
gen = func(*args, **kwargs)
# Generators are suspended and unsuspended at `yield`, hence we
# make sure the grad mode is properly set every time the execution
# flow returns into the wrapped generator and restored when it
# returns through our `yield` to our caller (see PR #49017).
try:
# Issuing `None` to a generator fires it up
with ctx_factory():
response = gen.send(None)
while True:
try:
# Forward the response to our caller and get its next request
request = yield response
except GeneratorExit:
# Inform the still active generator about its imminent closure
with ctx_factory():
gen.close()
raise
except BaseException:
# Propagate the exception thrown at us by the caller
with ctx_factory():
response = gen.throw(*sys.exc_info())
else:
# Pass the last request to the generator and get its response
with ctx_factory():
response = gen.send(request)
# We let the exceptions raised above by the generator's `.throw` or
# `.send` methods bubble up to our caller, except for StopIteration
except StopIteration as e:
# The generator informed us that it is done: take whatever its
# returned value (if any) was and indicate that we're done too
# by returning it (see docs for python's return-statement).
return e.value
return generator_context
def context_decorator(ctx, func):
"""
Like contextlib.ContextDecorator.
But with the following differences:
1. Is done by wrapping, rather than inheritance, so it works with context
managers that are implemented from C and thus cannot easily inherit from
Python classes
2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
3. Errors out if you try to wrap a class, because it is ambiguous whether
or not you intended to wrap only the constructor
The input argument can either be a context manager (in which case it must
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if not callable(ctx):
def ctx_factory():
return ctx
else:
ctx_factory = ctx
if inspect.isclass(func):
raise RuntimeError(
"Cannot decorate classes; it is ambiguous whether or not only the "
"constructor or all methods should have the context manager applied; "
"additionally, decorating a class at definition-site will prevent "
"use of the identifier as a conventional type. "
"To specify which methods to decorate, decorate each of them "
"individually."
)
if inspect.isgeneratorfunction(func):
return _wrap_generator(ctx_factory, func)
@functools.wraps(func)
def decorate_context(*args, **kwargs):
with ctx_factory():
return func(*args, **kwargs)
return decorate_context
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator."""
def __call__(self, orig_func: F) -> F:
if inspect.isclass(orig_func):
warnings.warn(
"Decorating classes is deprecated and will be disabled in "
"future versions. You should only decorate functions or methods. "
"To preserve the current behavior of class decoration, you can "
"directly decorate the `__init__` method and nothing else.",
FutureWarning,
stacklevel=2,
)
func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
else:
func = orig_func
return cast(F, context_decorator(self.clone, func))
def __enter__(self) -> None:
raise NotImplementedError
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
raise NotImplementedError
def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()
class _NoParamDecoratorContextManager(_DecoratorContextManager):
"""Allow a context manager to be used as a decorator without parentheses."""
def __new__(cls, orig_func=None):
if orig_func is None:
return super().__new__(cls)
return cls()(orig_func)

View File

@ -0,0 +1,59 @@
# mypy: allow-untyped-defs
import collections
Entry = collections.namedtuple('Entry', 'version, hash')
def update_hash(seed, value):
# Good old boost::hash_combine
# https://www.boost.org/doc/libs/1_35_0/doc/html/boost/hash_combine_id241013.html
return seed ^ (hash(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2))
def hash_source_files(hash_value, source_files):
for filename in source_files:
with open(filename) as file:
hash_value = update_hash(hash_value, file.read())
return hash_value
def hash_build_arguments(hash_value, build_arguments):
for group in build_arguments:
if group:
for argument in group:
hash_value = update_hash(hash_value, argument)
return hash_value
class ExtensionVersioner:
def __init__(self):
self.entries = {}
def get_version(self, name):
entry = self.entries.get(name)
return None if entry is None else entry.version
def bump_version_if_changed(self,
name,
source_files,
build_arguments,
build_directory,
with_cuda,
is_python_module,
is_standalone):
hash_value = 0
hash_value = hash_source_files(hash_value, source_files)
hash_value = hash_build_arguments(hash_value, build_arguments)
hash_value = update_hash(hash_value, build_directory)
hash_value = update_hash(hash_value, with_cuda)
hash_value = update_hash(hash_value, is_python_module)
hash_value = update_hash(hash_value, is_standalone)
entry = self.entries.get(name)
if entry is None:
self.entries[name] = entry = Entry(0, hash_value)
elif hash_value != entry.hash:
self.entries[name] = entry = Entry(entry.version + 1, hash_value)
return entry.version

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
from torch.utils._contextlib import context_decorator
from torch._C import _len_torch_function_stack
import functools
CURRENT_DEVICE: Optional[torch.device] = None
@functools.lru_cache(1)
def _device_constructors():
return {
# standard ones
torch.empty,
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.ones,
torch.arange,
torch.bartlett_window,
torch.blackman_window,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,
torch.full,
torch.fill,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.nested.nested_tensor,
# This function doesn't actually take a device argument
# torch.normal,
torch.ones,
torch.rand,
torch.randn,
torch.randint,
torch.randperm,
torch.range,
torch.sparse_coo_tensor,
torch.sparse_compressed_tensor,
torch.sparse_csr_tensor,
torch.sparse_csc_tensor,
torch.sparse_bsr_tensor,
torch.sparse_bsc_tensor,
torch.tril_indices,
torch.triu_indices,
torch.vander,
torch.zeros,
torch.asarray,
# weird ones
torch.tensor,
torch.as_tensor,
torch.scalar_tensor,
torch.asarray,
}
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
self.device = torch.device(device)
def __enter__(self):
global CURRENT_DEVICE
self.old_device = CURRENT_DEVICE
CURRENT_DEVICE = self.device
# We need to put the device at the bottom of the stack
# If we set default device within a function mode context
# exiting that context mode will pop the device function mode off
# of the stack incorrectly
cur_stack = []
for _ in range(_len_torch_function_stack()):
cur_stack.append(_pop_mode())
_push_mode(self)
for mode in reversed(cur_stack):
_push_mode(mode)
def __exit__(self, exc_type, exc_val, exc_tb):
global CURRENT_DEVICE
CURRENT_DEVICE = self.old_device
cur_stack = []
# Invariant: there should only be one DeviceContext on the stack at any time
# (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext
# or else someone else has popped it!
for _ in range(_len_torch_function_stack() - 1):
mode = _pop_mode()
assert not isinstance(mode, DeviceContext)
cur_stack.append(mode)
if _len_torch_function_stack() > 0:
mode = _pop_mode()
assert isinstance(mode, DeviceContext)
for mode in reversed(cur_stack):
_push_mode(mode)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)
# NB: This is directly called from C++ in torch/csrc/Device.cpp
def device_decorator(device, func):
return context_decorator(lambda: device, func)
def set_device(device):
"""
Set the default device inside of the wrapped function by decorating it with this function.
If you would like to use this as a context manager, use device as a
context manager directly, e.g., ``with torch.device(device)``.
"""
return lambda func: device_decorator(torch.device(device), func)

View File

@ -0,0 +1,15 @@
# mypy: allow-untyped-defs
# Allows one to expose an API in a private submodule publicly as per the definition
# in PyTorch's public api policy.
#
# It is a temporary solution while we figure out if it should be the long-term solution
# or if we should amend PyTorch's public api policy. The concern is that this approach
# may not be very robust because it's not clear what __module__ is used for.
# However, both numpy and jax overwrite the __module__ attribute of their APIs
# without problem, so it seems fine.
def exposed_in(module):
def wrapper(fn):
fn.__module__ = module
return fn
return wrapper

View File

@ -0,0 +1,44 @@
from typing import List, Dict, Tuple, Optional
import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad
from typing_extensions import TypeAlias
def _get_foreach_kernels_supported_devices() -> List[str]:
r"""Return the device type list that supports foreach kernels."""
return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
def _get_fused_kernels_supported_devices() -> List[str]:
r"""Return the device type list that supports fused kernels in optimizer."""
return ["mps", "cuda", "xpu", "cpu", torch._C._get_privateuse1_backend_name()]
TensorListList: TypeAlias = List[List[Optional[Tensor]]]
Indices: TypeAlias = List[int]
_foreach_supported_types = [torch.Tensor]
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
# - given an index i, all specified tensorlist[i]s match in dtype and device
# with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
# It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@no_grad()
def _group_tensors_by_device_and_dtype(
tensorlistlist: TensorListList,
with_indices: bool = False,
) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]:
return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
def _device_has_foreach_support(device: torch.device) -> bool:
return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
return _device_has_foreach_support(device) and all(t is None or type(t) in _foreach_supported_types for t in tensors)

View File

@ -0,0 +1,291 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
"""
Freeze Python packages.
Freezing makes it possible to ship arbitrary Python modules as part of a C++
library. The Python source of the module is compiled to bytecode and written
to `.c` files, to be imported by Python's built-in FrozenImporter.
In a normal Python installation, FrozenImporter is only used to bootstrap the
initialization of the import machinery. Python's importers are defined in
Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be
retrieved before any importers are available. Freezing the module bytecode
resolves this circular dependency.
This script will freeze the Python standard library. It produces two things:
- Bytecode files: A set of `.c` that define C variables containing Python bytecode.
- Main file: A `main.c` file listing all of these modules in the right form to be
consumed by FrozenImporter.
The library that wishes to these modules make them available to the local
Python instance by extending `PyImport_FrozenModules` appropriately (see
https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules).
"""
import argparse
import functools
import itertools
import marshal
import os
import types
from dataclasses import dataclass
from pathlib import Path
from typing import List
PATH_MARKER = "<Generated by torch::deploy>"
MAIN_INCLUDES = """#include <Python.h>
"""
MAIN_PREFIX_TEMPLATE = """
// Compiled standard library modules. These should be appended to the existing
// `PyImport_FrozenModules` that ships with CPython.
struct _frozen {}[] = {{
"""
FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules")
MAIN_SUFFIX = """\
{0, 0, 0} /* sentinel */
};
"""
# Exclude some standard library modules to:
# 1. Slim down the final frozen lib.
# 2. Remove functionality we don't want to support.
DENY_LIST = [
# Interface to unix databases
"dbm",
# ncurses bindings (terminal interfaces)
"curses",
# Tcl/Tk GUI
"tkinter",
"tkinter",
# Tests for the standard library
"test",
"tests",
"idle_test",
"__phello__.foo.py",
# importlib frozen modules. These are already baked into CPython.
"_bootstrap.py",
"_bootstrap_external.py",
]
NUM_BYTECODE_FILES = 5
def indent_msg(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
args[0].indent += 1
ret = fn(*args, **kwargs)
args[0].indent -= 1
return ret
return wrapper
@dataclass
class FrozenModule:
# The fully qualified module name, e.g. 'foo.bar.baz'
module_name: str
# The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz'
c_name: str
# The size of the C variable. Negative if this module is a package.
size: int
# The frozen bytecode
bytecode: bytes
class Freezer:
def __init__(self, verbose: bool):
self.frozen_modules: List[FrozenModule] = []
self.indent: int = 0
self.verbose: bool = verbose
def msg(self, path: Path, code: str):
if not self.verbose:
return
# P: package dir
# F: python file
# S: skipped (not a package dir)
# X: skipped (deny-listed)
# N: skipped (not a python file)
for i in range(self.indent):
print(" ", end="")
print(f"{code} {path}")
def write_bytecode(self, install_root):
"""
Write the `.c` files containing the frozen bytecode.
Shared frozen modules evenly across the files.
"""
bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
bytecode_files = [
open(os.path.join(install_root, name), "w") for name in bytecode_file_names
]
it = itertools.cycle(bytecode_files)
for m in self.frozen_modules:
self.write_frozen(m, next(it))
for f in bytecode_files:
f.close()
def write_main(self, install_root, oss, symbol_name):
"""Write the `main.c` file containing a table enumerating all the frozen modules."""
with open(os.path.join(install_root, "main.c"), "w") as outfp:
outfp.write(MAIN_INCLUDES)
for m in self.frozen_modules:
outfp.write(f"extern unsigned char {m.c_name}[];\n")
outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name))
for m in self.frozen_modules:
outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n')
outfp.write(MAIN_SUFFIX)
if oss:
outfp.write(FAKE_PREFIX)
outfp.write(MAIN_SUFFIX)
def write_frozen(self, m: FrozenModule, outfp):
"""Write a single frozen module's bytecode out to a C variable."""
outfp.write(f"unsigned char {m.c_name}[] = {{")
for i in range(0, len(m.bytecode), 16):
outfp.write("\n\t")
for c in bytes(m.bytecode[i : i + 16]):
outfp.write("%d," % c)
outfp.write("\n};\n")
def compile_path(self, path: Path, top_package_path: Path):
"""Entry point for compiling a Path object."""
if path.is_dir():
self.compile_package(path, top_package_path)
else:
self.compile_file(path, top_package_path)
@indent_msg
def compile_package(self, path: Path, top_package_path: Path):
"""Compile all the files within a Python package dir."""
assert path.is_dir()
if path.name in DENY_LIST:
self.msg(path, "X")
return
# Python packages are directories that have __init__.py in them.
is_package_dir = any(child.name == "__init__.py" for child in path.iterdir())
if not is_package_dir:
self.msg(path, "S")
return
self.msg(path, "P")
# Recursively compile all children in this dir
for child in path.iterdir():
self.compile_path(child, top_package_path)
def get_module_qualname(self, file_path: Path, top_package_path: Path) -> List[str]:
# `path` looks like 'Lib/foo/bar/baz.py'
# chop off 'Lib/' to get something that represents a Python module hierarchy.
# e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz'
normalized_path = file_path.relative_to(top_package_path.parent)
if normalized_path.name == "__init__.py":
# Special handling for `__init__.py`. In this case, this file
# specifies that the containing directory should be treated as a package.
# For 'foo/bar/baz/__init__.py':
# - The module name is 'baz'
module_basename = normalized_path.parent.name
# - The parent is foo.bar (need to shave off the 'baz')
module_parent = normalized_path.parent.parent.parts
else:
module_basename = normalized_path.stem
module_parent = normalized_path.parent.parts
return list(module_parent) + [module_basename]
def compile_string(self, file_content: str) -> types.CodeType:
# instead of passing in the real build time path to 'compile', we
# pass in a marker instead. This prevents the build time path being
# leaked to runtime. That path may not be available at runtime.
# Setting the path to a mark make sure it's a hard error rather
# than a flaky error when inspect module tries to retrieve python source
# code during torchscripting.
path_marker = PATH_MARKER
return compile(file_content, path_marker, "exec")
@indent_msg
def compile_file(self, path: Path, top_package_path: Path):
"""
Compile a Python source file to frozen bytecode.
Append the result to `self.frozen_modules`.
"""
assert path.is_file()
if path.suffix != ".py":
self.msg(path, "N")
return
if path.name in DENY_LIST:
self.msg(path, "X")
return
self.msg(path, "F")
module_qualname = self.get_module_qualname(path, top_package_path)
module_mangled_name = "__".join(module_qualname)
c_name = "M_" + module_mangled_name
with open(path) as src_file:
co = self.compile_string(src_file.read())
bytecode = marshal.dumps(co)
size = len(bytecode)
if path.name == "__init__.py":
# Python packages are signified by negative size.
size = -size
self.frozen_modules.append(
FrozenModule(".".join(module_qualname), c_name, size, bytecode)
)
def main() -> None:
parser = argparse.ArgumentParser(description="Compile py source")
parser.add_argument("paths", nargs="*", help="Paths to freeze.")
parser.add_argument("--verbose", action="store_true", help="Print debug logs")
parser.add_argument(
"--install-dir", "--install_dir", help="Root directory for all output files"
)
parser.add_argument(
"--oss",
action="store_true",
help="If it's OSS build, add a fake _PyImport_FrozenModules",
)
parser.add_argument(
"--symbol-name",
"--symbol_name",
help="The name of the frozen module array symbol to generate",
default="_PyImport_FrozenModules_torch",
)
args = parser.parse_args()
f = Freezer(args.verbose)
for p in args.paths:
path = Path(p)
if path.is_dir() and not Path.exists(path / "__init__.py"):
# this 'top level path p' is a standard directory containing modules,
# not a module itself
# each 'mod' could be a dir containing __init__.py or .py file
# NB: sorted to make sure this is deterministic
for mod in sorted(path.glob("*")):
f.compile_path(mod, mod)
else:
f.compile_path(path, path)
f.write_bytecode(args.install_dir)
f.write_main(args.install_dir, args.oss, args.symbol_name)
if __name__ == "__main__":
main() # pragma: no cover

View File

@ -0,0 +1,152 @@
# mypy: allow-untyped-defs
import argparse
import os
import re
from pathlib import Path
from typing import Dict, List
def remove_triton_function_declaration(source_code: str) -> str:
remove_head = re.sub(r"(\n.+\s\'\'\'\n)", "\n", source_code)
remove_tail = re.sub(r"(\'\'\'\,.+)", "\n", remove_head)
return remove_tail
def remove_async_compile(source_code: str) -> str:
remove_top_level = str.replace(source_code, "async_compile = AsyncCompile()", "")
remove_compile = str.replace(remove_top_level, "async_compile.wait(globals())", "")
remove_del = str.replace(remove_compile, "del async_compile", "")
return remove_del
def rename_kernels(source_code: str) -> str:
pattern = r"(\w+)\s*=\s*async_compile\.triton\('triton_',\s"
triton_kernel_decl = "def triton_"
matches = [
(match.end(), match.group(1))
for match in re.finditer(pattern, source_code, re.DOTALL)
]
# Starting from the last match to avoid issues with shifting indices after replacements
for end_index, captured_string in reversed(matches):
# Find the index of the next "B" after the current match
index_of_B = source_code.find(triton_kernel_decl, end_index)
if index_of_B != -1:
# Replace the triton_kernel_decl with the captured string
source_code = (
source_code[:index_of_B]
+ f"def {captured_string}"
+ source_code[index_of_B + len(triton_kernel_decl) :]
)
else:
# If triton_kernel_decl is not found after the current match, continue to the next
continue
return source_code
def merge_params(original_params: List[str], new_params: List[str]) -> List[str]:
assert len(new_params) >= len(original_params)
for idx in range(len(new_params)):
if new_params[idx] == "T":
new_params[idx] = original_params[idx]
return new_params
def add_launch_params(original: str, kernel_to_params: Dict[str, str]) -> str:
# Regex to match the function call in the original string
pattern = r"(\w+)\.run\((.*), grid=(.*\)), [^)]*\)"
def replace(match) -> str:
# Extract parts from the regex match
func_name = match.group(1)
params = match.group(2)
grid = match.group(3)
new_params = kernel_to_params[func_name]
new_params = merge_params(params.split(", "), new_params.split(", "))
# Format the new function call
new_string = f"{func_name}[{grid}]({', '.join(new_params)})"
return new_string
transformed = re.sub(pattern, replace, original)
remove_inductor_wrappers = re.sub(
r"@triton_heuristics[^@]*@triton.jit",
r"@triton.jit",
transformed,
flags=re.DOTALL,
)
return remove_inductor_wrappers
def process_file(input_filename: str, output_filename: str) -> str:
with open(input_filename) as file:
source_code = file.read()
transformed_code = source_code
if "def triton_(" in source_code:
raise RuntimeError(
"Need to run original Pytorch code generating kernels with TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1"
)
# transformed_code = rename_kernels(transformed_code)
transformed_code = remove_triton_function_declaration(transformed_code)
transformed_code = remove_async_compile(transformed_code)
launch_params_filename = f"{input_filename}.launch_params"
if not os.path.exists(launch_params_filename):
raise RuntimeError(
f"Missing {launch_params_filename}. Run `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1 python {input_filename} first."
)
with open(launch_params_filename) as f:
launch_params_meta = f.readlines()
split_params = [i.split("|") for i in launch_params_meta]
strip_params = [[a.strip(), b.strip()] for a, b in split_params]
kernel_to_args: Dict[str, str] = dict(strip_params)
transformed_code = add_launch_params(transformed_code, kernel_to_args)
with open(output_filename, "w") as file:
file.write(transformed_code)
return transformed_code
def get_clean_triton(
input_path: Path, output_path: Path = Path("triton_only_repro.py")
):
"""Run experiments and output results to file
Args:
input_path (Optional[Path]): Path to inductor generated output codede
output_path (Optional[Path]): Path to write out the new python file
"""
return process_file(str(input_path), str(output_path))
if __name__ == "__main__":
"""Sample usage:
# Running sweep
python inputcode.py
"""
parser = argparse.ArgumentParser(
description="Clean Inductor generated code to remove Inductor dependencies"
)
# Add the arguments
parser.add_argument(
"input_path", type=Path, help="Path to inductor generated output code"
)
parser.add_argument(
"--output_path",
type=Path,
default=Path("triton_only_repro.py"),
help="Path to write out the clean triton output",
)
# Parse the arguments
args = parser.parse_args()
# Call the function with parsed arguments
result = get_clean_triton(args.input_path, args.output_path)

View File

@ -0,0 +1,43 @@
# mypy: allow-untyped-defs
import functools
import importlib.util
import torch
def _check_module_exists(name: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
try:
spec = importlib.util.find_spec(name)
return spec is not None
except ImportError:
return False
@functools.lru_cache
def dill_available():
return (
_check_module_exists("dill")
# dill fails to import under torchdeploy
and not torch._running_with_deploy()
)
@functools.lru_cache
def import_dill():
if not dill_available():
return None
import dill
# XXX: By default, dill writes the Pickler dispatch table to inject its
# own logic there. This globally affects the behavior of the standard library
# pickler for any user who transitively depends on this module!
# Undo this extension to avoid altering the behavior of the pickler globally.
dill.extend(use_dill=False)
return dill

View File

@ -0,0 +1,11 @@
# mypy: allow-untyped-defs
import torch
from typing import TypeVar
T = TypeVar('T')
# returns if all are the same mode
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
no_dispatch = torch._C._DisableTorchDispatch

View File

@ -0,0 +1,180 @@
from __future__ import annotations
from collections.abc import MutableSet, Set as AbstractSet
from typing import (
Any,
cast,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
)
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
__all__ = ["OrderedSet"]
# Using Generic[T] bc py38 does not support type parameterized MutableSet
class OrderedSet(MutableSet, Generic[T]):
"""
Insertion ordered set, similar to OrderedDict.
"""
__slots__ = ("_dict",)
def __init__(self, iterable: Optional[Iterable[T]] = None):
self._dict = dict.fromkeys(iterable, None) if iterable is not None else {}
@staticmethod
def _from_dict(dict_inp: Dict[T, None]) -> OrderedSet[T]:
s: OrderedSet[T] = OrderedSet()
s._dict = dict_inp
return s
#
# Required overriden abstract methods
#
def __contains__(self, elem: object) -> bool:
return elem in self._dict
def __iter__(self) -> Iterator[T]:
return iter(self._dict)
def __len__(self) -> int:
return len(self._dict)
def add(self, elem: T) -> None:
self._dict[elem] = None
def discard(self, elem: T) -> None:
self._dict.pop(elem, None)
def clear(self) -> None:
# overridden because MutableSet impl is slow
self._dict.clear()
# Unimplemented set() methods in _collections_abc.MutableSet
@classmethod
def _wrap_iter_in_set(cls, other: Any) -> Any:
"""
Wrap non-Set Iterables in OrderedSets
Some of the magic methods are more strict on input types than
the public apis, so we need to wrap inputs in sets.
"""
if not isinstance(other, AbstractSet) and isinstance(other, Iterable):
return cls(other)
else:
return other
def pop(self) -> T:
if not self:
raise KeyError("pop from an empty set")
return self._dict.popitem()[0]
def copy(self) -> OrderedSet[T]:
return OrderedSet._from_dict(self._dict.copy())
def difference(self, *others: Iterable[T]) -> OrderedSet[T]:
res = self.copy()
res.difference_update(*others)
return res
def difference_update(self, *others: Iterable[T]) -> None:
for other in others:
self -= other # type: ignore[operator, arg-type]
def update(self, *others: Iterable[T]) -> None:
for other in others:
self |= other # type: ignore[operator, arg-type]
def intersection(self, *others: Iterable[T]) -> OrderedSet[T]:
res = self.copy()
for other in others:
if other is not self:
res &= other # type: ignore[operator, arg-type]
return res
def intersection_update(self, *others: Iterable[T]) -> None:
for other in others:
self &= other # type: ignore[operator, arg-type]
def issubset(self, other: Iterable[T]) -> bool:
return self <= self._wrap_iter_in_set(other)
def issuperset(self, other: Iterable[T]) -> bool:
return self >= self._wrap_iter_in_set(other)
def symmetric_difference(self, other: Iterable[T]) -> OrderedSet[T]:
return self ^ other # type: ignore[operator, arg-type]
def symmetric_difference_update(self, other: Iterable[T]) -> None:
self ^= other # type: ignore[operator, arg-type]
def union(self, *others: Iterable[T]) -> OrderedSet[T]:
res = self.copy()
for other in others:
if other is self:
continue
res |= other # type: ignore[operator, arg-type]
return res
# Specify here for correct type inference, otherwise would
# return AbstractSet[T]
def __sub__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# following cpython set impl optimization
if isinstance(other, OrderedSet) and (len(self) * 4) > len(other):
out = self.copy()
out -= other
return out
return cast(OrderedSet[T], super().__sub__(other))
def __ior__(self, other: Iterable[T]) -> OrderedSet[T]: # type: ignore[misc, override] # noqa: PYI034
if isinstance(other, OrderedSet):
self._dict.update(other._dict)
return self
return super().__ior__(other) # type: ignore[arg-type]
def __eq__(self, other: AbstractSet[T]) -> bool: # type: ignore[misc, override]
if isinstance(other, OrderedSet):
return self._dict == other._dict
return super().__eq__(other) # type: ignore[arg-type]
def __ne__(self, other: AbstractSet[T]) -> bool: # type: ignore[misc, override]
if isinstance(other, OrderedSet):
return self._dict != other._dict
return super().__ne__(other) # type: ignore[arg-type]
def __or__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
return cast(OrderedSet[T], super().__or__(other))
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# MutableSet impl will iterate over other, iter over smaller of two sets
if isinstance(other, OrderedSet) and len(self) < len(other):
return other & self
return cast(OrderedSet[T], super().__and__(other))
def __xor__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
return cast(OrderedSet[T], super().__xor__(other))
def __repr__(self) -> str:
return f"{self.__class__.__name__}({list(self)})"
def __getstate__(self) -> List[T]:
return list(self._dict.keys())
def __setstate__(self, state: List[T]) -> None:
self._dict = dict.fromkeys(state, None)
def __reduce__(self) -> Tuple[Type[OrderedSet[T]], Tuple[List[T]]]:
return (OrderedSet, (list(self),))

View File

@ -0,0 +1,701 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
from typing_extensions import TypeGuard
from collections import deque
import torch
import torchgen
import torchgen.model
from torch._C import (
_get_dispatch_stack_at,
_len_torch_dispatch_stack,
_pop_torch_dispatch_stack,
_push_on_torch_dispatch_stack,
DispatchKey,
)
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
_is_in_torch_dispatch_mode = False
_is_in_non_infra_torch_dispatch_mode = False
def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
without having to actually create a tensor subclass or manually
monkey-patch functions in the PyTorch API. Some common situations
where you should use a mode:
* You want to override the meaning of factory functions, or other
functions that do not otherwise take a tensor as an argument
(these cannot be overridden with tensor subclasses).
* You want to override the behavior of all functions without needing
to wrap your inputs in tensor subclasses; e.g., if you are just
interested in logging intermediate computations.
* You want to control the order of execution of various tensor
subclasses explicitly, rather than implicitly via the return of
``NotImplemented``.
Independent subclasses of :class:`TorchDispatchMode` are compositional:
modes can be pushed onto a stack using ``with MyMode():``.
When you call functions in the PyTorch API inside your
``__torch_dispatch__`` implementation, by default, they will forward on to
the next mode on the mode stack. If you want recursively call back into
your current ``__torch_dispatch__`` implementation, either explicitly
invoke ``self.__torch_dispatch__(...)``, or use the context manager
``__torch_dispatch__(self)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
def __init__(self, _dispatch_key=None):
if _dispatch_key is not None:
assert isinstance(_dispatch_key, torch._C.DispatchKey)
self.__dict__["_dispatch_key"] = _dispatch_key
self.old_dispatch_mode_flags: Deque[bool] = deque()
self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque()
def _lazy_init_old_dispatch_mode_flags(self):
if not hasattr(self, "old_dispatch_mode_flags"):
self.old_dispatch_mode_flags: Deque[bool] = deque() # type: ignore[no-redef]
if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque() # type: ignore[no-redef]
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError
def __enter__(self):
global _is_in_torch_dispatch_mode
global _is_in_non_infra_torch_dispatch_mode
# Previously, there wasn't any state in this class' constructor
# super calls were added to existing modes, but for any new modes
# this will replicate the previous behavior of not strictly needing
# to call super().__init__()
self._lazy_init_old_dispatch_mode_flags()
self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode)
_is_in_torch_dispatch_mode = True
self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode)
_is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
_push_mode(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None)
if mb_dk_or_mode_key is None:
# Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch)
# We should probably revisit this.
mb_dk_or_mode_key = self.__dict__.get("_mode_key", None)
global _is_in_torch_dispatch_mode
_is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop()
global _is_in_non_infra_torch_dispatch_mode
_is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop()
_pop_mode(mb_dk_or_mode_key)
@classmethod
def push(cls, *args, **kwargs):
warnings.warn(
"`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
)
instance = cls(*args, **kwargs)
return instance
@classmethod
def is_infra_mode(cls):
return False
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
# Return a user mode on the stack if there are any
if stack_len > 0:
return _get_dispatch_stack_at(stack_len - 1)
return None
def _detect_infra_mode(key):
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
from torch._ops import _get_dispatch_mode_pre_dispatch
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
key
)
post_dispatch_mode = torch._C._get_dispatch_mode(
key
)
assert (pre_dispatch_mode is None) or (
post_dispatch_mode is None
)
if pre_dispatch_mode is None:
return post_dispatch_mode
return pre_dispatch_mode
def _unset_infra_mode(key):
from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
post_dispatch_mode = torch._C._get_dispatch_mode(key)
if pre_dispatch_mode and post_dispatch_mode:
raise AssertionError(
"Can't have active infra mode on both pre and post dispatch mode stack"
)
if pre_dispatch_mode:
mode = unset_mode_pre_dispatch(key)
return mode
if post_dispatch_mode:
return torch._C._unset_dispatch_mode(key)
def _disable_infra_mode(key):
assert key in (
torch._C._TorchDispatchModeKey.FUNCTIONAL,
torch._C._TorchDispatchModeKey.PROXY,
)
mode_unset = _unset_infra_mode(key)
try:
yield mode_unset
finally:
if mode_unset is not None:
_push_mode(mode_unset)
def _get_current_dispatch_mode_stack():
stack_len = _len_torch_dispatch_stack()
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode: TorchDispatchMode):
k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
assert k is None or k == torch._C.DispatchKey.PreDispatch
if k is None:
_push_on_torch_dispatch_stack(mode)
return
from torch._ops import _set_mode_pre_dispatch, get_cached_ops
# See Note [Not Caching Per-Dispatch-Key Mode Handlers]
# Clear the cache of every op that has been used so far, for this particular key.
ks = torch._C._functionality_to_backend_keys(k)
for op in get_cached_ops():
for key in ks:
op._uncache_dispatch(key)
_set_mode_pre_dispatch(mode)
def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None):
if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined]
from torch._ops import _pop_mode_from_pre_dispatch
return _pop_mode_from_pre_dispatch()
if k is None or isinstance(k, torch._C._TorchDispatchModeKey):
return _pop_torch_dispatch_stack(k)
@contextlib.contextmanager
def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
old = _pop_mode(k)
try:
yield old
finally:
_push_mode(old)
@contextlib.contextmanager
def _disable_current_modes():
from torch._ops import (
_len_torch_dispatch_stack_pre_dispatch,
_pop_mode_from_pre_dispatch,
)
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch._subclasses.schema_check_mode import SchemaCheckMode
mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
old_pre_dispatch_modes = [
_pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch)
]
has_proxy_mode_in_pre_dispatch = False
has_functional_mode_in_pre_dispatch = False
has_schema_check_mode_in_pre_dispatch = False
for i in old_pre_dispatch_modes:
if isinstance(i, ProxyTorchDispatchMode):
has_proxy_mode_in_pre_dispatch = True
if isinstance(i, FunctionalTensorMode):
has_functional_mode_in_pre_dispatch = True
if isinstance(i, SchemaCheckMode):
has_schema_check_mode_in_pre_dispatch = True
mode_len = _len_torch_dispatch_stack()
old_modes = [_pop_mode() for _ in range(mode_len)]
for old in old_modes:
if (
isinstance(old, FunctionalTensorMode)
and has_functional_mode_in_pre_dispatch
):
raise AssertionError(
"Can't have FunctionalMode available both in PreDispatch and Python Key"
)
if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch:
raise AssertionError(
"Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
)
if (
isinstance(old, SchemaCheckMode)
and has_schema_check_mode_in_pre_dispatch
):
raise AssertionError(
"Can't have SchemaCheckMode available both in PreDispatch and Python Key"
)
# Manually disable proxy and fake modes, if any are active
try:
yield old_pre_dispatch_modes + old_modes
finally:
for mode in reversed(old_modes):
_push_mode(mode)
for mode in reversed(old_pre_dispatch_modes):
_push_mode(mode)
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
class TensorWithFlatten(Protocol):
def __tensor_flatten__(self) -> Tuple[Sequence[str], object]:
...
@staticmethod
def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor:
...
# It would be really nice to be able to say that the return of
# is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
# TensorWithFlatten] - but that doesn't exist.
shape: torch._C.Size
@overload
def stride(self, dim: None = None) -> Tuple[int, ...]:
...
@overload
def stride(self, dim: int) -> int:
...
def dim(self) -> int:
...
@overload
def to(
self,
dtype: torch.types._dtype,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
) -> torch.Tensor:
...
@overload
def to(
self,
device: Optional["torch._prims_common.DeviceLikeType"] = None,
dtype: Optional[torch.types._dtype] = None,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
) -> torch.Tensor:
...
@overload
def to(
self,
other: torch.Tensor,
non_blocking: bool = False,
copy: bool = False,
*,
memory_format: Optional[torch.memory_format] = None
) -> torch.Tensor:
...
def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
"""
Returns whether or not a tensor subclass that implements __torch_dispatch__
is 'traceable' with torch.compile.
In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
It is also expected to obey some restrictions around traceability and aliasing:
* The subclass's __torch_dispatch__() implementation should desugar into pytorch
dispatcher operations that can be traced into a graph.
* The subclass should use return_and_correct_aliasing(). This is needed today to make
sure that torch.compile does the right thing in a few cases around input mutation
and output aliasing.
Expected magic method signatures:
attrs, ctx = t.__tensor_flatten__()
attrs: list of attribute name strings for inner tensors
ctx: dict containing any other subclass-specific metadata needed for unflattening
t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
inner_tensors: dict mapping attribute name -> tensor for each inner tensor
ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
outer_size: expected (possibly symbolic) size that the returned subclass
instance should have. Note that this arg is useful for certain subclasses
that require the shape info to be constructed. In most cases, this arg can be
safely ignored.
outer_stride: expected (possibly symbolic) stride that the returned subclass
instance should have. Note that this arg is useful for certain subclasses
that require the stride info to be constructed. In most cases, this arg can be
safely ignored.
"""
is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
return (
is_subclass
and hasattr(t, "__tensor_flatten__")
and hasattr(t, "__tensor_unflatten__")
)
def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
"""Same as above, but takes a type argument instead of an instance."""
return (issubclass(t, torch.Tensor) and t != torch.Tensor
and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
def transform_subclass(t, callback, outer_size=None, outer_stride=None):
"""
Given a traceable, wrapper tensor subclass ``t`` that implements
``__torch_dispatch__`` and holds some inner tensors,
and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
`transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
It will do so by grabbing each inner tensor attribute from the wrapper,
passing them into ``callback`` to get a transformed tensor,
and putting each transformed tensor into the fresh tensor subclass instance.
Note: this function will not handle ensuring that the fresh subclass
gets the same (autograd, and aliasing) metadata as the original tensor.
This is generally handled in other subsystems like AOTAutograd.
"""
outer_size = outer_size if outer_size is not None else t.size()
outer_stride = outer_stride if outer_stride is not None else t.stride()
attrs, ctx = t.__tensor_flatten__()
transformed_tensors_dict = {}
for attr in attrs:
transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
sub = type(t).__tensor_unflatten__(
transformed_tensors_dict, ctx, outer_size, outer_stride
)
# NB: Purposefully guard here to simplify the inner / outer symbols.
# Using sym_eq() for symbolic comparison can result in an expression that's too
# difficult to guard on, so we use == here.
assert sub.shape == outer_size, (
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"shape equal to {outer_size}, but got: {sub.shape}"
)
assert sub.stride() == outer_stride, (
f"Expected return value from {type(t)}__tensor_unflatten__() to have "
f"stride equal to {outer_stride}, but got: {sub.stride()}"
)
return sub
def _correct_storage_aliasing(func, schema_info, args, outs):
"""
Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema),
and the inputs/outputs to the OpOverload,
this function checks to see if func is a view operator
(by checking if any of the outputs in the op's schema
are immutable aliases of inputs).
If so, this function manually aliases the storage of the output tensor
with its corresponding input tensor alias.
It does this by unsafely overwriting the storage field of the output tensor
to be the same storage as the input.
"""
assert isinstance(func, torch._ops.OpOverload)
assert isinstance(args, tuple)
assert isinstance(outs, (list, tuple))
flat_outs = torch.utils._pytree.tree_leaves(outs)
def alias_non_inplace_storage(arg, ret):
# This is hopefully a reasonable assert:
# subclasses that rely on this API for output aliasing
# should always return wrapper tensor subclasses for us to manually alias.
# in theory if a subclass that needs this API wants to sometimes return
# plain tensors, we could remove the assert and just not perform the aliasing,
# but it seems safer to learn more about this case first.
if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
ret_list = ret if isinstance(ret, list) else [ret]
for r in ret_list:
assert type(arg) == type(
r
), f"""Called {str(func)} with input of type {type(arg)}
and output of type {type(ret)}. But expected types to match."""
# Need to call a non-dispatcher helper, because we explicitly do **not**
# want our subclass to intercept the set_() call.
# instead, our subclass should directly have its storage swapped out.
# we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change.
# Why?
# The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
# We just want to "fix up" the storage aliasing, without modifying or output's metadata.
# Example: out = inp.expand(inp.shape[0], inp.shape[0])
# This requires swapping the storage of out to be the same as inp,
# but we do *not* want it to change the sizes/strides that were compute for out.
if isinstance(ret, list):
for r in ret:
torch._functionalize_unsafe_set(r, arg)
else:
assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
torch._functionalize_unsafe_set(ret, arg)
def is_read_only_alias_match(arg, ret):
shared_aliases = arg.alias_set & ret.alias_set
return len(shared_aliases) > 0 and not arg.is_write
num_args = len(func._schema.arguments)
num_returns = len(func._schema.returns)
for arg_idx in range(num_args):
for return_idx in range(num_returns):
if is_read_only_alias_match(
schema_info.args[arg_idx], schema_info.outs[return_idx]
):
alias_non_inplace_storage(args[arg_idx], outs[return_idx])
# This abstracts over the fact that in return_and_correct_aliasing,
# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
@dataclass
class AliasInfo:
alias_set: Set[str]
is_write: bool
name: Optional[str]
@dataclass
class SchemaInfo:
args: List[AliasInfo]
outs: List[AliasInfo]
# Can't import torch._ops.OpOverload due to circular reference
parsed_schema_map: Dict[Any, SchemaInfo] = {}
# Given an OpOverload, returns schema information on it.
# This is cached for efficiency, since it can involve running torchgen
def get_alias_info(func) -> SchemaInfo:
if func in parsed_schema_map:
return parsed_schema_map[func]
# For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
# properly for some ops that output tensorlists)
if func.namespace == "aten":
torchgen_schema_str = str(func._schema)
assert torchgen_schema_str.startswith("aten::")
# remove the aten:: namespace, which is added by the torchscript parser,
# and torchgen doesn't know how to handle
torchgen_schema_str = torchgen_schema_str[6:]
import re
# the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
# which torchgen chokes on.
torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
# for aten::rot90
torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
arg_schemas = [
AliasInfo(
alias_set=(
set() if a.annotation is None else set(a.annotation.alias_set)
),
is_write=a.annotation is not None and a.annotation.is_write,
name=a.name,
)
for a in torchgen_schema.arguments.flat_all
]
out_schemas = [
AliasInfo(
alias_set=(
set() if a.annotation is None else set(a.annotation.alias_set)
),
is_write=a.annotation is not None and a.annotation.is_write,
name=a.name,
)
for a in torchgen_schema.returns
]
else:
# For non-aten ops, torchgen is untested so we rely on torchscript schema parsing
arg_schemas = [
AliasInfo(
alias_set=(
set() if a.alias_info is None else set(a.alias_info.before_set)
),
is_write=a.alias_info is not None and a.alias_info.is_write,
name=a.name,
)
for a in func._schema.arguments
]
out_schemas = [
AliasInfo(
alias_set=(
set() if a.alias_info is None else set(a.alias_info.before_set)
),
is_write=a.alias_info is not None and a.alias_info.is_write,
name=a.name,
)
for a in func._schema.returns
]
schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
parsed_schema_map[func] = schema_info
return schema_info
def return_and_correct_aliasing(func, args, kwargs, out):
"""
This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
that would like to work with torch.compile. It ensures that the subclass
properly implements the aliasing behavior of every op,
which is needed for correctness in AOTAutograd.
This function will handle:
* When we see a view op, we will alias the storages of any
input and output tensor subclasses
* When we see an inplace or out= op, we will directly
return the corresponding input tensor, instead of returning
a (potentially) fresh output tensor.
"""
# Caching here because torchgen parsing is definitely not fast, and this function is called
# once for every op in the graph during functionalization.
schema_info = get_alias_info(func)
def get_write_alias(x):
if len(x.alias_set) == 0:
return None
alias_set = list(x.alias_set)
# torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
assert len(alias_set) == 1
if x.is_write:
return alias_set[0]
return None
def get_arg_from_alias(output_alias, schema_info, args, kwargs):
new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs
)
arg_indices = [
i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
]
# For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
assert len(arg_indices) == 1
idx = arg_indices[0]
arg_info = schema_info.args[idx]
if arg_info.name is not None and arg_info.name in new_kwargs:
return new_kwargs[arg_info.name]
return new_args[idx]
# Fix up the storages of any outs so that they point to the same storage as the input,
# if func is a view op.
_correct_storage_aliasing(
func, schema_info, args, (out,) if not isinstance(out, tuple) else out
)
# For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
# metadata is set correctly.
if torch.Tag.inplace_view in func.tags:
# no_dispatch() to make sure that we secretly change the metadata on the wrapper,
# but don't end up dispatching the op anywhere else.
mutated_args = [
x
for i, x in enumerate(args)
if get_write_alias(schema_info.args[i]) is not None
]
# Assumption: we have a very small number of inplace_view ops that follow a strict schema:
# there is only a single argument that gets its metadata mutated.
assert len(mutated_args) == 1
# This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
# but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
# so we don't actually need to update the metadata (and attempting to do so causes errors)
from torch._subclasses.functional_tensor import FunctionalTensor
if not isinstance(mutated_args[0], FunctionalTensor):
with torch.utils._mode_utils.no_dispatch():
# See Note: [Fake Tensor Dispatch Keys]
# we're borrowing the way it modifies dispatch key TLS.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
torch._C._set_meta_in_tls_dispatch_include(True)
try:
func(*args, **kwargs)
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
# Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
# simple case: none of our outputs have mutable aliases, so we can return the output as-is
if not any(get_write_alias(r) is not None for r in schema_info.outs):
return out
# simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
if not all(get_write_alias(r) is not None for r in schema_info.outs):
raise RuntimeError("Unsupported schema: " + str(func._schema))
if len(func._schema.returns) == 1:
return get_arg_from_alias(
get_write_alias(schema_info.outs[0]), schema_info, args, kwargs
)
# In the multi-return case, all aten ops return a tuple / list, so cast accordingly.
outs_to_return = type(out)(
[
(
get_arg_from_alias(
get_write_alias(schema_info.outs[i]), schema_info, args, kwargs
)
if get_write_alias(r) is not None
else o
)
for ((i, r), o) in zip(enumerate(schema_info.outs), out)
]
)
return outs_to_return

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
# AND SCRUB AWAY TORCH NOTIONS THERE.
import collections
import functools
from typing import OrderedDict
simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
def count_label(label):
prev = simple_call_counter.setdefault(label, 0)
simple_call_counter[label] = prev + 1
def count(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if fn.__qualname__ not in simple_call_counter:
simple_call_counter[fn.__qualname__] = 0
simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
return fn(*args, **kwargs)
return wrapper

View File

@ -0,0 +1,302 @@
# mypy: disallow-untyped-defs
import functools
import logging
import os
import re
import subprocess
import time
from threading import Lock
from typing import Any, List, Optional, Sequence
logger = logging.getLogger("strobelight_function_profiler")
console_handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
logger.propagate = False
class StrobelightCLIProfilerError(Exception):
"""
Raised when an error happens during strobelight profiling
"""
def _pid_namespace_link(pid: Optional[int] = None) -> str:
"""Returns the link to the process's namespace, example: pid:[4026531836]"""
PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
pid = pid or os.getpid()
return os.readlink(PID_NAMESPACE_PATH.format(pid))
def _pid_namespace(pid: Optional[int] = None) -> int:
"""Returns the process's namespace id"""
pid = pid or os.getpid()
link = _pid_namespace_link(pid)
return int(link[link.find("[") + 1 : -1])
def _command_to_string(command: Sequence[str]) -> str:
return " ".join(command)
class StrobelightCLIFunctionProfiler:
"""
Note: this is a meta only tool.
StrobelightCLIFunctionProfiler can be used to profile a python function and
generate a strobelight link with the results. It works on meta servers but
does not requries an fbcode target.
When stop_at_error is false(default), error during profiling does not prevent
the work function from running.
Check function_profiler_example.py for an example.
"""
# This lock is used to make sure only one thread is running the profiler at any point.
_lock = Lock()
def __init__(
self,
*,
stop_at_error: bool = False,
max_profile_duration_sec: int = 60 * 10,
sample_each: float = 1e7, # sample each sample_each cycles.
run_user_name: str = "pytorch-strobelight-ondemand",
timeout_wait_for_running_sec: int = 60,
timeout_wait_for_finished_sec: int = 60,
recorded_env_variables: Optional[List[str]] = None,
sample_tags: Optional[List[str]] = None,
stack_max_len: int = 127,
async_stack_max_len: int = 127,
):
self.stop_at_error = stop_at_error
self.max_profile_duration_sec = max_profile_duration_sec
self.sample_each = sample_each
self.run_user_name = run_user_name
self.timeout_wait_for_running_sec = timeout_wait_for_running_sec
self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
# Results of the most recent run.
# Tracks the strobelight run id of the most recent run
self.current_run_id: Optional[int] = None
self.sample_tags = sample_tags
def _run_async(self) -> None:
processId = os.getpid()
namespace = _pid_namespace(processId)
command = [
"strobeclient",
"run",
"--profiler",
"pyperf",
"--event",
"cycles",
"--async",
"--sample-interval",
f"{int(self.sample_each)}",
"--duration-ms",
f"{int(self.max_profile_duration_sec * 1000)}",
"--pid",
f"{namespace}:{processId}",
]
if self.sample_tags:
command.append("--sample-tags")
command.append(",".join(self.sample_tags))
logger.debug("running command: %s", _command_to_string(command))
result = subprocess.run(command, capture_output=True)
output = result.stderr.decode("utf-8")
logger.debug("output:\n{%s}", output)
if result.returncode != 0:
raise StrobelightCLIProfilerError(
f"failed to start strobelight profiling, error in run_async:{output}"
)
if match := re.search(r"INFO Run Id: (-?\d+)", output):
self.current_run_id = int(match.group(1))
return
raise StrobelightCLIProfilerError(
f"failed to start strobelight profiling, unexpected result {output}"
)
def _wait_for_running(self, counter: int = 0) -> None:
if counter > 20:
raise StrobelightCLIProfilerError(
"wait_for_running called more than 20 times"
)
command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"]
logger.debug("running command: %s", _command_to_string(command))
result = subprocess.run(command, capture_output=True)
output = result.stderr.decode("utf-8")
logger.debug("output:\n{%s}", output)
if result.returncode != 0:
raise StrobelightCLIProfilerError(
f"failed to start strobelight profiling, error in wait_for_running:{output}"
)
if match := re.search("Profile run status: (.*)", output):
current_status = match.group(1)
if current_status == "RUNNING":
return
elif current_status == "PREPARING":
time.sleep(10)
self._wait_for_running(counter + 1)
return
else:
raise StrobelightCLIProfilerError(f"unexpected {current_status} phase")
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
def _stop_run(self) -> None:
command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)]
logger.debug("running command: %s", _command_to_string(command))
result = subprocess.run(command, capture_output=True)
output = result.stderr.decode("utf-8")
logger.debug("output:\n{%s}", output)
if result.returncode != 0:
raise StrobelightCLIProfilerError(
f"failed to stop strobelight profiling, return code is not 0 :{output}"
)
if match := re.search("INFO ::1:(.*)", output):
current_status = match.group(1)
if current_status.__contains__("Success!"):
return
else:
raise StrobelightCLIProfilerError(
f"failed to stop strobelight profiling, got {current_status} result"
)
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
def _get_results(self) -> None:
command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)]
logger.debug("running command: %s", _command_to_string(command))
result = subprocess.run(command, capture_output=True)
output = result.stderr.decode("utf-8")
logger.debug("output:\n{%s}", output)
if result.returncode != 0:
raise StrobelightCLIProfilerError(
f"failed to extract profiling results, return code is not 0 : {output}"
)
if match := re.search("INFO ::1:(.*)", output):
current_status = match.group(1)
if current_status.__contains__("Profile run status: PROCESSING"):
time.sleep(10)
self._get_results()
return
elif not current_status.__contains__("Profile run finished with SUCCESS"):
raise StrobelightCLIProfilerError(
f"failed to extract profiling results, unexpected response {output}"
)
for item in re.findall(
r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))",
output,
):
logger.info(item[0])
def _stop_strobelight_no_throw(
self,
collect_results: bool,
) -> None:
try:
# call stop run
self._stop_run()
logger.info("strobelight profiling stopped")
logger.debug("collection stopped")
if not collect_results:
return
self._get_results()
except Exception as error:
logger.warning("error during stop_strobelight", exc_info=True)
# Return true if strobelight started and is running. Never throw.
def _start_strobelight(self) -> bool:
strobelight_started = False
try:
self._run_async()
strobelight_started = True
logger.info("strobelight run id is: %s", self.current_run_id)
self._wait_for_running()
logger.info("strobelight profiling running")
return True
except Exception as error:
logger.warning("error during start_strobelight:", exc_info=True)
if strobelight_started:
self._stop_strobelight_no_throw(collect_results=False)
return False
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
self.current_run_id = None
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
if not locked:
if self.stop_at_error:
raise StrobelightCLIProfilerError("concurrent runs not supported")
logger.warning("concurrent runs not supported")
return work_function(*args, **kwargs)
started = self._start_strobelight()
if not started:
if self.stop_at_error:
StrobelightCLIFunctionProfiler._lock.release()
raise StrobelightCLIProfilerError(
"failed to start strobelight profiling"
)
result = work_function(*args, **kwargs)
StrobelightCLIFunctionProfiler._lock.release()
return result
try:
logger.debug("collection started")
result = work_function(*args, **kwargs)
self._stop_strobelight_no_throw(collect_results=True)
StrobelightCLIFunctionProfiler._lock.release()
return result
except Exception as error:
logger.warning("work function throw exception", exc_info=True)
self._stop_strobelight_no_throw(collect_results=False)
StrobelightCLIFunctionProfiler._lock.release()
raise error
# A function decorator that wraps profile, if no profiler is provided one with
# default args is created. A function can be annotated as:
# @strobelight()
# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
# @strobelight(stop_at_error=True,...)
def strobelight(
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
) -> Any:
if not profiler:
profiler = StrobelightCLIFunctionProfiler(**kwargs)
def strobelight_inner(work_function: Any) -> Any:
@functools.wraps(work_function)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function
return strobelight_inner

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,190 @@
# mypy: allow-untyped-defs
"""
This is a simple interpreter for Sympy expressions that dispatches to
classes following the torch._inductor.virtualized calling convention.
For directness, the interpreter takes the handler directly rather than
consulting the TLS. It does not use most of the methods on the full
handler; only those with corresponding Sympy expressions. To see an example
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
"""
import functools
import logging
from typing import Any, Dict, Union
import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import (
CeilToInt,
CleanDiv,
FloatPow,
FloatTrueDiv,
FloorDiv,
FloorToInt,
Identity,
IntTrueDiv,
IsNonOverlappingAndDenseIndicator,
Max,
Min,
Mod,
ModularIndexing,
PowByNatural,
PythonMod,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToFloat,
TruncToInt,
Where,
)
log = logging.getLogger(__name__)
# TODO: Dedupe this with SYMPY_INTERP
@functools.lru_cache(None)
def handlers():
# TODO add CeilDiv (it doesn't appear in the index_expr)
# TODO default to some decompositions if the interpreter doesn't have them
# like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
HANDLERS = {
sympy.Or: "or_",
sympy.And: "and_",
sympy.Eq: "eq",
sympy.Ne: "ne",
sympy.Lt: "lt",
sympy.Gt: "gt",
sympy.Le: "le",
sympy.Ge: "ge",
sympy.Not: "not_",
IntTrueDiv: "int_truediv",
FloatTrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "floordiv", # TODO: hmm?
TruncToFloat: "trunc",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
FloatPow: "pow",
PowByNatural: "pow_by_natural",
# sympy simplifies x * x into Pow(x, 2), so we need to handle this.
# Do NOT use builtin Pow for floats
# TODO: There is a hazard here, if we have float * float it will
# also get turned into Pow(float, 2) but we don't want this because
# pow_by_natural is assumed to only be integers. Probably the fix is
# to add a FloatMul to impede this optimization
sympy.Pow: "pow_by_natural",
Mod: "mod",
PythonMod: "mod", # TODO: this is wrong
# TODO: Inductor can generate these, but it's ill-specified which
# semantics were intended here. Needs to be cleaned up along with
# FloorDiv in a bigger cleanup
sympy.Mod: "mod",
sympy.Abs: "abs",
sympy.log: "log",
sympy.exp: "exp",
sympy.Min: "minimum",
sympy.Max: "maximum",
Min: "minimum",
Max: "maximum",
ModularIndexing: "modular_indexing",
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
sympy.Piecewise: "piecewise",
Identity: "identity",
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
RoundDecimal: "round_decimal",
}
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
HANDLERS[getattr(sympy, name)] = name
return HANDLERS
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
# Special cases
if isinstance(expr, sympy.Pow) and isinstance(
expr.args[1], sympy.core.numbers.Half
):
return analysis.sqrt(args[0])
if isinstance(expr, ToFloat):
return analysis.to_dtype(args[0], torch.float64)
# These handlers are special because they take an extra dtype argument
# specifying what they should convert to, and we need to appropriately set
# this up when we convert from Sympy. A reasonable default when you
# are translating is to conservatively do int64, and then narrow these
# arguments later when you discover you can narrow the index range. But
# if you already know that 32-bit indexing is OK, you can directly do the
# sympy translation with index_dtype=torch.int32
INDEX_DTYPE_HANDLERS = {
TruncToInt: "trunc_to_int",
sympy.floor: "floor_to_int",
sympy.ceiling: "ceil_to_int",
FloorToInt: "floor_to_int",
CeilToInt: "ceil_to_int",
RoundToInt: "round_to_int",
}
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
return getattr(analysis, handler_name)(*args, index_dtype)
if hasattr(expr.func, "_torch_handler_name"):
handler_name = expr.func._torch_handler_name
else:
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
try:
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
log.debug("%s(%s) -> %s", handler_name, args, acc)
return acc
else:
r = handler(*args)
log.debug("%s(%s) -> %s", handler_name, args, r)
return r
except Exception:
log.warning("failed while executing %s(%s)", handler_name, args)
raise
def sympy_interp(
analysis,
env: Dict[sympy.Symbol, Any],
expr: Union[sympy.Expr, SympyBoolean],
*,
index_dtype=torch.int64,
):
# Handle base cases
dtype = None
if isinstance(expr, BooleanAtom):
dtype = torch.bool
elif isinstance(expr, sympy.Integer):
dtype = torch.int64
elif isinstance(expr, sympy.Number):
dtype = torch.double
if dtype is not None:
return analysis.constant(expr, dtype)
elif isinstance(expr, sympy.Symbol):
return env[expr]
# Recursive case
return _run_sympy_handler(
analysis,
[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]
expr,
index_dtype=index_dtype,
) # type: ignore[arg-type]

View File

@ -0,0 +1,397 @@
# mypy: allow-untyped-defs
import mpmath.libmp as mlib # type: ignore[import-untyped]
import sympy
from sympy import Expr
from sympy.core.decorators import _sympifyit
from sympy.core.expr import AtomicExpr
from sympy.core.numbers import Number
from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
Integer infinity is a value in an extended integers which
is greater than all other integers. We distinguish it from
sympy's existing notion of infinity in that it reports that
it is_integer.
Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
or can be imported as ``int_oo``.
"""
# NB: We can't actually mark this as infinite, as integer and infinite are
# inconsistent assumptions in sympy. We also report that we are complex,
# different from sympy.oo
is_integer = True
is_commutative = True
is_number = True
is_extended_real = True
is_comparable = True
is_extended_positive = True
is_prime = False
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _sympystr(self, printer):
return "int_oo"
def _eval_subs(self, old, new):
if self == old:
return new
# We could do these, not sure about it
"""
def _eval_evalf(self, prec=None):
return Float('inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (S.Infinity, S.NegativeInfinity):
return other
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.NegativeInfinity
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.NegativeIntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return S.Infinity # truediv produces float
return S.NegativeInfinity # truediv produces float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.NegativeIntInfinity
def _eval_power(self, expt):
if expt.is_extended_positive:
return S.IntInfinity
if expt.is_extended_negative:
return S.Zero
if expt is S.NaN:
return S.NaN
if expt is S.ComplexInfinity:
return S.NaN
if expt.is_extended_real is False and expt.is_number:
from sympy.functions.elementary.complexes import re
expt_real = re(expt)
if expt_real.is_positive:
return S.ComplexInfinity
if expt_real.is_negative:
return S.Zero
if expt_real.is_zero:
return S.NaN
return self ** expt.evalf()
def _as_mpf_val(self, prec):
return mlib.finf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.IntInfinity
def __ne__(self, other):
return other is not S.IntInfinity
def __gt__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __ge__(self, other):
if other is S.Infinity:
return sympy.false # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
def __lt__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __le__(self, other):
if other is S.Infinity:
return sympy.true # sympy.oo > int_oo
elif other is S.IntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
int_oo = S.IntInfinity
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.
NegativeInfinity is a singleton, and can be accessed
by ``S.NegativeInfinity``.
See Also
========
IntInfinity
"""
# Ensure we get dispatched to before plain numbers
_op_priority = 100.0
is_integer = True
is_extended_real = True
is_commutative = True
is_comparable = True
is_extended_negative = True
is_number = True
is_prime = False
__slots__ = ()
def __new__(cls):
return AtomicExpr.__new__(cls)
def _eval_subs(self, old, new):
if self == old:
return new
def _sympystr(self, printer):
return "-int_oo"
"""
def _eval_evalf(self, prec=None):
return Float('-inf')
def evalf(self, prec=None, **options):
return self._eval_evalf(prec)
"""
@_sympifyit("other", NotImplemented)
def __add__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.Infinity:
return S.Infinity
if other in (S.IntInfinity, S.NaN):
return S.NaN
return self
return Number.__add__(self, other)
__radd__ = __add__
@_sympifyit("other", NotImplemented)
def __sub__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other is S.NegativeInfinity:
return S.Infinity
if other in (S.NegativeIntInfinity, S.NaN):
return S.NaN
return self
return Number.__sub__(self, other)
@_sympifyit("other", NotImplemented)
def __rsub__(self, other):
return (-self).__add__(other)
@_sympifyit("other", NotImplemented)
def __mul__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other.is_zero or other is S.NaN:
return S.NaN
if other.is_extended_positive:
return self
return S.IntInfinity
return Number.__mul__(self, other)
__rmul__ = __mul__
@_sympifyit("other", NotImplemented)
def __truediv__(self, other):
if isinstance(other, Number) and global_parameters.evaluate:
if other in (
S.Infinity,
S.IntInfinity,
S.NegativeInfinity,
S.NegativeIntInfinity,
S.NaN,
):
return S.NaN
if other.is_extended_nonnegative:
return self
return S.Infinity # truediv returns float
return Number.__truediv__(self, other)
def __abs__(self):
return S.IntInfinity
def __neg__(self):
return S.IntInfinity
def _eval_power(self, expt):
if expt.is_number:
if expt in (
S.NaN,
S.Infinity,
S.NegativeInfinity,
S.IntInfinity,
S.NegativeIntInfinity,
):
return S.NaN
if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
if expt.is_odd:
return S.NegativeIntInfinity
else:
return S.IntInfinity
inf_part = S.IntInfinity**expt
s_part = S.NegativeOne**expt
if inf_part == 0 and s_part.is_finite:
return inf_part
if (
inf_part is S.ComplexInfinity
and s_part.is_finite
and not s_part.is_zero
):
return S.ComplexInfinity
return s_part * inf_part
def _as_mpf_val(self, prec):
return mlib.fninf
def __hash__(self):
return super().__hash__()
def __eq__(self, other):
return other is S.NegativeIntInfinity
def __ne__(self, other):
return other is not S.NegativeIntInfinity
def __gt__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.false
def __ge__(self, other):
if other is S.NegativeInfinity:
return sympy.true # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.false
def __lt__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.false # consistency with sympy.oo
else:
return sympy.true
def __le__(self, other):
if other is S.NegativeInfinity:
return sympy.false # -sympy.oo < -int_oo
elif other is S.NegativeIntInfinity:
return sympy.true # consistency with sympy.oo
else:
return sympy.true
@_sympifyit("other", NotImplemented)
def __mod__(self, other):
if not isinstance(other, Expr):
return NotImplemented
return S.NaN
__rmod__ = __mod__
def floor(self):
return self
def ceiling(self):
return self
def as_powers_dict(self):
return {S.NegativeOne: 1, S.IntInfinity: 1}

View File

@ -0,0 +1,283 @@
# mypy: allow-untyped-defs
import math
import operator
import sympy
import torch
from torch.utils._sympy.functions import (
_keep_float,
FloatPow,
FloatTrueDiv,
FloorDiv,
IntTrueDiv,
Max,
Min,
Mod,
OpaqueUnaryFn_exp,
OpaqueUnaryFn_log,
OpaqueUnaryFn_sqrt,
PowByNatural,
RoundDecimal,
RoundToInt,
ToFloat,
TruncToInt,
)
# The sympy interpretation of operators. It will also sometimes work with
# plain int/float, but if you do certain operations you will get out a
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
# check PythonReferenceAnalysis.
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@staticmethod
def constant(c, dtype):
return sympy.sympify(c)
@staticmethod
def or_(a, b):
return a | b
@staticmethod
def and_(a, b):
return a & b
@staticmethod
def eq(a, b):
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
return sympy.Eq(a, b)
return a == b
@classmethod
def ne(cls, a, b):
return cls.not_(cls.eq(a, b))
@staticmethod
def lt(a, b):
return a < b
@staticmethod
def gt(a, b):
return a > b
@staticmethod
def le(a, b):
return a <= b
@staticmethod
def ge(a, b):
return a >= b
@staticmethod
def not_(a):
assert not isinstance(a, bool)
return ~a
@staticmethod
def reciprocal(x):
return FloatTrueDiv(1.0, x)
@staticmethod
def square(x):
return PowByNatural(x, 2)
@staticmethod
def trunc_to_int(x, dtype):
return TruncToInt(x)
@staticmethod
def ceil_to_int(x, dtype):
return sympy.ceiling(x)
@staticmethod
def floor_to_int(x, dtype):
return sympy.floor(x)
@staticmethod
def floor(x):
return _keep_float(sympy.floor)(x)
@staticmethod
def ceil(x):
return _keep_float(sympy.ceiling)(x)
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return ToFloat(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def mod(x, y):
return Mod(x, y)
@staticmethod
def abs(x):
return abs(x)
@staticmethod
def neg(x):
return -x
@staticmethod
def truediv(a, b):
return FloatTrueDiv(a, b)
@staticmethod
def int_truediv(a, b):
return IntTrueDiv(a, b)
@staticmethod
def floordiv(a, b):
return FloorDiv(a, b)
@staticmethod
def truncdiv(a, b):
raise NotImplementedError("TODO: truncdiv")
@staticmethod
def add(a, b):
return _keep_float(operator.add)(a, b)
@staticmethod
def mul(a, b):
return _keep_float(operator.mul)(a, b)
@staticmethod
def sub(a, b):
return _keep_float(operator.sub)(a, b)
@staticmethod
def exp(x):
return OpaqueUnaryFn_exp(x)
@staticmethod
def log(x):
return OpaqueUnaryFn_log(x)
@staticmethod
def sqrt(x):
return OpaqueUnaryFn_sqrt(x)
@staticmethod
def pow(a, b):
return _keep_float(FloatPow)(a, b)
@staticmethod
def pow_by_natural(a, b):
return PowByNatural(a, b)
@staticmethod
def minimum(a, b):
return Min(a, b)
@staticmethod
def maximum(a, b):
return Max(a, b)
@staticmethod
def round_to_int(a, dtype):
return RoundToInt(a)
@staticmethod
def round_decimal(a, b):
return RoundDecimal(a, b)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
class PythonReferenceAnalysis(ReferenceAnalysis):
@staticmethod
def constant(c, dtype):
if dtype is torch.int64:
return int(c)
elif dtype is torch.double:
return float(c)
elif dtype is torch.bool:
return bool(c)
else:
raise AssertionError(f"unrecognized dtype {dtype}")
@staticmethod
def not_(a):
return torch.sym_not(a)
@staticmethod
def floordiv(a, b):
return a // b
@staticmethod
def mod(x, y):
return x % y
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def to_dtype(x, dtype):
if dtype == torch.float64:
return torch.sym_float(x)
raise NotImplementedError(f"to_dtype {dtype} NYI")
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
def sqrt(x):
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@staticmethod
def minimum(a, b):
return torch.sym_min(a, b)
@staticmethod
def maximum(a, b):
return torch.sym_max(a, b)
@staticmethod
def floor_to_int(x, dtype):
return math.floor(x)
@staticmethod
def ceil_to_int(x, dtype):
return math.ceil(x)
@staticmethod
def floor(x):
return float(math.floor(x))
@staticmethod
def ceil(x):
return float(math.ceil(x))
@staticmethod
def truediv(a, b):
return a / b
@staticmethod
def pow(a, b):
return a**b
@staticmethod
def pow_by_natural(a, b):
# Pray that safe_pow is not needed here lol. In particular, this
# never participates in VR low/high ranges, so overflow should be
# unlikely
return a**b
@staticmethod
def round_to_int(a, dtype):
return round(a)
@staticmethod
def round_decimal(a, b):
return round(a, ndigits=b)

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
import sympy
from sympy.multipledispatch import dispatch
__all__ = ["SingletonInt"]
class SingletonInt(sympy.AtomicExpr):
# This is probably not super important unless we are in multiple dispatch
# situations with other more exotic Expr types.
_op_priority = 99999
def __new__(cls, *args, coeff=None, **kwargs):
instance = super().__new__(cls, *args, **kwargs)
return instance
# The semantics of this class should match that of NestedIntSymNodeImpl in
# c10/core/NestedIntSymNodeImpl.h
def __init__(self, val, *, coeff=1):
self._val = val
self._coeff = coeff
super().__init__()
# See NOTE [ Inequalities with nested int ]
def _eval_Eq(self, other):
if (
isinstance(other, SingletonInt)
and other._val == self._val
and self._coeff == other._coeff
):
return sympy.true
else:
return sympy.false
# This is necessary so that calling expr.free_symbols on exprs that contain
# this Singleton does not error
@property
def free_symbols(self):
return set()
def __mul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
def __rmul__(self, other):
if isinstance(other, SingletonInt):
raise ValueError(
"SingletonInt cannot be multiplied by another SingletonInt"
)
return SingletonInt(self._val, coeff=self._coeff * other)
# Make sure we promptly raise an error instead of falling back to building
# an expression tree. There are probably more ops, how can we be exhaustive?
def __add__(self, other):
raise NotImplementedError("NYI")
def __sub__(self, other):
raise NotImplementedError("NYI")
def __truediv__(self, other):
raise NotImplementedError("NYI")
def __floordiv__(self, other):
raise NotImplementedError("NYI")
def __mod__(self, other):
raise NotImplementedError("NYI")
# See NOTE [ Inequalities with nested int ]
@dispatch(sympy.Integer, SingletonInt)
def _eval_is_ge(a, b):
if a < 2:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if b <= 2:
return sympy.true
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")
@dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef]
def _eval_is_ge(a, b): # noqa: F811
if a._val == b._val:
if a._coeff >= b._coeff:
return sympy.true
else:
return sympy.false
raise ValueError("Symbolic SingletonInt: Relation is indeterminate")

View File

@ -0,0 +1,175 @@
import logging
from typing import Dict, Optional, Tuple, Type
import sympy
from torch.utils._sympy.functions import FloorDiv
log = logging.getLogger(__name__)
_MIRROR_REL_OP: Dict[Type[sympy.Basic], Type[sympy.Rel]] = {
sympy.Eq: sympy.Eq,
sympy.Ne: sympy.Ne,
sympy.Ge: sympy.Le,
sympy.Gt: sympy.Lt,
sympy.Le: sympy.Ge,
sympy.Lt: sympy.Gt,
}
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
def mirror_rel_op(type: Type) -> Optional[Type[sympy.Rel]]:
return _MIRROR_REL_OP.get(type, None)
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
#
# Returns a tuple of:
# 1. The simplified expression
# 2. The expression on the right-hand side
#
# Returns 'None' if it can't reach a state where the only thing in the left
# hand side is 'thing'.
#
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
# left-hand side.
#
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
# inequalities.
def try_solve(
expr: sympy.Basic,
thing: sympy.Basic,
trials: int = 5,
floordiv_inequality: bool = True,
) -> Optional[Tuple[sympy.Rel, sympy.Basic]]:
mirror = mirror_rel_op(type(expr))
# Ignore unsupported expressions:
# - Those that are not relational operations
# - Those that don't have a mirror (just avoiding unexpected classes)
if not isinstance(expr, sympy.Rel) or mirror is None:
log.debug("expression with unsupported type: %s", type(expr))
return None
lhs_has_thing = expr.lhs.has(thing)
rhs_has_thing = expr.rhs.has(thing)
# Give up when 'thing' appears on both sides of the relational expression.
# That is because, as is, we assume the thing we are trying to isolate is
# only on the right-hand side.
if lhs_has_thing and rhs_has_thing:
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
return None
# Try considering both LHS and RHS by mirroring the original expression:
# a < b ==> b > a
expressions = []
# Add each version of 'expr' if 'thing' is in its left-hand side.
if lhs_has_thing:
expressions.append(expr)
if rhs_has_thing:
expressions.append(mirror(expr.rhs, expr.lhs))
for e in expressions:
if e is None:
continue
assert isinstance(e, sympy.Rel)
for _ in range(trials):
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
# Stop if there was no change in this trial.
if trial == e:
break
e = trial # type: ignore[assignment]
# Return if we were able to isolate 'thing' on the left-hand side.
if isinstance(e, sympy.Rel) and e.lhs == thing:
log.debug("solved: %s ---> %s", expr, e)
return e, e.rhs
return None
def _try_isolate_lhs(
e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool
) -> sympy.Basic:
op = type(e)
if isinstance(e, sympy.Rel):
# Move any constants in the left-hand side to the right-hand side.
lhs_not_thing = (
sum(a for a in e.lhs.args if not a.has(thing))
if isinstance(e.lhs, sympy.Add)
else 0
)
e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
# Divide both sides by the factors that don't contain thing.
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
lhs, rhs = e.args
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
# If we can't tell whether 'other' is negative or positive, we do nothing.
# That is because we don't know whether we have mirror the operation or not.
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None):
# Divide both sides by 'other'.
lhs = lhs / other
rhs = rhs / other
# If 'e' is an inequality and 'other' is negative, we have to
# mirror the expression.
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
op = mirror_rel_op(op) # type: ignore[assignment]
assert op is not None
e = op(lhs, rhs)
################################################################################
# left-hand side is FloorDiv
################################################################################
#
# Given the expression: a // b op c
# where 'op' is a relational operation, these rules only work if:
# - b > 0
# - c is an integer
if (
floordiv_inequality
and isinstance(e, sympy.Rel)
and isinstance(e.lhs, FloorDiv)
and e.lhs.divisor.is_positive
and e.rhs.is_integer
):
# a // b == expr
# => a >= (b * expr) and a < (b * (expr + 1))
if isinstance(e, sympy.Eq):
numerator, denominator = e.lhs.args
return sympy.And(
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# a // b != expr
# => a < (b * expr) or a >= (b * (expr + 1))
if isinstance(e, sympy.Ne):
numerator, denominator = e.lhs.args
return sympy.Or(
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
)
# The transformations below only work if b is positive.
# Note: we only have this information for constants.
# a // b > expr => a >= b * (expr + 1)
# a // b >= expr => a >= b * expr
if isinstance(e, (sympy.Gt, sympy.Ge)):
quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
# a // b < expr => a < b * expr
# a // b <= expr => a < b * (expr + 1)
if isinstance(e, (sympy.Lt, sympy.Le)):
quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
return e

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
"""
This file contains canonical definitions for our symbol naming conventions,
across torch.fx.experimental.symbolic_shapes and torch._inductor. The
intention is:
1. To make it easily greppable where all the sites we use a prefix are
2. Make it possible to easily tell if we can introduce a new prefix without
introducing a conflict
You can occasionally test if prefixes have been hardcoded by renaming prefixes
in this file and seeing what breaks.
"""
from enum import auto, Enum
from typing import Sequence, Union
import sympy
class SymT(Enum):
SIZE = auto()
FLOAT = auto()
UNBACKED_INT = auto()
UNBACKED_FLOAT = auto()
# Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
# If one of these shows up in an indexing expression, that means an
# indirect load is happening.
TMP = auto()
# Inductor: Placeholder variable that is later replaced with TMP
INDIRECT = auto()
# Inductor: Some size expressions are replaced with a precomputed size ps0
# which is computed host side, and then directly reused in the kernel, so
# we don't repeatedly recompute it on device.
PRECOMPUTED_SIZE = auto()
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
# dim in the loop
INDEX = auto()
# Inductor: A reduction indexing r0 variable in loops IR which ranges over
# reduced dim in the loop
RINDEX = auto()
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to
# store the final output and append epilogue fusions. To do this, we must
# know what the indexes the outputs range over. NB: These will also
# advertise as INDEX, this is... probably OK?
TEMPLATE_INDEX = auto()
# Inductor: iteration domain for blockIdx.x/blockIdx.y
XBLOCK = auto()
YBLOCK = auto()
# Inductor: this is used solely for dynamic_reshape_indexer
VIEW = auto()
# Alternate (non-modular) indexing used in halide kernels
HALIDE = auto()
# Invariant: there must not be a prefix which is a prefix of another string,
# as this introduces ambiguity
prefix_str = {
SymT.SIZE: "s", # integer
SymT.UNBACKED_INT: "u", # integer
# Prefix z here is chosen to avoid false aliasing in symbol_is_type test
# DO NOT add a "z" type. You also need to avoid conflicts on these
# prefixes but this is somewhat easier to manage
SymT.FLOAT: "zf",
SymT.UNBACKED_FLOAT: "zuf",
SymT.TMP: "tmp",
SymT.PRECOMPUTED_SIZE: "ps",
SymT.INDEX: "i",
SymT.RINDEX: "r",
SymT.TEMPLATE_INDEX: "idx",
SymT.XBLOCK: "x",
SymT.YBLOCK: "y",
SymT.INDIRECT: "indirect", # false aliasing?
SymT.VIEW: "view",
SymT.HALIDE: "h",
}
def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# TODO: maybe put the assumptions here directly
return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
# This type is a little wider than it should be, because free_symbols says
# that it contains Basic, rather than Symbol
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool:
assert isinstance(sym, sympy.Symbol)
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
if isinstance(prefix, SymT):
return name_str.startswith(prefix_str[prefix])
else:
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool:
return any(symbol_is_type(v, prefix) for v in e.free_symbols)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,28 @@
from typing import Callable, Generic, Optional, TypeVar
R = TypeVar("R")
class Thunk(Generic[R]):
"""
A simple lazy evaluation implementation that lets you delay
execution of a function. It properly handles releasing the
function once it is forced.
"""
f: Optional[Callable[[], R]]
r: Optional[R]
__slots__ = ["f", "r"]
def __init__(self, f: Callable[[], R]):
self.f = f
self.r = None
def force(self) -> R:
if self.f is None:
return self.r # type: ignore[return-value]
self.r = self.f()
self.f = None
return self.r

View File

@ -0,0 +1,255 @@
# mypy: allow-untyped-defs
from types import TracebackType
from typing import List, Optional
import tempfile
import traceback
import contextlib
import inspect
import os.path
# This file contains utilities for ensuring dynamically compile()'d
# code fragments display their line numbers in backtraces.
#
# The constraints:
#
# - We don't have control over the user exception printer (in particular,
# we cannot assume the linecache trick will work, c.f.
# https://stackoverflow.com/q/50515651/23845 )
#
# - We don't want to create temporary files every time we compile()
# some code; file creation should happen lazily only at exception
# time. Arguably, you *should* be willing to write out your
# generated Python code to file system, but in some situations
# (esp. library code) it would violate user expectation to write
# to the file system, so we try to avoid it. In particular, we'd
# like to keep the files around, so users can open up the files
# mentioned in the trace; if the file is invisible, we want to
# avoid clogging up the filesystem.
#
# If this is not a constraint for you, there is a substantially simpler
# way to implement the functionality in this PR: instead of using
# eval/exec directly, just always write a Python file to filesystem
# and compile that.
#
# - You have control over a context where the compiled code will get
# executed, so that we can interpose while the stack is unwinding
# (otherwise, we have no way to interpose on the exception printing
# process.)
#
# There are two things you have to do to make use of the utilities here:
#
# - When you compile your source code, you must save its string source
# in its f_globals under the magic name "__compile_source__"
#
# - Before running the compiled code, enter the
# report_compile_source_on_error() context manager.
@contextlib.contextmanager
def report_compile_source_on_error():
try:
yield
except Exception as exc:
tb = exc.__traceback__
# Walk the traceback, looking for frames that have
# source attached
stack = []
while tb is not None:
filename = tb.tb_frame.f_code.co_filename
source = tb.tb_frame.f_globals.get("__compile_source__")
if filename == "<string>" and source is not None:
# What black magic are we doing here? Intuitively, what
# we would like to do is overwrite the co_filename on any
# frames that were generated from exec/eval so that they
# point to a temporary file that has the actual line
# information, so Python's default error printer can print
# useful line information on it.
#
# Writing out the temporary file is easy. But overwriting
# co_filename is not! You can't modify the code object
# associated with a frame. You can, however, reconstruct
# a traceback with entirely new frames from scratch, so that's
# what we do. But there's another problem, which is how to
# make the frame?
#
# The black magic is we make a frankenstein frame and code
# object which resembles the original frame/code enough so
# that it will print properly under traceback and the default
# error printer, but IT IS NOT THE ORIGINAL FRAME (you
# couldn't, e.g., execute its code with different variables
# and expect it to work.)
# Don't delete the temporary file so the user can inspect it
# TODO: This creates a temporary file for every frame, but we
# technically only need one per distinct __compile_source__
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f:
f.write(source)
# Create a frame. Python doesn't let you construct
# FrameType directly, so just make one with compile
frame = tb.tb_frame
code = compile('__inspect_currentframe()', f.name, 'eval')
code = code.replace(co_name=frame.f_code.co_name)
# Python 3.11 only
if hasattr(frame.f_code, 'co_linetable'):
# We can't copy ALL of the metadata over, because you
# can cause Python to segfault this way. What exactly
# do we need? We need enough information for
# traceback to be able to print the exception
# correctly. Code reading Lib/traceback.py reveals
# that traceback calls code.co_positions() in order to
# get the augmented line/col numbers. Objects/codeobject.c,
# specifically _PyCode_InitAddressRange, reveals that
# this iterator is initialized from co_linetable and
# co_firstfileno. So copy these we must!
code = code.replace( # type: ignore[call-arg]
co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined]
co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined]
)
fake_frame = eval(
code,
frame.f_globals,
{
**frame.f_locals,
'__inspect_currentframe': inspect.currentframe
}
)
fake_tb = TracebackType(
None, fake_frame, tb.tb_lasti, tb.tb_lineno
)
stack.append(fake_tb)
else:
stack.append(tb)
tb = tb.tb_next
# Reconstruct the linked list
tb_next = None
for tb in reversed(stack):
tb.tb_next = tb_next
tb_next = tb
raise exc.with_traceback(tb_next) # noqa: B904
def shorten_filename(fn, *, base=None):
"""Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user."""
if base is None:
base = os.path.dirname(os.path.dirname(__file__))
# Truncate torch/foo.py to foo.py
try:
prefix = os.path.commonpath([fn, base])
except ValueError:
return fn
else:
return fn[len(prefix) + 1:]
def format_frame(frame, *, base=None, line=False):
"""
Format a FrameSummary in a short way, without printing full absolute path or code.
The idea is the result fits on a single line.
"""
extra_line = ""
if line:
extra_line = f"{frame.line} # "
return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
def format_traceback_short(tb):
"""Format a TracebackType in a short way, printing only the inner-most frame."""
return format_frame(traceback.extract_tb(tb)[-1])
class CapturedTraceback:
__slots__ = ['tb', 'skip']
def __init__(self, tb, skip=0):
self.tb = tb
self.skip = skip
def cleanup(self):
self.tb = None
def summary(self):
import torch._C._profiler
if self.tb is None:
# TODO: Maybe indicate that the traceback was elided?
return traceback.StackSummary()
return _extract_symbolized_tb(
torch._C._profiler.symbolize_tracebacks([self.tb])[0],
self.skip
)
def __getstate__(self):
return (None, {
'tb': None, # TB is not pickleable
'skip': self.skip,
})
@staticmethod
def extract(*, script=False, cpp=False, skip=0):
"""
Like traceback.extract_stack(), but faster (approximately 20x faster); it
is fast enough that you can unconditionally log stacks this way as part of
normal execution. It returns a torch._C._profiler.CapturedTraceback
object that must be formatted specially with format_captured_tb.
By default, this only reports Python backtraces (like extract_stack). You
can set the script/cpp kwargs to also turn on TorchScript/C++ trace
reporting.
"""
import torch._C._profiler
if script or cpp:
assert skip == 0, "skip with script/cpp NYI"
return CapturedTraceback(
torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
# Elide extract() frame if we don't have script/cpp frames. If
# we do have those frames, it doesn't work so force zero.
0 if script or cpp else skip + 1
)
def format(self):
"""
Formats a single torch._C._profiler.CapturedTraceback into a list of
strings equivalent to the output of traceback.format_list. Note that if
pass it CapturedTraceback with C++ traces, it is better not to use this
function and use the batch formatting API format_captured_tbs to amortize
the cost of symbolization
"""
return traceback.format_list(self.summary())
@staticmethod
def format_all(tbs):
"""
Bulk version of CapturedTraceback.format. Returns a list of list of strings.
"""
import torch._C._profiler
# Directly populate tracebacks that already have cached summaries
rs: List[Optional[List[str]]] = []
delayed_idxs = []
for i, tb in enumerate(tbs):
if tb.tb is None:
rs.append([])
else:
rs.append(None)
delayed_idxs.append(i)
stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs])
for i, stb in zip(delayed_idxs, stbs):
rs[i] = traceback.format_list(tbs[i].summary())
return rs
def _extract_symbolized_tb(tb, skip):
"""
Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of
pre-processed stack trace entries.
"""
stack = traceback.StackSummary()
for f in reversed(tb[skip:]):
stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name']))
return stack

View File

@ -0,0 +1,77 @@
# mypy: allow-untyped-defs
import functools
import hashlib
@functools.lru_cache(None)
def has_triton_package() -> bool:
try:
from triton.compiler.compiler import triton_key
return triton_key is not None
except ImportError:
return False
except RuntimeError:
return False
@functools.lru_cache(None)
def has_triton() -> bool:
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
def _return_true(device_interface):
return True
triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
device_interface = get_interface_for_device(device)
if device_interface.is_available() and extra_check(device_interface):
return True
return False
return is_device_compatible_with_triton() and has_triton_package()
@functools.lru_cache(None)
def triton_backend():
from triton.compiler.compiler import make_backend
from triton.runtime.driver import driver
target = driver.active.get_current_target()
return make_backend(target)
@functools.lru_cache(None)
def triton_hash_with_backend():
from triton.compiler.compiler import triton_key
backend = triton_backend()
key = f"{triton_key()}-{backend.hash()}"
# Hash is upper case so that it can't contain any Python keywords.
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
def dtype_to_string(dtype):
if dtype.name.startswith("fp"):
suffix = "float" + dtype.name[2:]
elif dtype.name.startswith("bf"):
suffix = "bfloat" + dtype.name[2:]
else:
suffix = dtype.name
return "triton.language." + suffix
def patch_triton_dtype_repr():
import triton
# Hack to get triton dtype repr to produce an evaluatable expression
# triton.language.float32 emits triton.language.fp32 which does not
# exist
# REMOVE when https://github.com/openai/triton/pull/3342 lands
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)

View File

@ -0,0 +1,14 @@
"""Miscellaneous utilities to aid with typing."""
from typing import Optional, TypeVar
# Helper to turn Optional[T] into T when we know None either isn't
# possible or should trigger an exception.
T = TypeVar("T")
def not_none(obj: Optional[T]) -> T:
if obj is None:
raise TypeError("Invariant encountered: value was None when it should not be")
return obj

View File

@ -0,0 +1,86 @@
# mypy: allow-untyped-defs
import argparse
import glob
import os
from pathlib import Path
from zipfile import ZipFile
# Exclude some standard library modules to:
# 1. Slim down the final zipped file size
# 2. Remove functionality we don't want to support.
DENY_LIST = [
# Interface to unix databases
"dbm",
# ncurses bindings (terminal interfaces)
"curses",
# Tcl/Tk GUI
"tkinter",
"tkinter",
# Tests for the standard library
"test",
"tests",
"idle_test",
"__phello__.foo.py",
# importlib frozen modules. These are already baked into CPython.
"_bootstrap.py",
"_bootstrap_external.py",
]
strip_file_dir = ""
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def write_to_zip(file_path, strip_file_path, zf, prepend_str=""):
stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/")
path = Path(stripped_file_path)
if path.name in DENY_LIST:
return
zf.write(file_path, stripped_file_path)
def main() -> None:
global strip_file_dir
parser = argparse.ArgumentParser(description="Zip py source")
parser.add_argument("paths", nargs="*", help="Paths to zip.")
parser.add_argument(
"--install-dir", "--install_dir", help="Root directory for all output files"
)
parser.add_argument(
"--strip-dir",
"--strip_dir",
help="The absolute directory we want to remove from zip",
)
parser.add_argument(
"--prepend-str",
"--prepend_str",
help="A string to prepend onto all paths of a file in the zip",
default="",
)
parser.add_argument("--zip-name", "--zip_name", help="Output zip name")
args = parser.parse_args()
zip_file_name = args.install_dir + "/" + args.zip_name
strip_file_dir = args.strip_dir
prepend_str = args.prepend_str
zf = ZipFile(zip_file_name, mode="w")
for p in sorted(args.paths):
if os.path.isdir(p):
files = glob.glob(p + "/**/*.py", recursive=True)
for file_path in sorted(files):
# strip the absolute path
write_to_zip(
file_path, strip_file_dir + "/", zf, prepend_str=prepend_str
)
else:
write_to_zip(p, strip_file_dir + "/", zf, prepend_str=prepend_str)
if __name__ == "__main__":
main() # pragma: no cover

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
from torch._C import _set_backcompat_broadcast_warn
from torch._C import _get_backcompat_broadcast_warn
from torch._C import _set_backcompat_keepdim_warn
from torch._C import _get_backcompat_keepdim_warn
class Warning:
def __init__(self, setter, getter):
self.setter = setter
self.getter = getter
def set_enabled(self, value):
self.setter(value)
def get_enabled(self):
return self.getter()
enabled = property(get_enabled, set_enabled)
broadcast_warning = Warning(_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn)
keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)

View File

@ -0,0 +1,382 @@
# mypy: allow-untyped-defs
import torch
from torch.overrides import (
handle_torch_function,
has_torch_function_unary,
)
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
from typing import List, Optional, Union
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
# renamed-backend name for `privateuse1`, but the func will cause an
# error with torch.jit.script, so we use the global variable named
# `_privateuse1_backend_name`.
_privateuse1_backend_name = "privateuseone"
def rename_privateuse1_backend(backend_name: str) -> None:
r"""
Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
The steps are:
(1) (In C++) implement kernels for various torch operations, and register them
to the PrivateUse1 dispatch key.
(2) (In python) call torch.utils.rename_privateuse1_backend("foo")
You can now use "foo" as an ordinary device string in python.
Note: this API can only be called once per process. Attempting to change
the external backend after it's already been set will result in an error.
Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``.
BackendModule needs to have the following API's:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
(1) ``_is_in_bad_fork() -> bool``
Return ``True`` if now it is in bad_fork, else return ``False``.
(2) ``manual_seed_all(seed int) -> None``
Sets the seed for generating random numbers for your devices.
(3) ``device_count() -> int``
Returns the number of "foo"s available.
(4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
Returns a list of ByteTensor representing the random number states of all devices.
(5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
Sets the random number generator state of the specified "foo" device.
And there are some common funcs:
(1) ``is_available() -> bool``
Returns a bool indicating if "foo" is currently available.
(2) ``current_device() -> int``
Returns the index of a currently selected device.
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
Example::
>>> # xdoctest: +SKIP("failing")
>>> torch.utils.rename_privateuse1_backend("foo")
# This will work, assuming that you've implemented the right C++ kernels
# to implement torch.ones.
>>> a = torch.ones(2, device="foo")
"""
_rename_privateuse1_backend(backend_name)
global _privateuse1_backend_name
_privateuse1_backend_name = backend_name
def _check_register_once(module, attr):
if hasattr(module, attr):
raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
def _get_current_device_index():
_get_device_index = "current_device"
if hasattr(torch, custom_backend_name) and \
hasattr(getattr(torch, custom_backend_name), _get_device_index):
return getattr(getattr(torch, custom_backend_name), _get_device_index)()
else:
# The default device index is 0.
return 0
if device is None:
return _get_current_device_index()
# if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0"
# convert str object to torch.device object, and then process it uniformly
elif isinstance(device, str):
device = torch.device(device)
# variable devcie can only be torch.device type or int type
if isinstance(device, torch.device):
if device.type != custom_backend_name:
raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
elif device.index is None:
device_idx = _get_current_device_index()
else:
device_idx = device.index
# if isinstance(device, int), we can take the index number directly
else:
device_idx = device
return device_idx
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
@property # type: ignore[misc]
def wrap_tensor_backend(self: torch.Tensor) -> bool:
if has_torch_function_unary(self):
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined]
return self.device.type == custom_backend_name
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined]
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
**kwargs) -> torch.Tensor:
r"""Perform Tensor device conversion. Call the to operator implementation.
.. note::
If the ``self`` Tensor already
has the correct :class:`torch.device`, then ``self`` is returned.
Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`.
Args:
device (int, optional): if specified, all parameters will be copied to that device
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
"""
if has_torch_function_unary(self):
return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
device_idx = _normalization_device(custom_backend_name, device)
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
_check_register_once(torch.Tensor, custom_backend_name)
wrap_tensor_to.__name__ = custom_backend_name
setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
# Generate Module attributes and methods depends on Tensor methods,
# so we need to check whether Tensor methods is already registered.
if not hasattr(torch.Tensor, custom_backend_name):
raise RuntimeError(
f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
f"For this error, you can try setting for_tensor=True.")
def wrap_module_to(self: torch.nn.modules.module.T,
device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
r"""Move all model parameters and buffers to the custom device.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on device while being optimized.
.. note::
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
_check_register_once(torch.nn.Module, custom_backend_name)
setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
# Generate PackedSequence Module attributes and methods depends on Tensor methods,
# so we need to check whether Tensor methods is already registered.
if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \
not hasattr(torch.Tensor, custom_backend_name):
raise RuntimeError(
f"Can not automatically generate is_{custom_backend_name}() or "
f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence."
f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()"
f"or {custom_backend_name}()."
f"For this error, you can try setting for_tensor=True.")
@property # type: ignore[misc]
def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool:
return self.data.device.type == custom_backend_name
_check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}')
setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend)
def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence,
*args, **kwargs) -> torch.nn.utils.rnn.PackedSequence:
r"""Move all model parameters and buffers to the custom device.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on device while being optimized.
.. note::
This method modifies the module in-place.
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
if ex.device.type == custom_backend_name:
return self.to(*args, **kwargs)
kwargs.update({'device': custom_backend_name})
return self.to(*args, **kwargs)
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to)
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
# Attribute is registered in the _StorageBase class
# and UntypedStorage obtains through inheritance.
@property # type: ignore[misc]
def wrap_storage_backend(self: torch.storage._StorageBase) -> bool:
r"""Return the internal :class:`torch.UntypedStorage`."""
return self.device.type == custom_backend_name
_check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
def wrap_storage_to(self, device=None, non_blocking=False):
r"""Return a copy of this object in custom device memory.
If this object is already in device memory and on the correct device, then
no copy is performed and the original object is returned.
Args:
device (int): The destination device id. Defaults to the current device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
"""
# There should be a judgment related to storage device and a judgment related to storage type,
# but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
device_idx = _normalization_device(custom_backend_name, device)
if getattr(self, f'is_{custom_backend_name}'):
# storage has already on expected device.
if self.get_device() == device_idx:
return self
# For sparse storage, custom need to extend the implementation by themselves.
if self.is_sparse:
raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
# create untyped_storage and copy data
untyped_storage = torch.UntypedStorage(
self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
)
untyped_storage.copy_(self, non_blocking)
return untyped_storage
_check_register_once(torch.storage._StorageBase, custom_backend_name)
setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)
# Register the corresponding attribute for the TypedStorage class.
# When the TypedStorage class is removed, the registration is also removed.
@property # type: ignore[misc]
def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool:
torch.storage._warn_typed_storage_removal()
return self._untyped_storage.device.type == custom_backend_name
_check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
def wrap_typed_storage_to(self: torch.storage.TypedStorage,
device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
torch.storage._warn_typed_storage_removal()
if unsupported_dtype and self.dtype in unsupported_dtype:
raise RuntimeError(f"Cannot create {custom_backend_name} storage "
f"as {self.dtype} dtype is not supported by this backend")
custom_backend_storage: torch.UntypedStorage = getattr(
self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
return self._new_wrapped_storage(custom_backend_storage)
_check_register_once(torch.TypedStorage, custom_backend_name)
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
for_packed_sequence: bool = True,
for_storage: bool = False,
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
r"""
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
In the default scenario, storage-related methods will not be generated automatically.
When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
And call the function torch.rename_privateuse1_backend("foo") to rename your backend name.
At this point, you can easily register specific methods and attributes by calling this function.
Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo.
Note: We recommend you use generic functions (check devices are equal or to(device=)).
We provide these methods for convenience only and they will be "monkey patched" onto the objects
and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage,
you need to extend the implementation yourself.
Args:
for_tensor (bool): whether register related methods for torch.Tensor class.
for_module (bool): whether register related methods for torch.nn.Module class.
for_storage (bool): whether register related methods for torch.Storage class.
unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated,
indicating that the storage does not support the torch.dtype type.
Example::
>>> # xdoctest: +SKIP("failing")
>>> torch.utils.rename_privateuse1_backend("foo")
>>> torch.utils.generate_methods_for_privateuse1_backend()
# Then automatically generate backend-related attributes and methods.
>>> a = torch.tensor(2).foo()
>>> a.is_foo
>>> hasattr(torch.nn.Module, 'foo')
"""
custom_backend_name = _get_privateuse1_backend_name()
if for_tensor:
_generate_tensor_methods_for_privateuse1_backend(custom_backend_name)
if for_module:
_generate_module_methods_for_privateuse1_backend(custom_backend_name)
if for_storage:
_generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
if for_packed_sequence:
_generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name)
def _get_custom_mod_func(func_name: str):
r"""
Return the func named `func_name` defined in custom device module. If not defined,
return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')`
and `torch._register_device_module('foo', BackendModule)`.
If the custom device module or the func is not defined, it will give warning or error message.
Args:
func_name (str): return the callable func named func_name defined in custom device module.
Example::
class DummyfooModule:
@staticmethod
def is_available():
return True
@staticmethod
def func_name(*args, **kwargs):
....
torch.utils.rename_privateuse1_backend("foo")
torch._register_device_module("foo", DummyfooModule)
foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available")
if foo_is_available_func:
foo_is_available = foo_is_available_func()
func_ = torch.utils.backend_registration._get_custom_mod_func("func_name")
if func_:
result = func_(*args, **kwargs)
Attention: This function is not meant to be used directly by users, which is why
it is marked as private. It is a convenience function for backend implementers to
more easily call the hooks into their backend extensions.
"""
assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
backend_name = _get_privateuse1_backend_name()
custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type]
function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type]
if custom_device_mod is None or function is None:
message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
raise RuntimeError(message)
return function

View File

@ -0,0 +1,6 @@
from torch.utils.benchmark.utils.common import * # noqa: F403
from torch.utils.benchmark.utils.timer import * # noqa: F403
from torch.utils.benchmark.utils.compare import * # noqa: F403
from torch.utils.benchmark.utils.fuzzer import * # noqa: F403
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import * # noqa: F403
from torch.utils.benchmark.utils.sparse_fuzzer import * # noqa: F403

Some files were not shown because too many files have changed in this diff Show More