I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,518 @@
import builtins
import copy
import dataclasses
import inspect
import io
import os
import sys
import typing
import warnings
import zipfile
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
__all__ = [
"Constraint",
"Dim",
"ExportBackwardSignature",
"ExportGraphSignature",
"ExportedProgram",
"ModuleCallEntry",
"ModuleCallSignature",
"dims",
"export",
"export_for_training",
"load",
"register_dataclass",
"save",
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
]
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
def export_for_training(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
which can subsequently be executed with different inputs or serialized. The
traced graph (1) produces normalized operators in the all ATen operator set
(as well as any user-specified custom operators), (2) has eliminated all Python control
flow and data structures (with certain exceptions), and (3) records the set of
shape constraints needed to show that this normalization and control-flow elimination
is sound for future inputs. This API is intended for PT2 quantization training use cases
and will soon be the default IR of torch.export.export in the near future.
**Soundness Guarantee**
See :func:`export()` docstring for more details.
Args:
mod: We will trace the forward method of this module.
args: Example positional inputs.
kwargs: Optional example keyword inputs.
dynamic_shapes:
An optional argument where the type should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
strict: When enabled (default), the export function will trace the program through
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
exported program will not validate the implicit assumptions baked into the graph and
may cause behavior divergence between the original model and the exported one. This is
useful when users need to workaround bugs in the tracer, or simply want incrementally
enable safety in their models. Note that this does not affect the resulting IR spec
to be different and the model will be serialized in the same way regardless of what value
is passed here.
WARNING: This option is experimental and use this at your own risk.
Returns:
An :class:`ExportedProgram` containing the traced callable.
**Acceptable input/output types**
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
``OrderedDict`` containing all above types.
"""
from ._trace import _export_for_training
if not isinstance(mod, torch.nn.Module):
raise ValueError(
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
)
if isinstance(mod, torch.jit.ScriptModule):
raise ValueError(
"Exporting a ScriptModule is not supported. "
"Maybe try converting your ScriptModule to an ExportedProgram "
"using `TS2EPConverter(mod, args, kwargs).convert()` instead."
)
return _export_for_training(
mod,
args,
kwargs,
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
)
def export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
strict: bool = True,
preserve_module_call_signature: Tuple[str, ...] = (),
) -> ExportedProgram:
"""
:func:`export` takes an arbitrary Python callable (an nn.Module, a function or
a method) along with example inputs, and produces a traced graph representing
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
which can subsequently be executed with different inputs or serialized. The
traced graph (1) produces normalized operators in the functional ATen operator set
(as well as any user-specified custom operators), (2) has eliminated all Python control
flow and data structures (with certain exceptions), and (3) records the set of
shape constraints needed to show that this normalization and control-flow elimination
is sound for future inputs.
**Soundness Guarantee**
While tracing, :func:`export()` takes note of shape-related assumptions
made by the user program and the underlying PyTorch operator kernels.
The output :class:`ExportedProgram` is considered valid only when these
assumptions hold true.
Tracing makes assumptions on the shapes (not values) of input tensors.
Such assumptions must be validated at graph capture time for :func:`export`
to succeed. Specifically:
- Assumptions on static shapes of input tensors are automatically validated without additional effort.
- Assumptions on dynamic shape of input tensors require explicit specification
by using the :func:`Dim` API to construct dynamic dimensions and by associating
them with example inputs through the ``dynamic_shapes`` argument.
If any assumption can not be validated, a fatal error will be raised. When that happens,
the error message will include suggested fixes to the specification that are needed
to validate the assumptions. For example :func:`export` might suggest the
following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the
shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``::
dim = Dim("dim0_x", max=5)
This example means the generated code requires dimension 0 of input ``x`` to be less
than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension
definitions and then copy them verbatim into your code without needing to change the
``dynamic_shapes`` argument to your :func:`export` call.
Args:
mod: We will trace the forward method of this module.
args: Example positional inputs.
kwargs: Optional example keyword inputs.
dynamic_shapes:
An optional argument where the type should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
strict: When enabled (default), the export function will trace the program through
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
exported program will not validate the implicit assumptions baked into the graph and
may cause behavior divergence between the original model and the exported one. This is
useful when users need to workaround bugs in the tracer, or simply want incrementally
enable safety in their models. Note that this does not affect the resulting IR spec
to be different and the model will be serialized in the same way regardless of what value
is passed here.
WARNING: This option is experimental and use this at your own risk.
Returns:
An :class:`ExportedProgram` containing the traced callable.
**Acceptable input/output types**
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
``OrderedDict`` containing all above types.
"""
from ._trace import _export
if not isinstance(mod, torch.nn.Module):
raise ValueError(
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
)
if isinstance(mod, torch.jit.ScriptModule):
raise ValueError(
"Exporting a ScriptModule is not supported. "
"Maybe try converting your ScriptModule to an ExportedProgram "
"using `TS2EPConverter(mod, args, kwargs).convert()` instead."
)
return _export(
mod,
args,
kwargs,
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
)
def save(
ep: ExportedProgram,
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
opset_version: Optional[Dict[str, int]] = None,
) -> None:
"""
.. warning::
Under active development, saved files may not be usable in newer versions
of PyTorch.
Saves an :class:`ExportedProgram` to a file-like object. It can then be
loaded using the Python API :func:`torch.export.load <torch.export.load>`.
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): Map from filename to contents
which will be stored as part of f.
opset_version (Optional[Dict[str, int]]): A map of opset names
to the version of this opset
Example::
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
ep = torch.export.export(MyModule(), (torch.randn(5),))
# Save to file
torch.export.save(ep, 'exported_program.pt2')
# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.export.save(ep, buffer)
# Save with extra files
extra_files = {'foo.txt': b'bar'.decode('utf-8')}
torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
"""
if not isinstance(ep, ExportedProgram):
raise TypeError(
f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead."
)
from torch._export.serde.schema import SCHEMA_VERSION
from torch._export.serde.serialize import serialize, SerializedArtifact
artifact: SerializedArtifact = serialize(ep, opset_version)
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
with zipfile.ZipFile(f, "w") as zipf:
# Save every field in the SerializedArtifact to a file.
assert isinstance(artifact.exported_program, bytes)
zipf.writestr("serialized_exported_program.json", artifact.exported_program)
zipf.writestr("serialized_state_dict.pt", artifact.state_dict)
zipf.writestr("serialized_constants.pt", artifact.constants)
zipf.writestr("serialized_example_inputs.pt", artifact.example_inputs)
zipf.writestr("version", ".".join(map(str, SCHEMA_VERSION)))
# Add extra files if provided
if extra_files:
for extra_file_name, content in extra_files.items():
encoded_content = content.encode("utf-8")
zipf.writestr(f"extra_files/{extra_file_name}", encoded_content)
def load(
f: Union[str, os.PathLike, io.BytesIO],
*,
extra_files: Optional[Dict[str, Any]] = None,
expected_opset_version: Optional[Dict[str, int]] = None,
) -> ExportedProgram:
"""
.. warning::
Under active development, saved files may not be usable in newer versions
of PyTorch.
Loads an :class:`ExportedProgram` previously saved with
:func:`torch.export.save <torch.export.save>`.
Args:
ep (ExportedProgram): The exported program to save.
f (Union[str, os.PathLike, io.BytesIO): A file-like object (has to
implement write and flush) or a string containing a file name.
extra_files (Optional[Dict[str, Any]]): The extra filenames given in
this map would be loaded and their content would be stored in the
provided map.
expected_opset_version (Optional[Dict[str, int]]): A map of opset names
to expected opset versions
Returns:
An :class:`ExportedProgram` object
Example::
import torch
import io
# Load ExportedProgram from file
ep = torch.export.load('exported_program.pt2')
# Load ExportedProgram from io.BytesIO object
with open('exported_program.pt2', 'rb') as f:
buffer = io.BytesIO(f.read())
buffer.seek(0)
ep = torch.export.load(buffer)
# Load with extra files.
extra_files = {'foo.txt': ''} # values will be replaced with data
ep = torch.export.load('exported_program.pt2', extra_files=extra_files)
print(extra_files['foo.txt'])
print(ep(torch.randn(5)))
"""
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
extra_files = extra_files or {}
with zipfile.ZipFile(f, "r") as zipf:
# Check the version
version = zipf.read("version").decode().split(".")
from torch._export.serde.schema import SCHEMA_VERSION
assert len(version) == len(SCHEMA_VERSION)
if version[0] != str(SCHEMA_VERSION[0]):
raise RuntimeError(
f"Serialized version {version} does not match our current "
f"schema version {SCHEMA_VERSION}."
)
from torch._export.serde.serialize import deserialize, SerializedArtifact
# Load serialized_ep and serialized_state_dict from the zip file
serialized_exported_program: Optional[bytes] = None
serialized_state_dict: Optional[bytes] = None
serialized_constants: Optional[bytes] = None
serialized_example_inputs: Optional[bytes] = None
for file_info in zipf.infolist():
file_content = zipf.read(file_info.filename)
if file_info.filename == "serialized_exported_program.json":
serialized_exported_program = file_content
elif file_info.filename == "serialized_state_dict.json":
warnings.warn("This version of file is deprecated")
serialized_state_dict = file_content
elif file_info.filename == "serialized_constants.json":
warnings.warn("This version of file is deprecated")
serialized_constants = file_content
elif file_info.filename == "serialized_state_dict.pt":
serialized_state_dict = file_content
elif file_info.filename == "serialized_constants.pt":
serialized_constants = file_content
elif file_info.filename == "serialized_example_inputs.pt":
serialized_example_inputs = file_content
elif file_info.filename.startswith("extra_files"):
filename = file_info.filename.split("/", 1)[1]
extra_files[filename] = file_content.decode("utf-8")
assert serialized_exported_program is not None
assert serialized_state_dict is not None
assert serialized_constants is not None
assert serialized_example_inputs is not None
artifact: SerializedArtifact = SerializedArtifact(
serialized_exported_program,
serialized_state_dict,
serialized_constants,
serialized_example_inputs,
)
# Deserialize ExportedProgram
ep = deserialize(artifact, expected_opset_version)
return ep
def register_dataclass(
cls: Type[Any],
*,
serialized_type_name: Optional[str] = None,
) -> None:
"""
Registers a dataclass as a valid input/output type for :func:`torch.export.export`.
Args:
cls: the dataclass type to register
serialized_type_name: The serialized name for the dataclass. This is
required if you want to serialize the pytree TreeSpec containing this
dataclass.
Example::
@dataclass
class InputDataClass:
feature: torch.Tensor
bias: int
class OutputDataClass:
res: torch.Tensor
torch.export.register_dataclass(InputDataClass)
torch.export.register_dataclass(OutputDataClass)
def fn(o: InputDataClass) -> torch.Tensor:
res = res=o.feature + o.bias
return OutputDataClass(res=res)
ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), ))
print(ep)
"""
from torch._export.utils import register_dataclass_as_pytree_node
return register_dataclass_as_pytree_node(
cls, serialized_type_name=serialized_type_name
)

View File

@ -0,0 +1,52 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch._higher_order_ops.auto_functionalize import (
auto_functionalized,
auto_functionalized_v2,
)
from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized
from torch.export import ExportedProgram
def remove_self_clone(graph: torch.fx.Graph):
for node in graph.nodes:
if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]:
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
def unsafe_remove_auto_functionalized_pass(
ep: ExportedProgram,
) -> ExportedProgram:
"""
This pass removes an instances of the higher order op 'auto_functionalized',
and modifies the calling EP inplace to have the original mutator op.
This pass doesn't perform safety checks to make sure that this inplace mutation is safe.
"""
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in ep.graph.nodes:
if (
node.op == "call_function" and node.target is auto_functionalized
) or (
node.op == "call_function" and node.target is auto_functionalized_v2
):
func = node.args[0]
assert isinstance(func, torch._ops.OpOverload)
# re-inplace everything
node.meta["only_clone_these_tensors"] = []
decompose_auto_functionalized(ep.graph)
remove_self_clone(ep.graph)
ep.graph.eliminate_dead_code()
return ep

View File

@ -0,0 +1,161 @@
# mypy: allow-untyped-defs
import operator
from typing import List
import torch
from torch._higher_order_ops.effects import _get_schema, with_effects
from .exported_program import ExportedProgram
from .graph_signature import (
CustomObjArgument,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TokenArgument,
)
def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
):
inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs
output_node = None
with_effect_nodes: List[torch.fx.Node] = []
# Output node need to check its args agianst output_token_names (collected from output_spec)
# Therefore, we only need to find the top-levele output node
output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output")))
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if not (node.op == "call_function" and node.target is with_effects):
continue
with_effect_nodes.append(node)
# Remove tokens from outputs
assert output_node is not None
output_args = output_node.args[0]
assert len(output_args) >= num_tokens
out_token_nodes = output_args[:num_tokens]
output_node.args = (tuple(output_args[num_tokens:]),)
for out_token in out_token_nodes:
assert out_token.name in output_token_names
out_token.users.clear()
ep.graph.erase_node(out_token)
# Replace with_effects(token, func, args) with just func(args)
for node in reversed(with_effect_nodes):
func = node.args[1]
assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator))
if func == torch.ops.higher_order.call_torchbind:
custom_obj_meta = node.args[2].meta["val"]
assert isinstance(custom_obj_meta, CustomObjArgument)
if custom_obj_meta.fake_val:
custom_obj = custom_obj_meta.fake_val
elif node.args[2].name in inputs_to_lifted_custom_objs:
custom_obj = ep.constants[
inputs_to_lifted_custom_objs[node.args[2].name]
]
else:
raise RuntimeError(f"Unable to find custom obj for node {node}")
schema = _get_schema(func, (custom_obj,) + node.args[3:])
else:
schema = _get_schema(func, node.args[2:])
with ep.graph.inserting_before(node):
new_node = ep.graph.call_function(func, node.args[2:], node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
node.replace_all_uses_with(new_node)
# Update user getitem nodes
for user in list(new_node.users.keys()):
assert user.target == operator.getitem
# getitem(with_effects, 0) == token
if user.args[1] == 0:
ep.graph.erase_node(user)
if len(schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(with_effects, 1) with just the note itself.
for user in list(new_node.users.keys()):
assert user.args[1] == 1
user.replace_all_uses_with(new_node)
new_node.meta["val"] = node.meta["val"][1]
elif len(schema.returns) > 1:
# If the function has more than 1 return then since we got rid of
# the 1st return value (the token), we need to bump all the other
# getitem calls by 1 down
for user in list(new_node.users.keys()):
assert user.args[1] >= 1
user.args = (user.args[0], user.args[1] - 1)
new_node.meta["val"] = node.meta["val"][1:]
else:
assert len(schema.returns) == 0
assert len(new_node.users) == 0
new_node.meta["val"] = None
ep.graph.erase_node(node)
# Remove tokens from inputs
placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"]
assert len(placeholders) >= num_tokens
inp_token_nodes = placeholders[:num_tokens]
for inp_token in inp_token_nodes:
assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token)
ep.graph.eliminate_dead_code()
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
"""
Removes the existance of tokens from the exported program, including:
- Removes the input and output tokens
- Replaces with_effects(token, func, args) with just func(args)
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
assert isinstance(inp.arg, TokenArgument)
input_token_names.append(inp.arg.name)
else:
new_input_specs.append(inp)
num_out_tokens: int = 0
new_output_specs: List[OutputSpec] = []
output_token_names: List[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1
output_token_names.append(out.arg.name)
else:
new_output_specs.append(out)
# Update graph signature
ep.graph_signature.input_specs = new_input_specs
ep.graph_signature.output_specs = new_output_specs
assert num_tokens == num_out_tokens
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
)
return ep

