I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,179 @@
from .base import VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
CatchWarningsCtxManagerVariable,
ContextWrappingVariable,
CUDADeviceVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
InferenceModeVariable,
JvpIncrementNestingCtxManagerVariable,
SetFwdGradEnabledContextManager,
StreamContextVariable,
StreamVariable,
VmapIncrementNestingCtxManagerVariable,
WithExitFunctionVariable,
)
from .dicts import (
ConstDictVariable,
CustomizedDictVariable,
DefaultDictVariable,
FrozensetVariable,
SetVariable,
)
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
from .functions import (
FunctoolsPartialVariable,
NestedUserFunctionVariable,
PolyfilledFunctionVariable,
SkipFunctionVariable,
UserFunctionVariable,
UserMethodVariable,
)
from .higher_order_ops import (
FunctionalCallVariable,
FunctorchHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import (
CountIteratorVariable,
CycleIteratorVariable,
IteratorVariable,
ItertoolsVariable,
MapVariable,
RepeatIteratorVariable,
ZipVariable,
)
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
RangeVariable,
RestrictedListSubclassVariable,
SliceVariable,
TupleIteratorVariable,
TupleVariable,
)
from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
ClosureVariable,
DeletedVariable,
ExceptionVariable,
GetAttrVariable,
InspectSignatureVariable,
LambdaVariable,
MethodWrapperVariable,
NewCellVariable,
NewGlobalVariable,
NumpyVariable,
PythonModuleVariable,
RandomClassVariable,
RandomVariable,
RegexPatternVariable,
StringFormatVariable,
SuperVariable,
TorchVersionVariable,
TypingVariable,
UnknownVariable,
)
from .nn_module import (
FSDPManagedNNModuleVariable,
NNModuleVariable,
UnspecializedBuiltinNNModuleVariable,
UnspecializedNNModuleVariable,
)
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
FakeItemVariable,
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
UntypedStorageVariable,
)
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable
from .user_defined import (
MutableMappingVariable,
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
WeakRefVariable,
)
__all__ = [
"AutogradFunctionContextVariable",
"AutogradFunctionVariable",
"BackwardHookVariable",
"BaseListVariable",
"BuiltinVariable",
"CatchWarningsCtxManagerVariable",
"ClosureVariable",
"ConstantVariable",
"ConstDictVariable",
"ContextWrappingVariable",
"CountIteratorVariable",
"CUDADeviceVariable",
"CustomizedDictVariable",
"CycleIteratorVariable",
"DefaultDictVariable",
"DeletedVariable",
"DeterministicAlgorithmsVariable",
"EnumVariable",
"FakeItemVariable",
"GetAttrVariable",
"GradModeVariable",
"InspectSignatureVariable",
"IteratorVariable",
"ItertoolsVariable",
"LambdaVariable",
"LazyVariableTracker",
"ListIteratorVariable",
"ListVariable",
"NamedTupleVariable",
"NestedUserFunctionVariable",
"NewCellVariable",
"NewGlobalVariable",
"NNModuleVariable",
"NumpyNdarrayVariable",
"NumpyVariable",
"OptimizerVariable",
"PlacementVariable",
"PolyfilledFunctionVariable",
"PythonModuleVariable",
"RangeVariable",
"RegexPatternVariable",
"RemovableHandleVariable",
"RepeatIteratorVariable",
"RestrictedListSubclassVariable",
"SDPAParamsVariable",
"SkipFunctionVariable",
"SliceVariable",
"StringFormatVariable",
"SuperVariable",
"TensorVariable",
"TorchCtxManagerClassVariable",
"TorchInGraphFunctionVariable",
"TorchVersionVariable",
"TupleVariable",
"UnknownVariable",
"UnspecializedNNModuleVariable",
"UnspecializedPythonVariable",
"UntypedStorageVariable",
"UserDefinedClassVariable",
"UserDefinedObjectVariable",
"UserFunctionVariable",
"UserMethodVariable",
"VariableTracker",
"WithExitFunctionVariable",
]

View File