View File

@ -0,0 +1,44 @@
# mypy: allow-untyped-defs
import torch
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.overrides import TorchFunctionMode
class AutogradStateOpsFailSafeguard(TorchFunctionMode):
"""
Detect grad state ops during exporting the graph and fail the process by
raising an error, to avoid unexpected behavior. Those grad mode ops could be:
`torch.no_grad`
`torch.enable_grad`
`torch.set_grad_enabled`
Export with predispatch mode is exempted.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
unsupported_grad_mode_ops = [
torch._C._set_grad_enabled,
]
# It's only enabled while tracing, by confirming the torch dispatch mode is
# any active PROXY. This is to allow the autograd ops out of tracing.
current_state = torch._C.is_grad_enabled()
if func in unsupported_grad_mode_ops:
assert len(args) == 1
changed_state = args[0]
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
# Intend to check if it's not the pre_dispatch mode. It's allowed to use
# autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
if (
mode
and isinstance(mode, ProxyTorchDispatchMode)
and not mode.pre_dispatch
and changed_state != current_state
):
raise RuntimeError(
f"Encountered autograd state manager op {func} trying to change global autograd state "
"while exporting. This is unsafe because we don't capture this op in torch.export "
"today, hence we can't reflect the user intention soundly. You can fix this by "
"adding a torch.no_grad() context around the export call."
)
return func(*args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,64 @@
from typing import Any, Callable, Dict, Optional
from torch.utils._pytree import Context, TreeSpec
def reorder_kwargs(user_kwargs: Dict[str, Any], spec: TreeSpec) -> Dict[str, Any]:
"""Reorder user-provided kwargs to match the order in `spec`. `spec` is
expected to be the in_spec of an exported program, i.e. the spec that
results from flattening `(args, kwargs)`.
We need this to provide consistent input ordering, such so that users can
pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result.
"""
# Make sure that the spec is actually shaped like (args, kwargs)
assert spec.type is tuple
assert spec.num_children == 2
kwargs_spec = spec.children_specs[1]
assert kwargs_spec.type is dict
if set(user_kwargs) != set(kwargs_spec.context):
raise ValueError(
f"kwarg key mismatch: "
f"Got {list(user_kwargs)} but expected {kwargs_spec.context}"
)
reordered_kwargs = {}
for kw in kwargs_spec.context:
reordered_kwargs[kw] = user_kwargs[kw]
return reordered_kwargs
def is_equivalent(
spec1: TreeSpec,
spec2: TreeSpec,
equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool],
) -> bool:
"""Customizable equivalence check for two TreeSpecs.
Arguments:
spec1: The first TreeSpec to compare
spec2: The second TreeSpec to compare
equivalence_fn: A function to determine the equivalence of two
TreeSpecs by examining their types and contexts. It will be called like:
equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context)
This function will be applied recursively to all children.
Returns:
True if the two TreeSpecs are equivalent, False otherwise.
"""
if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context):
return False
# Recurse on children
if len(spec1.children_specs) != len(spec2.children_specs):
return False
for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs):
if not is_equivalent(child_spec1, child_spec2, equivalence_fn):
return False
return True

View File

@ -0,0 +1,361 @@
# mypy: allow-untyped-defs
import copy
import warnings
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind, _recursive_getattr
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ._remove_effect_tokens_pass import _remove_effect_tokens
from .exported_program import (
ExportedProgram,
ExportGraphSignature,
InputKind,
OutputKind,
)
@torch._dynamo.disable
def _check_input_constraints_pre_hook(self, *args, **kwargs):
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args)
if received_spec != self._in_spec:
raise ValueError( # noqa: B904
"Trying to flatten user inputs with exported input tree spec: \n"
f"{self._in_spec}\n"
"but actually got inputs with tree spec of: \n"
f"{received_spec}"
)
return _check_input_constraints_for_graph(
[node for node in self.graph.nodes if node.op == "placeholder"],
flat_args_with_path,
self.range_constraints,
)
def _unlift_inputs_as_getattr(
gm: torch.fx.GraphModule,
lifted_inputs: List[Optional[str]],
) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
"""
Unlift inputs referring to params/buffers/constants as getattr nodes in the
graph
"""
unlifted_name_to_node = {}
input_name_to_node = {}
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
assert len(lifted_inputs) == len(placeholder_nodes)
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
if lifted_node is None:
input_name_to_node[input_node.name] = input_node
else:
with gm.graph.inserting_after(input_node):
getattr_node = gm.graph.get_attr(lifted_node)
input_node.replace_all_uses_with(getattr_node)
metadata = input_node.meta
gm.graph.erase_node(input_node)
getattr_node.meta = metadata
unlifted_name_to_node[lifted_node] = getattr_node
return unlifted_name_to_node, input_name_to_node
def _insert_copy_for_mutations(
gm: torch.fx.GraphModule,
mutated_outputs: List[Optional[str]],
unlifted_name_to_node: Dict[str, torch.fx.Node],
input_name_to_node: Dict[str, torch.fx.Node],
) -> None:
"""
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
assert output_node is not None
outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)
user_output_nodes = []
return_nodes_to_copy = {}
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
if mutated_node_name is None:
user_output_nodes.append(return_node)
continue
if mutated_node_name in unlifted_name_to_node:
mutated_node = unlifted_name_to_node[mutated_node_name]
elif mutated_node_name in input_name_to_node:
mutated_node = input_name_to_node[mutated_node_name]
else:
raise RuntimeError(
f"Could not find {mutated_node_name} in either buffer or input nodes"
)
with gm.graph.inserting_before(output_node):
copy_node = gm.graph.call_function(
torch.ops.aten.copy_.default, (mutated_node, return_node)
)
return_nodes_to_copy[return_node] = copy_node
output_args = [
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
for node in user_output_nodes
]
with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(tuple(output_args))
new_output.meta.update(output_node.meta)
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)
def _get_codegen(
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
forward_arg_names: Optional[List[str]] = None,
) -> _PyTreeCodeGen:
"""
Create the codegen for the graph module based on the in/out specs
"""
if forward_arg_names:
names = forward_arg_names
else:
if (
in_spec.type == tuple
and in_spec.num_children == 2
and in_spec.children_specs[0].type == tuple
and in_spec.children_specs[1].type == dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
# add kwarg names
names.extend(in_spec.children_specs[1].context)
else:
names = [f"arg_{i}" for i in range(in_spec.num_children)]
return _PyTreeCodeGen(
_PyTreeInfo(
names,
in_spec,
out_spec,
)
)
def _unlift(
gm: torch.fx.GraphModule,
lifted_inputs: List[Optional[str]],
mutated_outputs: List[Optional[str]],
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
state_dict: Dict[str, Any],
constants: Dict[str, Any],
forward_arg_names: Optional[List[str]] = None,
):
"""
Args:
lifted_inputs: A list matching the graph module's input nodes. For
an input node that is referring to a lifted parameter/buffer, this
list will contain the fqn the corresponding attribute. Otherwise, this
list will contain None. This is used to unlift the lifted parameters as
get_attr nodes.
mutated_outputs: A list matching the graph module's output nodes. For
an output node that is referring to a mutated buffer or user input, this
list will contain the name of the corresponding buffer or user input
that needs to be mutated. Otherwise, this list will contain None. This
is used to re-insert an inplace copy_ operator to copy the mutated
values back to the original node.
"""
unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
gm, lifted_inputs
)
_insert_copy_for_mutations(
gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
)
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
gm.graph.lint()
gm.recompile()
return gm
def _register_attrs_to_new_gm(
new_gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Any],
constants: Dict[str, Any],
) -> None:
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
for name in graph_signature.buffers:
if name in non_persistent_buffers:
persistent = False
value = constants[name]
else:
persistent = True
value = state_dict[name]
_assign_attr(
value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
)
for name in graph_signature.parameters:
value = state_dict[name]
_assign_attr(
value,
new_gm,
name,
attr_kind=_AttrKind.PARAMETER,
)
for name in chain(
graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
):
value = constants[name]
_assign_attr(
value,
new_gm,
name,
attr_kind=_AttrKind.CONSTANT,
)
class _StatefulGraphModuleFactory(type):
"""
Metaclass that ensures a private constructor for _StatefulGraphModule
"""
def __call__(cls, *args, **kwargs):
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
)
def _create(cls, root, graph, range_constraints=None):
return super().__call__(
root,
graph,
range_constraints=range_constraints,
)
class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
def __init__(self, root, graph, range_constraints=None):
super().__init__(root, graph)
# Need to fix up non-persistent buffers.
self.range_constraints = range_constraints or []
def _create_stateful_graph_module(
plain_graph_module: torch.fx.GraphModule,
range_constraints,
# TODO(suo) this should not be optional, but is since we still ahve
# capture_pre_autograd_graph grr
graph_signature: Optional[ExportGraphSignature] = None,
):
stateful_gm = _StatefulGraphModule._create(
plain_graph_module,
plain_graph_module.graph,
range_constraints=range_constraints,
)
stateful_gm.register_forward_pre_hook(
_check_input_constraints_pre_hook, with_kwargs=True
)
if graph_signature is None:
return stateful_gm
# Fix up lifted tensor constants.
# fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module
# into a buffer in stateful_gm and creates an inconsistency with graph_signature.
# We fix this by de-registering these buffers in lifted_tensor_constants
# and call _assign_attr(attr_kind=CONSTANT) to register them as constants.
for constant_fqn in graph_signature.lifted_tensor_constants:
# Sometimes, the constant can require gradient, this is probably a bug in user code,
# e.g. `self.const = torch.randn(2, 2, requires_grad=True)`.
# We call detach on the constant_val since they're tensor contants and we don't need to
# compute their gradients anyway.
# Users should properly register it as parameter if they want it to require gradient.
buffer = stateful_gm.get_buffer(constant_fqn)
if buffer.requires_grad:
warnings.warn(
f"A model attribute `{constant_fqn}` requires gradient. "
f"but it's not properly registered as a parameter. "
f"torch.export will detach it and treat it as a constant tensor "
f"but please register it as parameter instead."
)
buffer = buffer.detach()
*prefix, field = constant_fqn.rsplit(".")
submod = _recursive_getattr(stateful_gm, prefix)
delattr(submod, field)
_assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT)
# Fix up non-persistent buffers. torch.fx does not distinguish between
# persistent and non-persistent buffers, so we must restore that distinction
# here.
for buffer in graph_signature.non_persistent_buffers:
_assign_attr(
plain_graph_module.get_buffer(buffer),
stateful_gm,
buffer,
attr_kind=_AttrKind.BUFFER,
persistent=False,
)
return stateful_gm
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
ep = _remove_effect_tokens(ep)
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
forward_arg_names = ep.graph_module.meta.get("forward_arg_names")
lifted_inputs: List[Optional[str]] = [
(
in_spec.target
if in_spec.kind
in (
InputKind.BUFFER,
InputKind.CONSTANT_TENSOR,
InputKind.PARAMETER,
InputKind.CUSTOM_OBJ,
)
else None
)
for in_spec in ep.graph_signature.input_specs
]
mutated_outputs: List[Optional[str]] = [
(
out_spec.target
if out_spec.kind
in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
else None
)
for out_spec in ep.graph_signature.output_specs
]
new_gm = _unlift(
new_gm,
lifted_inputs,
mutated_outputs,
ep.call_spec.in_spec,
ep.call_spec.out_spec,
ep.state_dict,
ep.constants,
forward_arg_names=forward_arg_names,
)
unlift_gm = _create_stateful_graph_module(
new_gm, ep.range_constraints, ep.graph_signature
)
unlift_gm.meta.update(ep.graph_module.meta)
return unlift_gm

View File

@ -0,0 +1,16 @@
from dataclasses import dataclass
__all__ = ["ScriptObjectMeta"]
@dataclass
class ScriptObjectMeta:
"""
Metadata which is stored on nodes representing ScriptObjects.
"""
# Key into constants table to retrieve the real ScriptObject.
constant_name: str
class_fqn: str

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,67 @@
import copy
import typing
import torch
from torch.export.exported_program import _decompose_exported_program
def _copy_graph_module_and_signature(
ep: torch.fx.GraphModule,
) -> typing.Tuple[
torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature
]:
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
# and this can break placeholder names in some particular cases.
# For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'.
# So we manually overwrite placeholder names by reading the old graph.
gm = copy.deepcopy(ep.graph_module)
new_graph_signature = copy.deepcopy(ep.graph_signature)
# iterate over old/new graph modules
for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()):
old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"]
new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"]
# iterate over placeholders
assert len(old_phs) == len(new_phs)
for old_node, new_node in zip(old_phs, new_phs):
new_node.name = old_node.name
return gm, new_graph_signature
def _remove_detach_pass(
gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature
) -> None:
with gm._set_replace_hook(sig.get_replace_hook()):
for node in list(reversed(gm.graph.nodes)):
if node.op != "call_function":
continue
if (
node.target == torch.ops.aten.detach.default
and len(node.users) == 1
and next(iter(node.users)).target == torch.ops.aten.detach.default
):
next(iter(node.users)).replace_all_uses_with(node)
gm.graph.eliminate_dead_code()
gm.recompile()
def _export_forward_backward(
ep: torch.export.ExportedProgram, joint_loss_index: int = 0
) -> torch.export.ExportedProgram:
"""
WARNING: This API is highly unstable and will be subject to change in the future.
"""
from torch._decomp import core_aten_decompositions
ep = _decompose_exported_program(
ep,
decomp_table=core_aten_decompositions(),
_preserve_ops=(), # type: ignore[arg-type]
joint_loss_index=joint_loss_index,
)
gm, new_graph_signature = _copy_graph_module_and_signature(ep)
_remove_detach_pass(gm, new_graph_signature)
return ep._update(gm, new_graph_signature)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,593 @@
# mypy: allow-untyped-defs
import dataclasses
from enum import auto, Enum
from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union
from torch._library.fake_class_registry import FakeScriptObject
if TYPE_CHECKING:
import torch
from torch._functorch._aot_autograd.schemas import GraphSignature
__all__ = [
"ConstantArgument",
"CustomObjArgument",
"ExportBackwardSignature",
"ExportGraphSignature",
"InputKind",
"InputSpec",
"OutputKind",
"OutputSpec",
"SymIntArgument",
"TensorArgument",
]
@dataclasses.dataclass
class TensorArgument:
name: str
@dataclasses.dataclass
class TokenArgument:
name: str
@dataclasses.dataclass
class SymIntArgument:
name: str
@dataclasses.dataclass
class CustomObjArgument:
name: str
class_fqn: str
fake_val: Optional[FakeScriptObject] = None
@dataclasses.dataclass
class ConstantArgument:
name: str
value: Union[int, float, bool, str, None]
ArgumentSpec = Union[
TensorArgument,
SymIntArgument,
ConstantArgument,
CustomObjArgument,
TokenArgument,
]
class InputKind(Enum):
USER_INPUT = auto()
PARAMETER = auto()
BUFFER = auto()
CONSTANT_TENSOR = auto()
CUSTOM_OBJ = auto()
TOKEN = auto()
@dataclasses.dataclass
class InputSpec:
kind: InputKind
arg: ArgumentSpec
target: Optional[str]
persistent: Optional[bool] = None
def __post_init__(self):
if self.kind == InputKind.BUFFER:
assert (
self.persistent is not None
), "Failed to specify persistent flag on BUFFER."
assert isinstance(
self.arg,
(
TensorArgument,
SymIntArgument,
ConstantArgument,
CustomObjArgument,
TokenArgument,
),
), f"got {type(self.arg)}"
class OutputKind(Enum):
USER_OUTPUT = auto()
LOSS_OUTPUT = auto()
BUFFER_MUTATION = auto()
GRADIENT_TO_PARAMETER = auto()
GRADIENT_TO_USER_INPUT = auto()
USER_INPUT_MUTATION = auto()
TOKEN = auto()
@dataclasses.dataclass
class OutputSpec:
kind: OutputKind
arg: ArgumentSpec
target: Optional[str]
def __post_init__(self):
assert isinstance(
self.arg,
(
TensorArgument,
SymIntArgument,
ConstantArgument,
TokenArgument,
CustomObjArgument,
),
), self.arg
@dataclasses.dataclass
class ExportBackwardSignature:
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
loss_output: str
@dataclasses.dataclass
class ExportGraphSignature:
"""
:class:`ExportGraphSignature` models the input/output signature of Export Graph,
which is a fx.Graph with stronger invariants gurantees.
Export Graph is functional and does not access "states" like parameters
or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
gurantees that parameters, buffers, and constant tensors are lifted out of
the graph as inputs. Similarly, any mutations to buffers are not included
in the graph either, instead the updated values of mutated buffers are
modeled as additional outputs of Export Graph.
The ordering of all inputs and outputs are::
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]
e.g. If following module is exported::
class CustomModule(nn.Module):
def __init__(self) -> None:
super(CustomModule, self).__init__()
# Define a parameter
self.my_parameter = nn.Parameter(torch.tensor(2.0))
# Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0) # In-place addition
return output
Resulting Graph would be::
graph():
%arg0_1 := placeholder[target=arg0_1]
%arg1_1 := placeholder[target=arg1_1]
%arg2_1 := placeholder[target=arg2_1]
%arg3_1 := placeholder[target=arg3_1]
%arg4_1 := placeholder[target=arg4_1]
%add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
%mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
%mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
%add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
%add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
return (add_tensor_2, add_tensor_1)
Resulting ExportGraphSignature would be::
ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
]
)
"""
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
# A list of parameters uniquely identified by mangled fully qualified name
@property
def parameters(self) -> Collection[str]:
return tuple(
s.target
for s in self.input_specs
if s.kind == InputKind.PARAMETER
if isinstance(s.target, str)
)
# A list of buffers uniquely identified by mangled fully qualified name
@property
def buffers(self) -> Collection[str]:
return tuple(
s.target
for s in self.input_specs
if s.kind == InputKind.BUFFER
if isinstance(s.target, str)
)
@property
def non_persistent_buffers(self) -> Collection[str]:
return tuple(
s.target
for s in self.input_specs
if s.kind == InputKind.BUFFER
if s.persistent is False
if isinstance(s.target, str)
)
# A list of lifted constant tensors
@property
def lifted_tensor_constants(self) -> Collection[str]:
return tuple(
s.target
for s in self.input_specs
if s.kind == InputKind.CONSTANT_TENSOR
if isinstance(s.target, str)
)
@property
def lifted_custom_objs(self) -> Collection[str]:
return tuple(
s.target
for s in self.input_specs
if s.kind == InputKind.CUSTOM_OBJ
if isinstance(s.target, str)
)
# Graph node names of pytree-flattened inputs of original program
@property
def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]:
user_inputs: List[Union[int, float, bool, None, str]] = []
for s in self.input_specs:
if s.kind != InputKind.USER_INPUT:
continue
if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)):
user_inputs.append(s.arg.name)
elif isinstance(s.arg, ConstantArgument):
user_inputs.append(s.arg.value)
else:
raise RuntimeError(f"{s.arg} is not a valid user inputs")
return tuple(user_inputs)
# Graph node names of pytree-flattened outputs of original program
@property
def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]:
user_outputs: List[Union[int, float, bool, None, str]] = []
for s in self.output_specs:
if s.kind != OutputKind.USER_OUTPUT:
continue
if isinstance(s.arg, (TensorArgument, SymIntArgument)):
user_outputs.append(s.arg.name)
elif isinstance(s.arg, ConstantArgument):
user_outputs.append(s.arg.value)
elif isinstance(s.arg, CustomObjArgument):
user_outputs.append(s.arg.name)
else:
raise RuntimeError(f"{s.arg} is not a valid user output")
return tuple(user_outputs)
# A dictionary mapping graph input node names to parameters. If a graph input
# name is found in this dictionary, it is guranteed to be a lifted parameter.
@property
def inputs_to_parameters(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.input_specs
if s.kind == InputKind.PARAMETER
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
# A dictionary mapping graph input node names to buffers. If a graph input
# name is found in this dictionary, it is guranteed to be a lifted buffer.
@property
def inputs_to_buffers(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target) # type: ignore[union-attr, misc]
for s in self.input_specs
if s.kind == InputKind.BUFFER
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
# A dictionary mapping graph output node names to buffers that are mutated in the
# original program. Buffers that are not mutated will not be found in this dictionary.
@property
def buffers_to_mutate(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.output_specs
if s.kind == OutputKind.BUFFER_MUTATION
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
@property
def user_inputs_to_mutate(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.output_specs
if s.kind == OutputKind.USER_INPUT_MUTATION
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
# A dictionary mapping graph input node names to lifted tensor constants.
@property
def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.input_specs
if s.kind == InputKind.CONSTANT_TENSOR
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
)
@property
def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]:
return _immutable_dict(
(s.arg.name, s.target)
for s in self.input_specs
if s.kind == InputKind.CUSTOM_OBJ
and isinstance(s.arg, CustomObjArgument)
and isinstance(s.target, str)
)
@property
def backward_signature(self) -> Optional[ExportBackwardSignature]:
loss_output = None
gradients_to_parameters: Dict[str, str] = {}
gradients_to_user_inputs: Dict[str, str] = {}
for spec in self.output_specs:
if spec.kind == OutputKind.LOSS_OUTPUT:
assert loss_output is None
assert isinstance(spec.arg, TensorArgument)
loss_output = spec.arg.name
elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER:
assert isinstance(spec.target, str)
assert isinstance(spec.arg, TensorArgument)
gradients_to_parameters[spec.arg.name] = spec.target
elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT:
assert isinstance(spec.target, str)
assert isinstance(spec.arg, TensorArgument)
gradients_to_user_inputs[spec.arg.name] = spec.target
if loss_output is None:
return None
return ExportBackwardSignature(
loss_output=loss_output,
gradients_to_parameters=gradients_to_parameters,
gradients_to_user_inputs=gradients_to_user_inputs,
)
# Map from assertion dependency token index to assertion dep token output
# name in output. The shape of output after aot_autograd will be like:
# (updated_inputs, user_outputs, dep_token).
@property
def assertion_dep_token(self) -> Optional[Mapping[int, str]]:
return None
@property
def input_tokens(self) -> Collection[str]:
input_tokens = []
for s in self.input_specs:
if s.kind == InputKind.TOKEN:
assert isinstance(s.arg, TokenArgument)
input_tokens.append(s.arg.name)
return tuple(input_tokens)
@property
def output_tokens(self) -> Collection[str]:
output_tokens = []
for s in self.output_specs:
if s.kind == OutputKind.TOKEN:
assert isinstance(s.arg, TokenArgument)
output_tokens.append(s.arg.name)
return tuple(output_tokens)
def __post_init__(self) -> None:
assertion_dep_token = self.assertion_dep_token
if assertion_dep_token is None:
return
assert len(assertion_dep_token) == 1
assertion_dep_token_index = next(iter(assertion_dep_token.keys()))
assert (
len(self.user_outputs) + len(self.buffers_to_mutate)
== assertion_dep_token_index
)
def replace_all_uses(self, old: str, new: str):
"""
Replace all uses of the old name with new name in the signature.
"""
assert isinstance(old, str)
assert isinstance(new, str)
arg_types = (TensorArgument, SymIntArgument, CustomObjArgument, TokenArgument)
for o in self.output_specs:
if isinstance(o.arg, arg_types):
if o.arg.name == old:
o.arg.name = new
for i in self.input_specs:
if isinstance(i.arg, arg_types):
if i.arg.name == old:
i.arg.name = new
def get_replace_hook(self):
def _(old, new, user):
if user.op in ("output", "input"):
self.replace_all_uses(old.name, new)
return _
def _immutable_dict(items):
"""
Creates a mapping where items cannot be added, deleted, or updated.
NOTE: The immutability is shallow (like tuple is an immutable collection).
"""
from types import MappingProxyType
return MappingProxyType(dict(items))
def _make_argument_spec(node, token_names) -> ArgumentSpec:
from torch import ScriptObject, SymInt
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor
if isinstance(node, (int, bool, float, type(None), str)):
# For const outputs we just directly return this
return ConstantArgument(name="", value=node)
assert (
"val" in node.meta
), f"{node} is not a constant or a node with a 'val' metadata field"
val = node.meta["val"]
if node.name in token_names:
return TokenArgument(name=node.name)
elif isinstance(val, FakeTensor):
return TensorArgument(name=node.name)
elif isinstance(val, SymInt):
return SymIntArgument(name=node.name)
elif isinstance(val, ScriptObject):
return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined]
elif isinstance(val, FakeScriptObject):
return CustomObjArgument(
name=node.name, class_fqn=val.script_class_name, fake_val=val
)
elif isinstance(val, (int, bool, str, float, type(None))):
return ConstantArgument(name=node.name, value=val)
else:
raise AssertionError(
f"Encountered an unsupported object of type {type(val)} "
f"while writing the metadata for exported program"
)
def _convert_to_export_graph_signature(
graph_signature: "GraphSignature",
gm: "torch.fx.GraphModule",
non_persistent_buffers: Set[str],
) -> "ExportGraphSignature":
from torch.utils import _pytree as pytree
is_joint = graph_signature.backward_signature is not None
# unpack objects
user_inputs = set(graph_signature.user_inputs)
inputs_to_parameters = graph_signature.inputs_to_parameters
inputs_to_buffers = graph_signature.inputs_to_buffers
user_outputs = set(graph_signature.user_outputs)
buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr]
grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr]
loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr]
input_tokens = graph_signature.input_tokens
output_tokens = graph_signature.output_tokens
inputs = [
_make_argument_spec(node, input_tokens)
for node in gm.graph.nodes
if node.op == "placeholder"
]
outputs = [
_make_argument_spec(node, output_tokens)
for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args)
]
def to_input_spec(inp: ArgumentSpec) -> InputSpec:
if isinstance(inp, TokenArgument):
return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None)
if not isinstance(inp, TensorArgument):
return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
name = inp.name
if name in user_inputs:
return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None)
elif name in inputs_to_parameters:
return InputSpec(
kind=InputKind.PARAMETER,
arg=inp,
target=inputs_to_parameters[name], # type: ignore[index]
)
elif name in inputs_to_buffers:
return InputSpec(
kind=InputKind.BUFFER,
arg=inp,
target=inputs_to_buffers[name], # type: ignore[index]
persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index]
)
else:
raise AssertionError(f"Unknown tensor input kind: {name}")
def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
if isinstance(o, TokenArgument):
return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None)
if not isinstance(o, TensorArgument):
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
name = o.name
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
if name in buffer_mutations:
return OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=o,
target=buffer_mutations[name], # type: ignore[index]
)
elif name in user_input_mutations:
return OutputSpec(
kind=OutputKind.USER_INPUT_MUTATION,
arg=o,
target=user_input_mutations[name], # type: ignore[index]
)
else:
raise AssertionError(f"Unknown tensor mutation kind: {name}")
else:
if name in user_outputs:
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
elif name in grad_params:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_PARAMETER,
arg=o,
target=grad_params[name],
)
elif name in grad_user_inputs:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_USER_INPUT,
arg=o,
target=grad_user_inputs[name],
)
elif name == loss_output:
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
else:
raise AssertionError(f"Unknown tensor output kind: {name}")
input_specs = [to_input_spec(inp) for inp in inputs]
output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs)