@ -0,0 +1,385 @@
# mypy: ignore-errors
import collections
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from .. import variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented
from ..source import AttrSource, Source
from ..utils import istype
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class MutableLocalSource(Enum):
"""
If the VariableTracker.mutable_local represents a Variable that:
- already existed that Dynamo began tracking while introspection (Existing)
- is a new variable that is created during Dynamo introspection (Local)
"""
Existing = 0
Local = 1
class MutableLocalBase:
"""
Base class for Variable.mutable_local
"""
def __init__(self, typ: MutableLocalSource) -> None:
# In HigherOrderOperator tracing, we need to distinguish
# between MutableLocals inside the HigherOrderOperator and
# ones outside it. For example, it is not safe to mutate
# `a` in the following example because it was constructed
# in a different scope.
#
# def f(x):
# a = 1
# def g(x):
# nonlocal a
# a = 2
# return x
# return wrap(g, x) + a
#
# We use self.scope to distinguish this.
# scope == 0: The object was an existing variable
# scope == 1: The object was created while Dynamo
# was introspecting a function
# (and no HigherOrderOps were involved)
# scope >= 2: The object was created through
# Dynamo introspection of a HigherOrderOp.
# The exact number corresponds to the level
# of nested HigherOrderOps.
if typ is MutableLocalSource.Existing:
self.scope = 0
elif typ is MutableLocalSource.Local:
self.scope = current_scope_id()
else:
unimplemented(f"Unsupported MutableLocalSource: {typ}")
class MutableLocal(MutableLocalBase):
"""
Marker used to indicate this (list, iter, etc) was constructed in
local scope and can be mutated safely in analysis without leaking
state.
"""
def __init__(self) -> None:
super().__init__(MutableLocalSource.Local)
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
def _is_top_level_scope(scope_id):
return scope_id == 1
def is_side_effect_safe(m: MutableLocalBase):
scope_id = current_scope_id()
# In the top-level scope (if no HigherOrderOperators are involved),
# we are allowed to modify variables created in this scope as well
# as existing variables.
if _is_top_level_scope(scope_id):
return True
# Otherwise, only allow local mutation of variables created in the current scope
return m.scope == scope_id
class VariableTrackerMeta(type):
all_subclasses = []
def __instancecheck__(cls, instance) -> bool:
"""Make isinstance work with LazyVariableTracker"""
if type.__instancecheck__(
variables.LazyVariableTracker, instance
) and cls not in (
VariableTracker,
variables.LazyVariableTracker,
):
instance = instance.realize()
return type.__instancecheck__(cls, instance)
def __init__(cls, name, bases, attrs) -> None:
super().__init__(name, bases, attrs)
VariableTrackerMeta.all_subclasses.append(cls)
class VariableTracker(metaclass=VariableTrackerMeta):
"""
Base class for tracked locals and stack values
VariableTracker instances are immutable and should be copied in
order to change them.
"""
# fields to leave unmodified in apply()
_nonvar_fields = {
"value",
"guards",
"source",
"mutable_local",
"parents_tracker",
"user_code_variable_name",
}
def clone(self, **kwargs):
"""Shallow copy with some (optional) changes"""
args = dict(self.__dict__)
args.update(kwargs)
return self.__class__(**args)
@classmethod
def visit(
cls,
fn: Callable[["VariableTracker"], None],
value: Any,
cache: Optional[Dict[int, Any]] = None,
) -> None:
"""
Walk value and call fn on all the VariableTracker instances
"""
if cache is None:
cache = {}
idx = id(value)
if idx in cache:
return
# save `value` to keep it alive and ensure id() isn't reused
cache[idx] = value
if isinstance(value, VariableTracker):
value = value.unwrap()
fn(value)
value = value.unwrap() # calling fn() might have realized it
nonvars = value._nonvar_fields
for key, subvalue in value.__dict__.items():
if key not in nonvars:
cls.visit(fn, subvalue, cache)
elif istype(value, (list, tuple)):
for subvalue in value:
cls.visit(fn, subvalue, cache)
elif istype(value, (dict, collections.OrderedDict)):
for subvalue in value.values():
cls.visit(fn, subvalue, cache)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def debug_repr(self):
# Intended to be overridden to provide more info
try:
return repr(self.as_python_constant())
except NotImplementedError:
return repr(self)
def python_type(self):
"""
Abstract method to be implemented by subclasses of VariableTracker.
This method should return the type represented by the instance of the subclass.
The purpose is to provide a standardized way to retrieve the Python type information
of the variable being tracked.
Returns:
type: The Python type (such as int, str, list, etc.) of the variable tracked by
the subclass. If the type cannot be determined or is not relevant,
leaving it undefined or invoking super() is always sound.
Note:
This is an abstract method and may be overridden in subclasses.
Example:
class SetVariable(VariableTracker):
def python_type(self):
return set
Raises:
NotImplementedError: If the method is not implemented in a subclass.
"""
try:
return type(self.as_python_constant())
except NotImplementedError:
raise NotImplementedError(f"{self} has no type") from None
def as_python_constant(self):
"""For constants"""
raise NotImplementedError(f"{self} is not a constant")
def guard_as_python_constant(self):
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
try:
return self.as_python_constant()
except NotImplementedError as e:
unimplemented(str(e))
def is_python_constant(self):
try:
self.as_python_constant()
return True
except NotImplementedError:
return False
def make_guard(self, fn):
if self.source:
return self.source.make_guard(fn)
raise NotImplementedError
def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError
source = None
if self.source:
source = AttrSource(self.source, name)
return variables.ConstantVariable.create(value, source=source)
def is_proxy(self):
try:
self.as_proxy()
return True
except NotImplementedError:
return False
def as_proxy(self):
raise NotImplementedError(str(self))
def maybe_fx_node(self):
try:
proxy = self.as_proxy()
import torch.fx
if isinstance(proxy, torch.fx.Proxy):
return proxy.node
return None
except NotImplementedError:
return None
def reconstruct(self, codegen):
raise NotImplementedError
def can_reconstruct(self, tx):
"""If it is possible to reconstruct the Python object this
VariableTracker represents."""
assert tx is tx.output.root_tx, "Only root tx can reconstruct"
try:
from ..codegen import PyCodegen
cg = PyCodegen(tx)
self.reconstruct(cg)
return True
except NotImplementedError:
return False
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
raise NotImplementedError
def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]:
# like unpack_var_sequence, but should only be used when it is
# safe to eagerly (vs. lazily) unpack this variable.
# e.g. map(f, x) is normally evaluated lazily but sometimes
# we want to force eager unpacking, e.g. when converting to a list.
# NOTE: this method is allowed to mutate the VariableTracker, so
# it should only be called once.
return self.unpack_var_sequence(tx)
def has_unpack_var_sequence(self, tx) -> bool:
try:
self.unpack_var_sequence(tx)
return True
except NotImplementedError:
return False
# NB: don't call force_unpack_var_sequence, especially if it mutates!
def has_force_unpack_var_sequence(self, tx) -> bool:
return self.has_unpack_var_sequence(tx)
def inspect_parameter_names(self) -> List[str]:
unimplemented(f"inspect_parameter_names: {self}")
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
unimplemented(f"hasattr {self.__class__.__name__} {name}")
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__len__" and self.has_unpack_var_sequence(tx):
assert not (args or kwargs)
return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
elif (
name == "__getattr__"
and len(args) == 1
and args[0].is_python_constant()
and not kwargs
):
return self.var_getattr(tx, args[0].as_python_constant())
unimplemented(f"call_method {self} {name} {args} {kwargs}")
def set_name_hint(self, name):
pass
def realize(self) -> "VariableTracker":
"""Used by LazyVariableTracker to build the real VariableTracker"""
return self
def unwrap(self) -> "VariableTracker":
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
return self
def is_realized(self):
"""Used by LazyVariableTracker to indicate an unrealized node"""
return True
def next_variable(self, tx):
unimplemented(f"next({self})")
def is_strict_mode(self, tx):
return tx.strict_checks_fn and tx.strict_checks_fn(self)
def __init__(
self,
*,
source: Source = None,
mutable_local: MutableLocal = None,
) -> None:
super().__init__()
self.source = source
self.mutable_local = mutable_local
def typestr(*objs):
if len(objs) == 1:
(obj,) = objs
if isinstance(obj, VariableTracker):
return str(obj)
else:
return type(obj).__name__
else:
return " ".join(map(typestr, objs))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,237 @@
# mypy: ignore-errors
import operator
from typing import Dict, List, TYPE_CHECKING
import torch
from torch._dynamo.source import GetItemSource
from .. import variables
from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder, install_guard
from ..utils import common_constant_types, istype, np
from .base import typestr, VariableTracker
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
_type_to_assert_reason = {
# NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards.
# A locally created set should always become a SetVariable, as the items in the set will already either be sourced
# from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an
# input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is
# not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to
# become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable.
# As this does not exist, we cannot add sets to this invariant.
list: "List types must use ListVariable.",
dict: "Dict types must use ConstDictVariable.",
torch.Tensor: "Tensor types must use TensorVariable.",
torch.SymInt: "SymInts must use SymNodeVariable. "
"If the underlying value is static, we will create a ConstantVariable and specialize.",
torch.SymFloat: "SymInts must use SymNodeVariable",
}
class ConstantVariable(VariableTracker):
@staticmethod
def create(value, **kwargs) -> VariableTracker:
source = kwargs.get("source", None)
is_literal = ConstantVariable.is_literal(value)
if not is_literal:
for disallowed_type, reason in _type_to_assert_reason.items():
assert not isinstance(value, disallowed_type), reason
# Routing for list and tuple literals.
if is_literal and isinstance(value, (set, frozenset)):
items = []
for i, x in enumerate(value):
items.append(ConstantVariable.create(x))
return variables.SetVariable(items, **kwargs)
elif is_literal and isinstance(value, (list, tuple)):
items = []
for i, x in enumerate(value):
item_source = GetItemSource(source, i) if source else None
if item_source:
install_guard(item_source.make_guard(GuardBuilder.CONSTANT_MATCH))
items.append(
ConstantVariable.create(
x,
source=item_source,
)
)
return variables.BaseListVariable.cls_for(type(value))(items, **kwargs)
return ConstantVariable(value, **kwargs)
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
if not ConstantVariable.is_literal(value):
for disallowed_type, reason in _type_to_assert_reason.items():
assert not isinstance(value, disallowed_type), reason
assert not isinstance(
value, (list, tuple)
), "ConstantVariable(list) is banned - please create a ListVariable(items)"
if np is not None and isinstance(value, np.number):
self.value = value.item()
else:
self.value = value
def as_proxy(self):
return self.value
def __str__(self) -> str:
return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})"
def as_python_constant(self):
return self.value
def is_python_constant(self):
return True
@property
def items(self):
"""
Need this when adding a BaseListVariable and a ConstantVariable together.
Happens in detectron2.
"""
return self.unpack_var_sequence(tx=None)
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
return ConstantVariable.create(
self.value[arg.as_python_constant()],
)
@staticmethod
def is_literal(obj):
if type(obj) in common_constant_types:
return True
# The structure within is_literal get routed to variables.BaseListVariable
if type(obj) in (list, tuple, set, frozenset, torch.Size):
return all(ConstantVariable.is_literal(x) for x in obj)
return False
def unpack_var_sequence(self, tx):
try:
return [ConstantVariable.create(x) for x in self.as_python_constant()]
except TypeError as e:
raise NotImplementedError from e
def const_getattr(self, tx: "InstructionTranslator", name):
if isinstance(self.value, type):
raise UserError(
UserErrorType.ANTI_PATTERN,
"Can't access members of type(obj) for a generated custom object. "
"Please use __class__ instead",
case_name="type_reflection_method",
)
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError
return member
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .tensor import SymNodeVariable
if name == "format" and istype(self.value, str):
return variables.BuiltinVariable(str.format).call_function(
tx, [self, *args], kwargs
)
elif name == "join" and istype(self.value, str):
assert len(args) == 1 and len(kwargs) == 0
arg_unpacked = args[0].force_unpack_var_sequence(tx)
try:
arg_const = [x.as_python_constant() for x in arg_unpacked]
return ConstantVariable.create(self.value.join(arg_const))
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes.
return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
tx, name, args, kwargs
)
try:
const_args = [a.as_python_constant() for a in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
if isinstance(self.value, str) and name in str.__dict__.keys():
method = getattr(self.value, name)
return ConstantVariable.create(method(*const_args, **const_kwargs))
elif isinstance(self.value, (float, int)):
if not (args or kwargs):
return ConstantVariable.create(getattr(self.value, name)())
if (
hasattr(operator, name)
and len(args) == 1
and args[0].is_python_constant()
):
add_target = const_args[0]
op = getattr(operator, name)
if isinstance(
add_target, (torch.SymBool, torch.SymFloat, torch.SymInt)
):
# Addition between a non sym and sym makes a sym
proxy = tx.output.create_proxy(
"call_function", op, (self.value, add_target), {}
)
return SymNodeVariable.create(tx, proxy, add_target)
else:
return ConstantVariable.create(op(self.value, add_target))
elif isinstance(self.value, bytes) and name == "decode":
method = getattr(self.value, name)
return ConstantVariable.create(method(*const_args, **const_kwargs))
if name == "__len__" and not (args or kwargs):
return ConstantVariable.create(len(self.value))
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
assert not kwargs
search = args[0].as_python_constant()
result = search in self.value
return ConstantVariable.create(result)
unimplemented(f"const method call {typestr(self.value)}.{name}")
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
result = hasattr(self.value, name)
return variables.ConstantVariable.create(result)
class EnumVariable(VariableTracker):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
@classmethod
def create(cls, cls_type, value_vt, options):
if isinstance(value_vt, variables.ConstantVariable):
for member in list(cls_type):
if member.value == value_vt.as_python_constant():
return cls(member, **options)
unimplemented("Enum variable is constructed with non constant values")
def as_proxy(self):
return self.value
def __str__(self) -> str:
return f"EnumVariable({type(self.value)})"
def as_python_constant(self):
return self.value
def const_getattr(self, tx: "InstructionTranslator", name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError
return member

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,399 @@
# mypy: ignore-errors
import functools
import inspect
from typing import Dict, List, TYPE_CHECKING
import torch
from torch.fx.experimental._backward_state import BackwardState
from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented
from ..external_utils import call_module_hooks_from_backward_state
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource
from ..utils import istype
from .base import VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class DistributedVariable(VariableTracker):
"""
The base distributed variable that encapsulates common methods
for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
Concrete distributed objects could inherit this class and add object
specific logic.
i.e. It provides the check on the distributed package existance
and hold the tracking value for the corresponding distributed object.
"""
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
if not DistributedVariable.is_available():
unimplemented("torch.distributed package is not available!")
self.value = value
def python_type(self):
return type(self.value)
@staticmethod
def is_available():
# check if the distributed package is available or not
return torch.distributed.is_available()
def is_from_local(value):
if not DistributedVariable.is_available():
return False
from torch.distributed.tensor import DTensor
return inspect.isfunction(value) and value is DTensor.from_local
def is_constant_pg_functions(value):
if not DistributedVariable.is_available():
return False
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
_get_group_tag,
_rank_not_in_group,
_resolve_group_name_by_ranks_and_tag,
get_process_group_ranks,
)
constant_processgroup_functions = [
_get_group_size_by_name,
_get_group_tag,
_rank_not_in_group,
get_process_group_ranks,
_resolve_group_name_by_ranks_and_tag,
]
return inspect.isfunction(value) and value in constant_processgroup_functions
class WorldMetaClassVariable(DistributedVariable):
"""
Tracks torch.distributed.GroupMember and torch.distributed.group, which are
instances of the metaclass _WorldMeta.
"""
@classmethod
def is_group_member_type(cls, value):
if not cls.is_available():
return False
from torch.distributed.distributed_c10d import _WorldMeta
return type(value) is _WorldMeta
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "WORLD":
source = AttrSource(base=self.source, member="WORLD")
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
return ProcessGroupVariable(self.value.WORLD)
return super().var_getattr(tx, name)
class PlacementClassVariable(DistributedVariable):
@staticmethod
def is_placement_type(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed.tensor.placement_types import Placement
return type(value) is type and issubclass(value, Placement)
def as_python_constant(self):
return self.value
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they
# suppose to be immutable.
new_obj = object.__new__(self.value)
var = PlacementVariable(new_obj)
if inspect.getattr_static(self.value, "__init__", None):
var.call_method(tx, "__init__", args, kwargs)
return var
return super().call_function(tx, args, kwargs)
class PlacementVariable(DistributedVariable):
@staticmethod
def is_placement(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed.tensor.placement_types import Placement
return isinstance(value, Placement)
def as_python_constant(self):
return self.value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "dim":
return ConstantVariable.create(self.value.dim)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable
# Placement types dynamo tracking only allows following methods
# and __setattr__ is for case like `Shard(dim)` and methods.
# Methods in the list must satisfy:
# 1. Input arguments are constants and do not need to be guarded on;
# 2. Output is constant with respect to their inputs
constant_fold_functions = [
"__init__",
"__setattr__",
"is_shard",
"is_partial",
"is_replicate",
]
if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (
inspect.getattr_static(value_type, "__getattr__", None) is None
), "no custom getattr allowed!"
method = inspect.getattr_static(value_type, name)
except AttributeError:
method = None
if method is object.__init__:
return ConstantVariable.create(None)
args = [x.as_python_constant() for x in args]
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
if name == "__setattr__":
method(self.value, *args, **kwargs)
return self
constant_val = method(self.value, *args, **kwargs)
return ConstantVariable.create(constant_val)
return super().call_method(tx, name, args, kwargs)
class DeviceMeshVariable(DistributedVariable):
@staticmethod
def is_device_mesh(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed.device_mesh import DeviceMesh
return istype(value, DeviceMesh)
def as_python_constant(self):
return self.value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "ndim":
return ConstantVariable.create(self.value.ndim)
if name == "device_type":
return ConstantVariable.create(self.value.device_type)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "size":
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
if name == "get_coordinate":
return ConstantVariable.create(self.value.get_coordinate())
if name == "get_group":
return ConstantVariable.create(self.value.get_group())
if name == "_get_or_create_default_group":
return ProcessGroupVariable(self.value._get_or_create_default_group())
return super().call_method(tx, name, args, kwargs)
class ProcessGroupVariable(DistributedVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def as_python_constant(self):
return self.value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable.create(self.value.rank())
if name == "size":
return variables.ConstantVariable.create(self.value.size())
if name == "_get_backend_name":
return variables.ConstantVariable.create(self.value._get_backend_name())
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
if name == "group_name":
return variables.ConstantVariable.create(self.value.group_name)
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch._C._distributed_c10d import ProcessGroup
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
return istype(value, (ProcessGroup, FakeProcessGroup))
class BackwardHookVariable(VariableTracker):
"""
Handles torch.utils.hooks.BackwardHook for module-level backward
hooks.
"""
@staticmethod
def create(
tx,
module: VariableTracker,
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
):
if not compiled_autograd.compiled_autograd_enabled:
unimplemented("module-level backwards hooks require compiled autograd")
def _in_graph_bw_hooks(bw_state: BackwardState):
"""
Rather than installing the user hooks in the graph (which
don't survive AotAutograd), we install hooks that will call
trace_wrapped in the backward pass that CompiledAutograd
can turn into actual hook calls.
"""
return torch.utils.hooks.BackwardHook(
None,
(
functools.partial(
trace_wrapped,
fn=call_module_hooks_from_backward_state,
bw_state=bw_state,
hooks_name=user_hooks_name,
module_name=module_name,
),
),
(
functools.partial(
trace_wrapped,
fn=call_module_hooks_from_backward_state,
bw_state=bw_state,
hooks_name=user_pre_hooks_name,
module_name=module_name,
),
),
)
module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
proxy = tx.output.create_proxy(
"call_function",
_in_graph_bw_hooks,
(bw_state_proxy,),
{},
)
proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
def __init__(
self,
proxy: torch.fx.Proxy,
module: VariableTracker,
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
**options,
) -> None:
super().__init__(**options)
self.proxy = proxy
self.module = module
self.user_hooks = user_hooks
self.user_pre_hooks = user_pre_hooks
def as_proxy(self):
return self.proxy
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
if name in ("setup_input_hook", "setup_output_hook"):
return self._setup_hook(tx, name, *args, **kwargs)
return super().call_method(tx, name, args, kwargs)
def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
hook_method_name,
(self.as_proxy(), args.as_proxy()),
{},
),
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,475 @@
# mypy: ignore-errors
import itertools
import operator
import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from .. import polyfills, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
handle_observed_exception,
ObservedUserStopIteration,
raise_observed_exception,
unimplemented,
UserError,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
MAX_ITERATOR_LIMIT = 100 * 1024 # 100k
class ItertoolsVariable(VariableTracker):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
def __repr__(self) -> str:
return f"ItertoolsVariable({self.value})"
def as_python_constant(self):
return self.value
def call_function(
self,
tx: "InstructionTranslator",
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
self.value is itertools.product
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.accumulate:
from .builtin import BuiltinVariable
if any(key not in ["initial", "func"] for key in kwargs.keys()):
unimplemented(
"Unsupported kwargs for itertools.accumulate: "
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
)
acc = kwargs.get("initial")
if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)
if "func" in kwargs and len(args) == 1:
func = kwargs["func"].call_function
elif len(args) == 2:
func = args[1].call_function
elif len(args) == 1:
# Default to operator.add
func = BuiltinVariable(operator.add).call_function
else:
unimplemented(
"itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
)
else:
unimplemented("Unsupported arguments for itertools.accumulate")
items = []
if acc is not None:
items.append(acc)
for item in seq:
if acc is None:
acc = item
else:
try:
acc = func(tx, [acc, item], {})
except Exception as e:
unimplemented(
f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
from_exc=e,
)
items.append(acc)
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif (
self.value is itertools.combinations
and not kwargs
and len(args) == 2
and args[0].has_unpack_var_sequence(tx)
and args[1].is_python_constant()
):
iterable = args[0].unpack_var_sequence(tx)
r = args[1].as_python_constant()
items = []
for item in itertools.combinations(iterable, r):
items.append(variables.TupleVariable(list(item)))
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
elif self.value is itertools.groupby:
if any(kw != "key" for kw in kwargs.keys()):
unimplemented(
"Unsupported kwargs for itertools.groupby: "
f"{','.join(set(kwargs.keys()) - {'key'})}"
)
def retrieve_const_key(key):
if isinstance(key, variables.SymNodeVariable):
return key.evaluate_expr()
elif isinstance(key, variables.ConstantVariable):
return key.as_python_constant()
else:
unimplemented(
"Unsupported key type for itertools.groupby: " + str(type(key))
)
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
seq = args[0].unpack_var_sequence(tx)
keyfunc = (
(
lambda x: (
retrieve_const_key(
kwargs.get("key").call_function(tx, [x], {})
)
)
)
if "key" in kwargs
else None
)
else:
unimplemented("Unsupported arguments for itertools.groupby")
result = []
try:
for k, v in itertools.groupby(seq, key=keyfunc):
result.append(
variables.TupleVariable(
[
variables.ConstantVariable.create(k)
if variables.ConstantVariable.is_literal(k)
else k,
variables.ListIteratorVariable(
list(v), mutable_local=MutableLocal()
),
],
mutable_local=MutableLocal(),
)
)
except Exception as e:
unimplemented(
"Unexpected failure when calling itertools.groupby",
from_exc=e,
)
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
elif self.value is itertools.repeat:
if len(args) < 2:
return variables.RepeatIteratorVariable(
*args, mutable_local=MutableLocal()
)
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
elif self.value is itertools.cycle:
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
elif self.value is itertools.dropwhile:
return variables.UserFunctionVariable(polyfills.dropwhile).call_function(
tx, args, kwargs
)
elif self.value is itertools.zip_longest:
return variables.UserFunctionVariable(polyfills.zip_longest).call_function(
tx, args, kwargs
)
else:
return super().call_function(tx, args, kwargs)
class IteratorVariable(VariableTracker):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def next_variable(self, tx):
unimplemented("abstract method, must implement")
# NOTE: only call when unpacking this iterator safely done eagerly!
# Normally, iterators are accessed lazily.
# Example of safe eager unpacking: list(map(f, seq))
# Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
result = []
while True:
try:
result.append(self.next_variable(tx))
except ObservedUserStopIteration:
handle_observed_exception(tx)
break
return result
# don't call force_unpack_var_sequence since it can mutate
# IteratorVariable state!
def has_force_unpack_var_sequence(self, tx) -> bool:
return True
class RepeatIteratorVariable(IteratorVariable):
def __init__(self, item: VariableTracker, **kwargs) -> None:
super().__init__(**kwargs)
self.item = item
# Repeat needs no mutation, clone self
def next_variable(self, tx):
return self.item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("repeat"),
]
)
)
codegen(self.item)
codegen.extend_output(create_call_function(1, False))
class CountIteratorVariable(IteratorVariable):
def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
super().__init__(**kwargs)
if not isinstance(item, VariableTracker):
item = ConstantVariable.create(item)
if not isinstance(step, VariableTracker):
step = ConstantVariable.create(step)
self.item = item
self.step = step
def next_variable(self, tx):
assert self.mutable_local
old_item = self.item
tx.output.side_effects.mutation(self)
self.item = self.item.call_method(tx, "__add__", [self.step], {})
return old_item
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(itertools),
codegen.create_load_attr("count"),
]
)
)
codegen(self.item)
codegen(self.step)
codegen.extend_output(create_call_function(2, False))
class CycleIteratorVariable(IteratorVariable):
def __init__(
self,
iterator: IteratorVariable,
saved: List[VariableTracker] = None,
saved_index: int = 0,
item: Optional[VariableTracker] = None,
**kwargs,
) -> None:
if saved is None:
saved = []
super().__init__(**kwargs)
self.iterator = iterator
self.saved = saved
self.saved_index = saved_index
self.item = item
def next_variable(self, tx):
assert self.mutable_local
if self.iterator is not None:
try:
new_item = self.iterator.next_variable(tx)
if len(self.saved) > MAX_ITERATOR_LIMIT:
unimplemented(
"input iterator to itertools.cycle has too many items"
)
tx.output.side_effects.mutation(self)
self.saved.append(new_item)
self.item = new_item
if self.item is None:
return self.next_variable(tx)
return self.item
except ObservedUserStopIteration:
handle_observed_exception(tx)
self.iterator = None
return self.next_variable(tx)
elif len(self.saved) > 0:
tx.output.side_effects.mutation(self)
self.saved_index = (self.saved_index + 1) % len(self.saved)
return self.item
else:
raise_observed_exception(StopIteration, tx, self)
class ZipVariable(IteratorVariable):
"""
Represents zip(*iterables)
"""
_nonvar_fields = {
"index",
"strict",
*IteratorVariable._nonvar_fields,
}
def __init__(
self,
iterables: List[Union[List[VariableTracker], VariableTracker]],
strict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
assert isinstance(iterables, list)
# can be list[Variable] or VariableTracker (with next_variable implemented)
self.iterables = iterables
self.index = 0
self.strict = strict
def python_type(self):
return zip
def has_unpack_var_sequence(self, tx) -> bool:
return all(
isinstance(it, list) or it.has_unpack_var_sequence(tx)
for it in self.iterables
)
def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
assert self.has_unpack_var_sequence(tx)
iterables = []
for it in self.iterables:
if isinstance(it, list):
iterables.append(it[self.index :])
else:
iterables.append(it.unpack_var_sequence(tx))
kwargs = {"strict": self.strict} if self.strict else {}
zipped = zip(*iterables, **kwargs)
return [variables.TupleVariable(list(var)) for var in zipped]
def next_variable(self, tx):
assert self.mutable_local
old_index = self.index
args = []
def get_item(it):
if isinstance(it, list):
if old_index >= len(it):
raise_observed_exception(StopIteration, tx, self)
return it[old_index]
else:
return it.next_variable(tx)
try:
for idx, it in enumerate(self.iterables):
args.append(get_item(it))
except ObservedUserStopIteration:
if self.strict:
if idx == 0:
# all other iterables should be exhausted
for it in self.iterables:
try:
get_item(it)
except ObservedUserStopIteration:
handle_observed_exception(tx)
continue
# no ObservedUserStopIteration - fall through to UserError
break
else:
# all iterables exhausted, raise original error
raise
handle_observed_exception(tx)
raise UserError(
ValueError,
"zip() has one argument of len differing from others",
) from None
raise
tx.output.side_effects.mutation(self)
self.index += 1
return variables.TupleVariable(args)
def reconstruct_items(self, codegen):
for it in self.iterables:
if isinstance(it, list):
remaining_items = it[self.index :]
codegen.foreach(remaining_items)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(remaining_items))
)
else:
codegen(it)
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
)
self.reconstruct_items(codegen)
codegen.append_output(
create_instruction("BUILD_TUPLE", arg=len(self.iterables))
)
if sys.version_info >= (3, 10):
codegen.extend_output(
[
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
create_instruction("CALL_FUNCTION_EX", arg=1),
]
)
else:
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
class MapVariable(ZipVariable):
"""
Represents map(fn, *iterables)
"""
def __init__(
self,
fn: VariableTracker,
iterables: List[Union[List[VariableTracker], VariableTracker]],
**kwargs,
) -> None:
super().__init__(iterables, **kwargs)
self.fn = fn
def python_type(self):
return map
def has_unpack_var_sequence(self, tx) -> bool:
return False
def next_variable(self, tx):
args = super().next_variable(tx)
return self.fn.call_function(tx, args.items, {})
def reconstruct(self, codegen):
codegen.add_push_null(
lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
)
codegen(self.fn)
self.reconstruct_items(codegen)
codegen.extend_output(
[
create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
create_instruction("CALL_FUNCTION_EX", arg=0),
]
)