View File

@ -0,0 +1,69 @@
from typing import Dict, Union
import torch
import torch.utils._pytree as pytree
from torch.export.exported_program import ExportedProgram
__all__ = ["move_to_device_pass"]
def move_to_device_pass(
ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
) -> ExportedProgram:
"""
Move the exported program to the given device.
Args:
ep (ExportedProgram): The exported program to move.
location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
If a string, it is interpreted as a device name.
If a dict, it is interpreted as a mapping from
the existing device to the intended one
Returns:
ExportedProgram: The moved exported program.
"""
def _get_new_device(
curr_device: torch.device,
location: Union[torch.device, str, Dict[str, str]],
) -> str:
if isinstance(location, dict):
if str(curr_device) in location.keys():
return location[str(curr_device)]
else:
return str(curr_device)
else:
return str(location)
# move all the state_dict
for k, v in ep.state_dict.items():
if isinstance(v, torch.nn.Parameter):
ep._state_dict[k] = torch.nn.Parameter(
v.to(_get_new_device(v.device, location))
)
else:
ep._state_dict[k] = v.to(_get_new_device(v.device, location))
# move all the constants
for k, v in ep.constants.items():
if isinstance(v, torch.Tensor):
ep._constants[k] = v.to(_get_new_device(v.device, location))
for node in ep.graph.nodes:
# move all the nodes kwargs with burnt-in device
if "device" in node.kwargs:
kwargs = node.kwargs.copy()
kwargs["device"] = _get_new_device(kwargs["device"], location)
node.kwargs = kwargs
# move all the tensor metadata
node.meta["val"] = pytree.tree_map(
lambda v: v.to(_get_new_device(v.device, location))
if isinstance(v, torch.Tensor)
else v,
node.meta.get("val"),
)
ep.validate()
return ep

File diff suppressed because it is too large Load Diff