View File

@ -0,0 +1,168 @@
# mypy: allow-untyped-defs
import collections
import functools
from typing import Optional
from .base import VariableTracker
from .tensor import SymNodeVariable
class LazyCache:
"""Container to cache the real VariableTracker"""
def __init__(self, value, source) -> None:
if not isinstance(value, LazySymNodeFormatString):
assert source
self.value = value
self.source = source
self.vt: Optional[VariableTracker] = None
def realize(self):
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
from .builder import SourcelessBuilder, VariableBuilder
tx = InstructionTranslator.current_tx()
if isinstance(self.value, LazySymNodeFormatString):
self.vt = SourcelessBuilder.create(tx, self.value)
else:
self.vt = VariableBuilder(tx, self.source)(self.value)
del self.value
del self.source
class LazyVariableTracker(VariableTracker):
"""
A structure that defers the creation of the actual VariableTracker
for a given underlying value until it is accessed.
The `realize` function invokes VariableBuilder to produce the real object.
Once a LazyVariableTracker has been realized, internal bookkeeping will
prevent double realization.
This object should be utilized for processing containers, or objects that
reference other objects where we may not want to take on creating all the
VariableTrackers right away.
"""
_nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
@staticmethod
def create(value, source, **options):
return LazyVariableTracker(LazyCache(value, source), source=source, **options)
def __init__(self, _cache, **kwargs) -> None:
assert isinstance(_cache, LazyCache)
super().__init__(**kwargs)
self._cache = _cache
def realize(self) -> VariableTracker:
"""Force construction of the real VariableTracker"""
if self._cache.vt is None:
self._cache.realize()
assert self._cache.vt is not None
return self._cache.vt
def unwrap(self):
"""Return the real VariableTracker if it already exists"""
if self.is_realized():
return self._cache.vt
return self
def is_realized(self):
return self._cache.vt is not None
def clone(self, **kwargs):
assert kwargs.get("_cache", self._cache) is self._cache
if kwargs.get("source", self.source) is not self.source:
self.realize()
return VariableTracker.clone(self.unwrap(), **kwargs)
def __str__(self) -> str:
if self.is_realized():
return self.unwrap().__str__()
return VariableTracker.__str__(self.unwrap())
def __getattr__(self, item):
return getattr(self.realize(), item)
# most methods are auto-generated below, these are the ones we want to exclude
visit = VariableTracker.visit # type: ignore[assignment]
__repr__ = VariableTracker.__repr__
@classmethod
def realize_all(
cls,
value,
cache=None,
):
"""
Walk an object and realize all LazyVariableTrackers inside it.
"""
if cache is None:
cache = {}
idx = id(value)
if idx in cache:
return cache[idx][0]
value_cls = type(value)
if issubclass(value_cls, LazyVariableTracker):
result = cls.realize_all(value.realize(), cache)
elif issubclass(value_cls, VariableTracker):
# update value in-place
result = value
value_dict = value.__dict__
nonvars = value._nonvar_fields
for key in value_dict:
if key not in nonvars:
value_dict[key] = cls.realize_all(value_dict[key], cache)
elif value_cls is list:
result = [cls.realize_all(v, cache) for v in value]
elif value_cls is tuple:
result = tuple(cls.realize_all(v, cache) for v in value)
elif value_cls in (dict, collections.OrderedDict):
result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
else:
result = value
# save `value` to keep it alive and ensure id() isn't reused
cache[idx] = (result, value)
return result
class LazySymNodeFormatString:
def __init__(
self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
) -> None:
from .constant import ConstantVariable
self.sym_node_var = sym_node_variable
self.fmt_var = ConstantVariable.create(
"{:" + fmt_spec_var.as_python_constant() + "}"
)
def __str__(self) -> str:
return str.format(
self.fmt_var.as_python_constant(),
str(self.sym_node_var.evaluate_expr()),
)
def _create_realize_and_forward(name):
@functools.wraps(getattr(VariableTracker, name))
def realize_and_forward(self, *args, **kwargs):
return getattr(self.realize(), name)(*args, **kwargs)
return realize_and_forward
def _populate():
for name, value in VariableTracker.__dict__.items():
if name not in LazyVariableTracker.__dict__:
if callable(value):
setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
_populate()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,355 @@
# mypy: ignore-errors
import weakref
from typing import Dict, List, TYPE_CHECKING
import torch
from torch.utils._pytree import tree_map_only
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
ConstDictKeySource,
GetItemSource,
GlobalWeakRefSource,
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .base import VariableTracker
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
_nonvar_fields = {
"grad_to_source",
"tensor_to_source",
"static_tensor_names",
*UserDefinedObjectVariable._nonvar_fields,
}
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
) -> None:
super().__init__(value, **kwargs)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
self.graph_break_if_pending_mutation(tx)
self.move_step_if_cpu()
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
mangled_name = f"__optimizer_{id(self.value)}"
tx.store_global_weakref_by_id(mangled_name, self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx: "InstructionTranslator", name):
# Note: this allows us to intercept the call in call_method
# in the typical case, we return a UserMethodVariable
# which will directly inline
if name in ("_init_group", "step"):
return GetAttrVariable(self, name, source=AttrSource(self.source, name))
if name == "param_groups":
from ..decorators import mark_static_address
for group in self.value.param_groups:
for p in group["params"]:
mark_static_address(p)
self._set_capturable(tx)
return super().var_getattr(tx, name)
def graph_break_if_pending_mutation(self, tx):
# If there are pending mutations on a parameter (due to using closure)
# then we need to graph break to allow the python version of the parameter
# to update, so that running _init_group will initialize the states with
# the correct values
for g in self.value.param_groups:
for p in g["params"]:
side_effects = tx.output.side_effects
variable = side_effects.id_to_variable.get(id(p), None)
if variable and side_effects.has_pending_mutation(variable):
from ..exc import Unsupported
raise Unsupported("Pending mutation on parameter")
def _set_capturable(self, tx):
from . import LazyVariableTracker
from .builder import VariableBuilder
# We only set capturable if params are on cuda
# and the state is not initialized
def safe_to_set_capturable(group):
all_uninitialized = True
all_gpu = True
for p in group.get("params", []):
all_gpu &= p.is_cuda or p.is_xpu
all_uninitialized &= p not in self.value.state
return "capturable" in group and all_uninitialized and all_gpu
# track indices to not set so we don't need to
# in the variable tracker realize the whole state
# we handle guarding the state specially
for ind, group in enumerate(self.value.param_groups):
if safe_to_set_capturable(group):
group["capturable"] = True
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)
for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker(
ConstantVariable.create("capturable")
)
param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
# If users load an old state dictionary,
# it's possible that step could be on the cpu
# if this is the case, move it to the GPU
# corresponding to the parameter
# in most cases this is a no-op because the state is empty
def move_step_if_cpu(self):
for p, state in self.value.state.items():
if "step" in state and state["step"].is_cpu:
state["step"] = state["step"].to(p.device)
def map_sources_and_install_guards(self, tx):
from ..decorators import mark_static_address
from .builder import VariableBuilder
from .lazy import LazyVariableTracker
self.grad_to_source = {}
self.tensor_to_source = {}
# Tracing the _init_group is expensive. But we still have to insert the
# necessary guards for _init_group. So, we manually handle insertion of
# guards. We also want to mark all the tensors inside the state dict to
# be static address.
# Mark all the tensors in the state dict to be static address. This has
# to be done first because the variable builder relies on the static
# address annotation.
def mark_static(x):
mark_static_address(x)
tree_map_only(torch.Tensor, mark_static, self.value.state)
# Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards.
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
# We need to realize the top level state dict to populate
# the guard locals
state_vt.realize()
# Populate self.grad_to_source and self.tensor_to_source so that we can
# manually update_list_args
for g_ind, (group, group_vt) in enumerate(
zip(self.value.param_groups, param_groups_vt.items)
):
# we assume here that all params within a param group
# are initialized similarly
if len(group["params"]) > 0:
for param in group["params"]:
if param.grad is not None:
key_index = None
for i, k in enumerate(self.value.state.keys()):
if k is param:
key_index = i
break
if key_index:
state_source = AttrSource(self.source, "state")
LazyVariableTracker.realize_all(
VariableBuilder(
tx,
GetItemSource(
state_source,
ConstDictKeySource(state_source, key_index),
),
)(self.value.state[param])
)
break
group_source = group_vt.source
params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = GradSource(
param_source,
"grad",
)
if p.grad is not None:
self.grad_to_source[p.grad] = grad_source
else:
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
# We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer.
state_source = AttrSource(self.source, "state")
for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
"""Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
result = builder(tensor_value)
return result
def update_list_args(
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
assert isinstance(
py_arg, list
), "py_arg should be a list in optimizer variable"
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder.create(tx, val))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)

View File

@ -0,0 +1,83 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import Dict
import torch
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable
def _raise_hard_error_if_graph_break(reason):
def deco(fn):
@functools.wraps(fn)
def graph_break_as_hard_error(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Unsupported as e:
raise UnsafeScriptObjectError(e.msg) from e
return graph_break_as_hard_error
return deco
class TorchScriptObjectVariable(UserDefinedObjectVariable):
_fake_script_object_cache: Dict[int, "TorchScriptObjectVariable"] = {}
@classmethod
def is_matching_cls(cls, user_cls: type):
return issubclass(user_cls, torch.ScriptObject)
@staticmethod
def create(proxy, value, **options):
return TorchScriptObjectVariable(proxy, value, **options)
def __init__(self, proxy, value, source, **kwargs) -> None:
super().__init__(value, **kwargs)
self.proxy = proxy
self.proxy.node.meta["example_value"] = value
self.source = source
def as_proxy(self):
return self.proxy
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def var_getattr(self, tx, name: str) -> VariableTracker:
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource
from .higher_order_ops import TorchHigherOrderOperatorVariable
method = getattr(self.value, name, None)
if method is None:
unimplemented(
f"FakeScriptObject doesn't define method {name}. Did you forget to implement it in the fake class?"
)
if not callable(method):
unimplemented(
"Only method calls on TorchScript objects can be supported safely."
" Please use method calls instead of attribute access."
)
return TorchHigherOrderOperatorVariable.make(
call_torchbind,
source=AttrSource(self.source, name),
script_obj_var=self,
method_name=name,
)
# We only support method calls on script objects. Interpreting the bytecodes
# should go through var_getattr then call_function instead of call_method.
#
# However, it's possible for call_method to be used directly e.g. for __setattr__.
@_raise_hard_error_if_graph_break(
"Dynamo cannot safely trace script object due to graph break."
)
def call_method(self, tx, name, args, kwargs):
unimplemented(f"call method {name} on script object is not safe.")

View File

@ -0,0 +1,97 @@
# mypy: ignore-errors
from inspect import getattr_static
from typing import TYPE_CHECKING
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
This is a read-only container."""
@staticmethod
def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
value.attn_mask
)
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
value_var,
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}
)
def __init__(self, proxy, param_vars, **kwargs) -> None:
self.proxy = proxy
self.param_vars = param_vars
super().__init__(**kwargs)
def reconstruct(self, codegen):
assert self.source is None
assert self.param_vars is not None
codegen.add_push_null(
lambda: codegen.load_import_from("torch._C", "_SDPAParams")
)
codegen.foreach(self.param_vars)
codegen.extend_output(create_call_function(len(self.param_vars), False))
def as_proxy(self):
return self.proxy
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
getattr_static(torch._C._SDPAParams, name)
except AttributeError:
# Using raise from is too verbose here
raise Unsupported(
f"Unsupported torch._C._SDPAParams attribute {name}"
) from None
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
if self.source is not None:
return wrap_fx_proxy(
tx=tx, proxy=proxy, source=AttrSource(self.source, name)
)
else:
return wrap_fx_proxy(tx=tx, proxy=proxy)
@staticmethod
def is_sdpa_params(value):
from torch.backends.cuda import SDPAParams
return value is SDPAParams

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,389 @@
# mypy: ignore-errors
import inspect
from typing import Dict, List, TYPE_CHECKING
import torch.utils._pytree as pytree
from torch._guards import Source
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches
# __torch_function__ on attribute accesses, method calls, and torch API calls.
# The following is not supported:
# - triggering __torch_function__ on tensor subclass non-tensor custom attributes
# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause
# excessive recompiles in certain degenerate cases
# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls
# The following is supported:
# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as
# any argument
# - triggering __torch_function__ on torch API calls with tensor subclass arguments
# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances
# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position
# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
# for more information on the design.
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
banned_attrs = [
fn.__self__.__name__
for fn in get_default_nowrap_functions()
if is_tensor_base_attr_getter(fn)
]
# Today set default device is placed in the graph and guarded on separately
# so we should not trace through it. In the future we can trace it once
# mode tracing is implemented and not put in the graph, but this is more
# of a BE project and can be evaluated later
IGNORED_MODES = {DeviceContext}
class TorchFunctionModeStackVariable(VariableTracker):
"""Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
# singleton value representing the global torch function mode stack
# singleton (it exists in C++)
stack_value_singleton = object()
# offset is used to track if we have inserted/removed a
# device context which is always placed at the bottom of the stack
# if a device context is inserted, the graph will run this mutation
# so when we want to reconstruct any other modes on the stack
# their indices should be shifted right by 1 (+1)
# Conversely, if there was a device context on the stack, and the graph
# mutates the stack to remove that context (set default device to None)
# each of the indices of other modes should be shifted left by 1 (-1)
offset = 0
def __init__(self, source, symbolic_stack):
self.source = source
self.symbolic_stack = symbolic_stack
@classmethod
def reset(cls):
cls.offset = 0
@classmethod
def register_mutation(cls, tx: "InstructionTranslator"):
if cls.stack_value_singleton not in tx.output.side_effects:
var = cls(
source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
)
tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
tx.output.side_effects.mutation(var)
@classmethod
def register_device_context_insertion(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
if stack and cls.is_device_context(stack[0]):
return
else:
cls.offset += 1
tx.symbolic_torch_function_mode_stack.insert(
0,
TorchFunctionModeVariable(
None, source=TorchFunctionModeStackSource(-cls.offset)
),
)
@classmethod
def clear_default_device(cls, tx: "InstructionTranslator"):
stack = tx.symbolic_torch_function_mode_stack
if stack and cls.is_device_context(stack[0]):
stack.popleft()
cls.offset -= 1
@staticmethod
def is_device_context(var):
return isinstance(var.value, DeviceContext) or var.value is None
@classmethod
def get_mode_index(cls, ind):
return ind + cls.offset
class TorchFunctionModeVariable(ContextWrappingVariable):
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
self.value = value
@staticmethod
def get_global_mangled_name(tx, val):
return get_safe_global_name(
tx, f"__torch_function_mode_{val.__class__.__name__}", val
)
def reconstruct(self, codegen):
# We don't support locally created torch function modes yet
assert self.source
self.source.reconstruct(codegen)
def _call_func(self, tx, values):
unimplemented("torch function mode context manager is not supported yet")
def _get_all_args(args, kwargs):
return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
def _flatten_vts(vts):
from collections import deque
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable
vts = deque(vts)
output = []
while vts:
vt = vts.pop()
LazyVariableTracker.realize_all(vt)
if isinstance(vt, ListVariable):
vts.extend(vt.items)
elif isinstance(vt, ConstDictVariable):
vts.extend(vt.items.values())
else:
output.append(vt)
return output
def _get_subclass_type(var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
return var.python_type()
def _get_subclass_type_var(tx: "InstructionTranslator", var):
assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable):
from .builder import SourcelessBuilder, VariableBuilder
if var.source:
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
import torch
overridden = False
try:
attr_val = inspect.getattr_static(var.python_type(), name)
overridden |= attr_val != getattr(torch.Tensor, name)
except AttributeError:
pass
return overridden
def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
):
from .builder import SourcelessBuilder
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
torch_function_type,
fn,
types,
SourcelessBuilder.create(tx, tuple(args)),
SourcelessBuilder.create(tx, kwargs),
)
return tx.inline_user_function_return(torch_function_var, tf_args, {})
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from .builder import SourcelessBuilder, VariableBuilder
if source:
return VariableBuilder(
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
return tx.output.torch_function_enabled and any(
has_torch_function(arg) for arg in _get_all_args(args, kwargs)
)
def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
"""Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
all_args = _get_all_args(args, kwargs)
overloaded_args = _get_overloaded_args(
[arg for arg in all_args if has_torch_function(arg)],
_get_subclass_type,
)
for arg in overloaded_args:
res = arg.call_torch_function(
tx,
fn,
TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
args,
kwargs,
)
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
unimplemented(
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
)
class TensorWithTFOverrideVariable(TensorVariable):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(self, *args, **kwargs) -> None:
self.torch_function_fn = kwargs.pop("torch_function_fn")
super().__init__(*args, **kwargs)
@classmethod
def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn):
import torch
kwargs = dict(tensor_var.__dict__)
assert (
kwargs.pop("class_type") is torch.Tensor
), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
var.install_global(tx)
return var
def install_global(self, tx):
# stash the subclass type to rewrap an output tensor if needed
# this is needed because the actual type needs to be available
# each time the compiled artifact is run and outputs a wrapped tensor.
if self.global_mangled_class_name(tx) not in tx.output.global_scope:
# Safe because global_mangled_class_name figures it out
tx.output.install_global_unsafe(
self.global_mangled_class_name(tx), self.class_type
)
def python_type(self):
return self.class_type
def class_type_var(self, tx):
return TensorSubclassVariable(
self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
)
def global_mangled_class_name(self, tx):
return get_safe_global_name(
tx, f"__subclass_{self.class_type.__name__}", self.class_type
)
def var_getattr(self, tx: "InstructionTranslator", name):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
if tx.output.torch_function_enabled and hasattr(torch.Tensor, name):
if self.source:
install_guard(
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
GuardBuilder.FUNCTION_MATCH
)
)
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
)
else:
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(
tx,
self.class_type_var(tx),
self.torch_function_fn,
fn,
types,
args,
kwargs,
)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
if tx.output.torch_function_enabled:
import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
# [Note: __torch_function__] Currently we only support methods that are defined on tensor
# we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality
# We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(self.python_type(), name))
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)

File diff suppressed because it is too large Load Diff