I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,559 @@
# mypy: allow-untyped-defs
"""
Utils for caching the outputs of AOTAutograd
"""
from __future__ import annotations
import json
import logging
import os
import pickle
import shutil
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch._dynamo.utils import counters, get_chromium_event_logger
from torch._functorch import config
from torch._inductor.codecache import (
_ident,
BypassFxGraphCache,
CompiledFxGraph,
extract_tensor_metadata_for_cache_key,
FxGraphCache,
FxGraphCachePickler,
FxGraphHashDetails,
write_atomic,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._logging import LazyString
from .runtime_wrappers import (
AOTDispatchAutograd,
AOTDispatchSubclassWrapper,
CompilerWrapper,
FunctionalizedRngRuntimeWrapper,
post_compile,
RuntimeWrapper,
SubclassMeta,
)
from .schemas import AOTConfig, ViewAndMutationMeta # noqa: F401
if TYPE_CHECKING:
from torch._inductor.utils import BoxedBool
from torch.fx.node import Node
log = logging.getLogger(__name__)
class BypassAOTAutogradCache(Exception):
pass
# Used to signify when FXGraphCache missed when AOTAutogradCache uses it
class FXGraphCacheMiss(BypassAOTAutogradCache):
pass
def check_node_safe(node: Node):
"""
Checks that the node only uses supported operators. We are starting with very
conservative cacheability constraints, and incrementally adding more support as we expand.
[Note: AOTAutograd Cacheability checks]
- Our cache key is computed from the FX graph produced by Dynamo and the input example values
- A node is "safe" if the same cache key results in a compiled artifact that has the same behavior
(i.e, the set of inputs that go into our cache key is sufficient to distinguish its behavior)
To accomplish this safety check, we consider the following functions to be safe:
- Public functions under modules torch, torch.functional, and torch.nn.functional: these are
allowed in the graph by dynamo, so we can assume they are safe to cache.
- method calls on base tensor types
- Any call_module that dynamo deemed safe to allow AOTAutograd to trace
- Non callable nodes, such as placeholder, output, get_attr
The test suite test_aot_autograd_cache.py::AOTAutogradCachePicklerTests tries its best to fully cover/specify this behavior.
"""
SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
def is_public_torch_api(target):
# Don't blindly allow private functions in the torch namespace
is_private = target.__name__.startswith("_")
return (
getattr(target, "__module__", None) in SAFE_TORCH_MODULES and not is_private
)
def is_torch_function(target):
if isinstance(target, torch._ops.OpOverload):
return True
if is_public_torch_api(target):
return True
is_builtin_fun_or_type = type(target).__name__ == "builtin_function_or_method"
return is_builtin_fun_or_type
def is_tensor(target):
# Tensors always have example values in meta field
return "example_value" in target.meta
# I'd love to use a match statement here, but it wasn't introduced until py3.10
if node.op == "call_function":
# We support only torch.* functions for now
# We can probably add an allowlist of safe non-torch implementations as well
if not is_torch_function(node.target):
raise BypassAOTAutogradCache(
f"Unsupported call_function target {node.target}"
)
elif node.op == "call_method":
method_name = node.target
method_target = node.args[0]
# Only support method calls on base tensors
if not is_tensor(method_target):
raise BypassAOTAutogradCache(
f"Unsupported call_method target {method_target}"
)
if (
type(method_name) != str
and type(method_name).__name__ != "method_descriptor"
):
raise BypassAOTAutogradCache(
f"Unsupported call_method method {node.target}: {method_name}"
)
# Cache safe
elif node.op in ("placeholder", "get_attr", "call_module", "output"):
# Assumption today for call_module being a safe op:
# (1) today the only call_module ops that can show up in a graph come from "built-in-nn-modules"
# that dynamo assumes are safe to trace. If dynamo assumes they are safely to blindly trace, then
# they should be safe to cache as well.
# (2) in the steady-state (some time in H2?) we shouldn't see these anymore, once inline builtin nn modules by default
# (3) We do not allow user made nn modules in the graph today, only function calls.
pass
else:
raise BypassAOTAutogradCache(f"Unsupported node op {node.op}")
def check_cacheable(gm: torch.fx.GraphModule):
"""
Checks that the graph module only uses supported operators
"""
nodes = gm.graph.nodes
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)
if not torch._inductor.config.fx_graph_cache:
raise BypassAOTAutogradCache("FX graph cache is not enabled")
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context and tracing_context.fakify_first_call:
raise BypassAOTAutogradCache(
"Won't cache a graph with fakify_first_call enabled"
)
for node in nodes:
check_node_safe(node)
class AOTAutogradCacheDetails(FxGraphHashDetails):
"""
Object to capture all the details for a dynamo graph module relevant to computing
a safe and stable cache key for AOTAutograd.
"""
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs,
aot_config: AOTConfig,
fx_config: Dict[str, BoxedBool],
):
# FxGraphHashDetails contains all the keys related to inductor. Also includes some system info
self.aot_config = aot_config
self.grad_enabled = torch.is_grad_enabled()
self.disable_amp = torch._C._is_any_autocast_enabled()
self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled()
self.autograd_config = config.save_config()
try:
# TODO: example_inputs causes more cache misses than necessary
# with dynamic shapes, because this is before we add
# symints to tensor metadata. Improve this later.
super().__init__(gm, example_inputs, fx_config, [])
except BypassFxGraphCache as e:
# Sometimes inductor configs are unpickleable and can fail
raise BypassAOTAutogradCache from e
def debug_lines(self) -> List[str]:
return AOTAutogradCachePickler.debug_lines(self)
def _reduce_aot_config(aot_config: AOTConfig):
"""
Reduce the config to a stable key for caching.
"""
return (
_ident,
(
aot_config.num_params_buffers,
aot_config.keep_inference_input_mutations,
aot_config.is_export,
aot_config.no_tangents,
aot_config.dynamic_shapes,
aot_config.aot_autograd_arg_pos_to_source,
aot_config.enable_log,
aot_config.pre_dispatch,
),
)
def _reduce_tensor(tensor):
"""
Reduce the tensor to a stable key for caching.
"""
return (
_ident,
(
extract_tensor_metadata_for_cache_key(
FxGraphCachePickler._device_map, tensor
),
),
)
class AOTAutogradCachePickler(FxGraphCachePickler):
dispatch_table = FxGraphCachePickler.dispatch_table.copy()
dispatch_table[AOTConfig] = _reduce_aot_config
dispatch_table[torch.Tensor] = _reduce_tensor
def autograd_cache_key(
gm: torch.fx.GraphModule,
example_inputs,
config: AOTConfig,
fx_config: Dict[str, BoxedBool],
# TODO: add args and parameters
) -> Tuple[str, List[str]]:
"""
Generate a unique hash of the FX graph for caching.
"""
check_cacheable(gm)
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
# The prefix distinguishes among the other kinds of objects we cache
key = "a" + AOTAutogradCachePickler.get_hash(details)
debug_lines = details.debug_lines()
log.debug(
"Autograd graph cache hash details for key %s:\n%s",
key,
LazyString(lambda: "\n".join(debug_lines)),
)
return key, debug_lines
@dataclass
class FXGraphCacheLoadable:
fx_graph_cache_key: str
def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph:
# [Note: AOTAutogradCache and FXGraphCache Guard interactions]
# As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments.
# FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph.
# The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly
# the same as the ones it passes to inductor, for both the forward and backward passes.
# (This does not mean that the tensor values passed in are the same: only that their symints are).
# That is, AOTAutograd and Inductor never create new guards based on symints with different sources
# than those passed to it by inductor.
result = FxGraphCache._lookup_graph(
self.fx_graph_cache_key, example_inputs, local=True, remote_cache=None
)
if result is None:
log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key)
counters["inductor"]["fxgraph_cache_miss"] += 1
raise FXGraphCacheMiss
FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"])
counters["inductor"]["fxgraph_cache_hit"] += 1
result._boxed_call = True
return result
@dataclass
class CompiledForward(FXGraphCacheLoadable):
"""
Cacheable entry for a forward function
"""
@dataclass
class CompiledBackward(FXGraphCacheLoadable):
"""
Cacheable entry for a forward function
"""
# Used by AOTDispatchAutograd.post_compile
backward_state_indices: List[int]
num_symints_saved_for_bw_: int
@dataclass
class AOTAutogradCacheEntry:
"""A single entry into the cache."""
# Forward and Backward info
compiled_fw: CompiledForward
compiled_bw: Optional[CompiledBackward]
# Runtime_metadata saved right before compilation
runtime_metadata: ViewAndMutationMeta
# Wrappers that run after each aot_dispatch_* function
dispatch_wrappers: List[CompilerWrapper]
# Used by AOTSubclassWrapper
maybe_subclass_meta: Optional[SubclassMeta]
num_fw_outs_saved_for_bw: Optional[int]
# Used by RuntimeWrapepr
indices_of_inps_to_detach: List[int]
# Turn cache entry into the original callable
def wrap_post_compile(
self,
args: List[torch.Tensor],
aot_config: AOTConfig,
fx_config: Dict[str, BoxedBool],
) -> Callable:
"""
This function takes a cache entry and carefully reconstructs the original callable
that AOTAutograd returned the first time it was run. It does this by running the various
post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers.
In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers.
In the autograd path, this consists of AOTAutogradDispatch.post_compile.
The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd.
Notably absent from the cached path are:
- DebugAssertWrapper
- FakifiedOutWrapper
Which we'll handle separately later on, if necessary.
"""
compiled_fw_func = self.compiled_fw.load(args, fx_config)
compiled_bw_func = None
if self.compiled_bw is not None:
compiled_bw_func = self.compiled_bw.load(args, fx_config)
needs_autograd = True
else:
needs_autograd = False
# Wrap the forward function in post compile wrappers
compiled_fw_func = AOTDispatchSubclassWrapper(
trace_joint=needs_autograd,
fw_only=None,
maybe_subclass_meta=self.maybe_subclass_meta,
num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw,
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
# In autograd case, functionalizedRngWrapper should not modify outs
return_new_outs = not needs_autograd
compiled_fw_func = FunctionalizedRngRuntimeWrapper(
return_new_outs=return_new_outs
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
disable_amp = torch._C._is_any_autocast_enabled()
if needs_autograd:
assert self.compiled_bw is not None
# This function is run on both cache miss and cache hit, either here
# or in aot_dispatch_autograd. On a cache hit,
# 1. the bw is already compiled
# 2. we don't need to save to the cache again
# so those corresponding arguments are set to None.
compiled_function = AOTDispatchAutograd.post_compile(
compiled_fw_func,
compiled_bw_func,
self.maybe_subclass_meta,
self.compiled_bw.num_symints_saved_for_bw_,
self.compiled_bw.backward_state_indices,
disable_amp,
self.indices_of_inps_to_detach,
None, # lazy_backward_info
aot_config,
fw_metadata=self.runtime_metadata,
try_save_cache_entry=None,
)
else:
compiled_function = RuntimeWrapper(
indices_of_inps_to_detach=self.indices_of_inps_to_detach,
trace_joint=False,
disable_amp=disable_amp,
).post_compile(
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
)
compiled_function, _ = post_compile(
self.dispatch_wrappers,
compiled_function,
aot_config,
runtime_metadata=self.runtime_metadata,
)
return compiled_function
class AOTAutogradCache:
"""
Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas
AOTAutogradCacheEntry handles the wrapping/unwrapping logic.
Cache Inputs (AOTAutogradCacheDetails)
- AOTAutogradCache takes in the following inputs, which are analogous to inputs given
to AOTAutograd by dynamo:
- A fx graph module generated by dynamo
- A list of args, which consists of:
- Symint inputs to the graph, generated by dynamo
- The **real tensor** inputs, which inductor uses for cudagraphs
- Notably, the real tensor inputs don't have symints in their metadata.
AOTAutograd then retraces those real tensor arguments into FakeTensors later during execution.
- A set of global configurations that affect AOTAutograd or Inductor behavior.
It then generates a cache key given these values. Notably, this means AOTAutogradCache currently
specializes on the sizes and strides of the real tensor inputs when dynamic shapes are turned on.
In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates
based on the real tensor inputs, which can contain symints.
# Cache Outputs (AOTAutogradCacheEntry)
- AOTAutogradCache caches the following values:
- The compiled forward and backward functions from inductor, via keys to the FXGraphCache
- Metadata to reconstruct the AOTModule from the compiled inductor artifacts
- See AOTAutogradCacheEntry for more info
[Note: Caching guards generated by AOTAutograd and Inductor]
AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each
compiled graph inductor generates. On a cache hit, AOTAutograd reloads the compiled forward and backward functions
from FXGraphCache, giving it new symint arguments from the input args.
FXGraphCache uses those symints and its saved guards to repopulate the ShapeEnv with guards.
**No new guards are generated into the shape env after inductor finishes compiling**, so the guards
saved by inductor are sufficient for correctness for both AOTAutograd and Inductor's caches.
"""
@staticmethod
def clear():
"""Clear the cache"""
try:
shutil.rmtree(AOTAutogradCache._get_tmp_dir())
except FileNotFoundError:
pass
@staticmethod
def load(
dispatch_and_compile: Callable,
mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper],
args,
aot_config: AOTConfig,
cudagraphs: BoxedBool,
) -> Callable:
"""
Load a result from the cache, and reconstruct a runtime wrapper around the object
"""
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
compiled_fn = None
cache_key = None
debug_lines: List[str] = []
cache_event_time = time.time_ns()
cache_state = None
fx_config = {"cudagraphs": cudagraphs}
try:
cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config)
entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup(cache_key)
if entry is not None:
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
cache_state = "hit"
cache_event_time = time.time_ns()
if compiled_fn is None:
log.info("AOTAutograd cache miss for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_miss"] += 1
cache_state = "miss"
cache_event_time = time.time_ns()
# Count missing the FXGraphCache as a miss not a bypass
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
# Special counter when we pass autograd cache but
# fail when on inductor guards
counters["aot_autograd"]["autograd_cache_guard_miss"] += 1
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
cache_state = "bypass"
cache_event_time = time.time_ns()
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
aot_config.cache_key = cache_key
compiled_fn = dispatch_and_compile()
cache_args = {
"key": cache_key,
"cache_state": cache_state,
"components": debug_lines,
}
chromium_log = get_chromium_event_logger()
chromium_log.log_instant_event(
f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_args
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "aotautograd_cache_hash",
"encoding": "json",
},
payload_fn=lambda: json.dumps(cache_args),
)
return compiled_fn
@staticmethod
def _get_tmp_dir() -> str:
"""
Get the toplevel temporary directory for storing compiled graphs.
"""
return os.path.join(cache_dir(), "aotautograd")
@staticmethod
def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]:
"""Given a key generated by AOTAutogradCachePickler, look up its location in the cache."""
subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key)
if not os.path.exists(subdir):
return None
path = os.path.join(subdir, "entry")
try:
with open(path, "rb") as f:
entry: AOTAutogradCacheEntry = pickle.load(f)
return entry
except Exception as e:
log.warning("AOTAutograd cache unable to load compiled graph: %s", e)
if config.strict_autograd_cache:
raise e
return None
@staticmethod
def save(key: str, entry: AOTAutogradCacheEntry):
"""Save a single entry into the cache."""
try:
content = pickle.dumps(entry)
except Exception as e:
log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e)
if config.strict_autograd_cache:
raise e
return None
subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key)
if not os.path.exists(subdir):
os.makedirs(subdir, exist_ok=True)
path = os.path.join(subdir, "entry")
log.info("Writing AOTAutograd cache entry to %s", path)
write_atomic(path, content)
counters["aot_autograd"]["autograd_cache_saved"] += 1

View File

@ -0,0 +1,749 @@
# mypy: allow-untyped-defs
"""
This module is one of the analysis modules - it takes as input a function or graph
and some preexisting properties, and returns some data that is useful for deciding
how to further proceed with compilation or construct runtime wrappers.
In particular, the analysis here constructs view and mutation metadata from running
a functionalized version of the graph under compilation.
"""
import collections
import contextlib
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List, Optional
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._guards import detect_fake_mode
from torch._logging import getArtifactLogger
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._subclasses.meta_utils import safe_is_leaf
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from .functional_utils import (
are_all_mutations_hidden_from_autograd,
are_all_mutations_under_no_grad_or_inference_mode,
from_fun,
has_data_mutation,
has_metadata_mutation,
has_same_metadata,
to_fun,
was_inductor_storage_resized,
)
from .schemas import (
FunctionalTensorMetadataEq,
InputAliasInfo,
MutationType,
OutputAliasInfo,
OutputType,
ViewAndMutationMeta,
)
from .subclass_utils import create_subclass_meta
from .utils import _get_autocast_states, KNOWN_TYPES, strict_zip
zip = strict_zip
log = logging.getLogger(__name__)
static_input_logger = getArtifactLogger("torch._dynamo", "cudagraph_static_inputs")
# Note [Tangents must be contiguous]
# We force tangents to be contiguous today.
# The idea is that we are technically making a guess about the strides of our tangents,
# while we trace out the joint.
# Today, we force this guess to be correct by additioanlly calling contiguous()
# on all tangents at runtime.
# In the future, you could imagine lifting this restriction, since these contiguous()
# calls can have noticeable perf overhead depending on the model.
def coerce_tangent(x):
if not isinstance(x, Tensor):
return x
out = x.detach().contiguous()
# Note [Tangents must be contiguous, Part 2]
# In the same way that "what strides do we assigns to our tangents" is a question
# that we can not answer (and therefore have to guess) as we trace the backward ahead-of-time,
# The same applies to any tensor subclass metadata, when we have tangents that are subclasses.
# To handle this situation, we have two new methods that a tensor subclass can implement:
# (1) __coerce_tangent_metadata__(self)
# Given a subclass with "non-standard" metadata, turn it into a new subclass with "normal" metadata.
# The main example here is a DTensor with the "_Partial" placement.
# If we have a forward output with a _Partial placement, and corresponding tangent
# with a Replicate/Shard placement, we have no way to convert the tangent "back" to a _Partial placement.
# This method lets us avoid the problem entirely by allowing subclasses to ensure that we can never
# have a tangent with "problematic" metadata, that we cannot convert to.
# (1) __coerce_same_metadata_as_tangent__(self, metadata)
# Given a subclass, and a target differing metadata,
# convert self to have the same metadata as the target.
# With DTensor being the main example, we can use this to convert a DTensor with a Replicate()
# placement into one with a Shard() placement, in the case that we "guessed wrong",
# and traced tangents with a Shard() placement at compile time.
#
if is_traceable_wrapper_subclass(out) and hasattr(
out, "__coerce_tangent_metadata__"
):
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
# It's possible to have a subclass that advertises as contiguous,
# but has noncontiguous inner tensors.
# Force these to be conntiguous too
if is_traceable_wrapper_subclass(out):
for attr in out.__tensor_flatten__()[0]: # type: ignore[attr-defined]
elem = getattr(out, attr)
if not elem.is_contiguous():
elem_contig = elem.contiguous()
setattr(out, attr, elem_contig)
return out
# This is a version of functionalization that is specifically designed
# for the AOTAutograd use case.
#
# Unlike functorch's variant, this doesn't use the functorch level system,
# instead it directly uses PyTorch's conventional dispatcher to hit the
# functionalization key. In particular, this means that FunctionalTensorWrapper
# can have autograd data stored directly on it.
#
# In typical AOTAutograd usage, the dispatch key order will look like:
#
# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
# outer tensor inner tensor
#
# Returns:
# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and
# The list of outputs from the forward, but **only** the outputs that we need
# to pass in as tangents into the backward.
# Specifically, aliased outputs from the forward get regenerated, and don't participate
# in the compiled backward function.
def run_functionalized_fw_and_collect_metadata(
f,
*,
keep_input_mutations: bool,
# TODO: refactor to kill this flag
is_train: bool = False,
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[List[int]] = None,
pre_dispatch: bool = False,
) -> Callable[..., ViewAndMutationMeta]:
memo: Dict[Tensor, Tensor] = {}
def _to_fun(t):
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = to_fun(t)
memo[t] = r
return r
else:
return t
@wraps(f)
def inner(*flat_args):
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
input_info: List[InputAliasInfo] = []
output_info: List[OutputAliasInfo] = []
prior_grad_enabled = torch.is_grad_enabled()
prior_autocast_states = _get_autocast_states()
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
# It doesn't matter if we run this under predispatch or not because it is
# only for figuring out metadata
mode = FunctionalTensorMode(_allow_token_discovery=True)
suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env):
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args)
# We didn't do any tracing, so we don't need to process the
# unbacked symbols, they will just disappear into the ether.
# Also, prevent memoization from applying.
if fake_mode:
fake_mode.epoch += 1
fake_mode.reset_nt_tensor_id_counter()
if prior_autocast_states != _get_autocast_states():
raise RuntimeError(
"AOTAutograd does not support tracing graphs that mutate the autocast state. "
"Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, "
"which will unwind all of their mutations to autocast state before the graph exits. "
"If you encounter this error while using torch.compile, please file a bug."
)
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)):
# NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in
# strides between the functionalized arg inner tensors and non-functionalized arg inner
# tensors. This is a problem as the inner tensor stride change may not be reflected
# correctly in the outer tensor, so disallow this for now.
mutates_data = has_data_mutation(f_arg)
if (
mutates_data
and not arg.is_contiguous()
and is_traceable_wrapper_subclass(arg)
):
raise RuntimeError(
"Mutations on non-contiguous inputs are currently not allowed on "
"tensor subclasses"
)
if not isinstance(arg, Tensor):
new_arg = arg
else:
new_arg = from_fun(f_arg)
mutates_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=False
)
if mutates_metadata and is_traceable_wrapper_subclass(arg):
raise RuntimeError(
"Metadata mutations are currently not allowed on tensor subclasses"
)
mutates_storage_metadata = has_metadata_mutation(
f_arg, arg, check_only_storage_mutation=True
)
mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(
f_arg
)
mutations_under_no_grad_or_inference_mode = (
mutates_data
and are_all_mutations_under_no_grad_or_inference_mode(f_arg)
)
mutation_inductor_storage_resize = was_inductor_storage_resized(f_arg)
if mutates_storage_metadata:
mutates_data = False
requires_grad = isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
input_info.append(
InputAliasInfo(
is_leaf=isinstance(arg, Tensor) and safe_is_leaf(arg),
mutates_data=mutates_data,
mutates_metadata=mutates_metadata,
mutations_hidden_from_autograd=mutations_hidden_from_autograd,
mutates_storage_metadata=mutates_storage_metadata,
mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode,
mutation_inductor_storage_resize=mutation_inductor_storage_resize,
requires_grad=requires_grad,
keep_input_mutations=keep_input_mutations,
)
)
# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate,
# We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
# to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
# on the base tensor, but we are obligated to properly set requires-gradness on the real output.
inp_storage_refs = {
StorageWeakRef(inpt.untyped_storage()): idx
for idx, inpt in enumerate(flat_f_args)
if isinstance(inpt, Tensor)
}
# We need inp tensor id's to be able to tell if an outputs **are** inputs.
inp_tensor_ids = {id(inpt) for inpt in flat_f_args if isinstance(inpt, Tensor)}
# We need output tensor id's to tell if any output._base` attributes **are** other outputs.
# (This is also a dict because we need to know that output's index, so we can regenerate
# the alias from it).
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
# Keep track of which outputs alias other outputs
out_tensor_alias_counts: DefaultDict = collections.defaultdict(int)
# This tells us, for a given group of outputs that alias each other,
# whether they e.g. all came from an unbind call
num_aliased_tensors_that_are_multi_output_views: DefaultDict = (
collections.defaultdict(int)
)
out_storage_to_tensors: DefaultDict = collections.defaultdict(set)
curr_storage = None
for o in flat_f_outs:
if isinstance(o, torch.Tensor):
curr_storage = StorageWeakRef(o.untyped_storage())
out_tensor_alias_counts[curr_storage] += 1
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
# This is an optimization on top of the "alias of intermediates" logic,
# which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# Before describing the optimization: this is important for AOTAutograd to have good
# perf around, multi-output views. HOWEVER:
# - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case,
# around using pre-dispatch tracing to partition out a graph so we can faithfully replay all
# views without having to regenerate them at runtime.
# - It's loosely described in this doc (more details will be added soon):
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
# - Once that change lands, we should just rip out this "optimization", since:
# (1) It will be fully unnecessary
# (2) Although it is only a few lines of code, it is a bit difficult to reason about
# its correctness with the autograd engine in all cases.
#
#
# What is this optimization? Consider the below case:
# def f(x):
# intermediate = x.mul(2)
# # x and intermediate here require grad
# o1, o2, ... o10 = intermediate.unbind(-1)
# return intermediate, o1, o2, ... o10
# Now, the "intermediate base" handling in AOTAutograd implies that we must do the following:
# (1) return "intermediate as an extra output of the compiled graph
# (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function.
# The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know
# that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function,
# this information will be hidden.
# In particular, mutating one alias might require autograd to update autograd metadata on the other aliases
# (like their grad_fn, for example, when the autograd engine needs to do view-replay).
#
# However, intermediate_base logic can be bad for backward performance (we sometimes generate
# as_strided calls during the intermediate base logic, which can have a slow backward formula).
# Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd?
#
# For a set of outputs of the graph that alias each other, o_1...o_k, consider:
# (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
# (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
# **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
# o_other, that aliases these outputs)
# (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad.
# This condition is important because it's what causes slowness in the intermediate_base
# codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and
# aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn.
# "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward.
# In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta
# of the other aliases?
#
# Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd):
# (a) What happens if we mutate any of o_1 through o_k directly?
# Autograd raises an error:
# "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is
# the output of a function that returns multiple views. Such functions do not allow the output
# views to be modified inplace. You should replace the inplace operation by an out-of-place one."
# (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)?
# Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views.
# (c) What if we mutate o_k under no_grad?
# Autograd raises the same error
# (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)?
# Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed.
# Autograd raises the same error
# (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view?
# We promised that there is at most **one** such alias, e.g. intermediate in the example above.
# You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k
# to be error fn's.
# Since intermediate was the *only* non-multi-output-alias, there are no other aliases
# of `intermediate` around that were produced by the compiled fn and have a valid grad_fn.
#
# Coming back to this optimization:
# Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias
# without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile
# if all of the above conditions are met.
# This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on
# in eager but fail to during torch.compile, but it has the benefit that this code has much better performance.
# NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here:
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit,
# then this optimization will probably matter less and might be ok to remove.
is_cur_tensor_multi_out_view = isinstance(
o, FunctionalTensor
) and torch._functionalize_is_multi_output_view( # type: ignore[attr-defined]
o.elem
)
if is_cur_tensor_multi_out_view:
num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1
out_storage_to_tensors[curr_storage].add(o)
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
# Why Do We Care If Storage Changed?
# It's important to understand the implications of storage changes in complex scenarios. Take this example:
#
# def f(x):
# x_storage = x.untyped_storage()
# non_leaf_tensor = torch.ones(4, requires_grad=True).clone()
#
# # Using no_grad() and _unsafe_preserve_version_counter to simulate the .data = operation
# with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
# x.set_(non_leaf_tensor.untyped_storage())
#
# out = x.view(-1)
#
# # Restoring x to its original storage, again simulating .data = operation
# with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(x):
# x.set_(x_storage)
#
# return out
#
# In this scenario, 'x' and 'out' have different shapes and are stored at different memory addresses, aka no aliasing.
# However, due to how set_() and more specificlaly, set is functionalized, is defined to preserve eager semantics,
# the autograd engine mistakenly assumes that 'x' and 'out' are aliased, treating 'x' as 'out._base'.
# This misinterpretation leads to an 'alias_of_input' flag, causing an unnecessary as_strided() call to be generated,
# which could lead to issues later in the code.
for o in flat_f_outs:
functional_tensor_storage_changed = isinstance(
o, FunctionalTensor
) and torch._functionalize_was_storage_changed( # type: ignore[attr-defined]
o.elem
)
curr_storage = (
None
if not isinstance(o, torch.Tensor)
else StorageWeakRef(o.untyped_storage())
)
outs_with_identical_metadata_that_require_grad = (
[]
if not isinstance(o, Tensor)
else [
curr
for curr in out_storage_to_tensors[curr_storage]
if has_same_metadata(o, curr)
and curr.requires_grad
and o is not curr
]
)
# See Note [Accessing .grad_fn on FunctionalTensor]
# In-place operations on views will trigger a lazy rebase of the autograd graph;
# this runs during access to the .grad_fn. The rebase logic will invoke view ops
# on FunctionalTensors, so we must enable a FunctionalTensorMode here to ensure
# these op calls succeed.
grad_fn = None
if isinstance(o, Tensor):
with FunctionalTensorMode():
grad_fn = o.grad_fn
is_result_of_custom_autograd_fn = False
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction)
# autograd fns
if type(grad_fn).__name__ == "CppFunction":
is_result_of_custom_autograd_fn = True
if isinstance(grad_fn, torch.autograd.function.BackwardCFunction):
is_result_of_custom_autograd_fn = True
if not isinstance(o, Tensor):
output_type = OutputType.non_alias
base_idx = None
elif (
curr_storage in inp_storage_refs
and grad_fn is not None
and is_result_of_custom_autograd_fn
):
output_type = OutputType.custom_function_view
base_idx = None
elif (
curr_storage in inp_storage_refs
and not functional_tensor_storage_changed
):
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
num_aliased_outs = out_tensor_alias_counts[curr_storage]
num_multi_output_view_outs = (
num_aliased_tensors_that_are_multi_output_views[curr_storage]
)
num_aliased_outs_that_are_not_multi_output_views = (
num_aliased_outs - num_multi_output_view_outs
)
if (
grad_fn is not None
and num_aliased_outs_that_are_not_multi_output_views == 0
):
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
# In particular, given:
# def f(x):
# return list(x.unbind(0))
# The main reason we ordinarily try to regenerate these output aliases outside of the
# compiled autograd.Function is because if any of the outputs are later mutated,
# autograd needs to perform view-replay to regenerate them.
# However, autograd does not allow users to mutate multi-output views
# in any way that can change the autograd metadata of other aliases.
# So we hide this aliasing from autograd here.
log.debug(
"Encountered AOTAutograd case: differentiable outputs that \
alias each other from a multi-output view call"
)
output_type = OutputType.non_alias
elif is_input_tensor:
output_type = OutputType.is_input
else:
output_type = OutputType.alias_of_input
elif functional_tensor_storage_changed and id(o) in inp_tensor_ids:
# When there is a set_() on an input, we cannot rely on checking storages
# to detect if we are returning an input (since the inputs storage is different)
assert curr_storage is not None
base_idx = inp_storage_refs[curr_storage]
output_type = OutputType.is_input
# We only need to handle the intermediate base case when both
# the intermediate base and the output require gradients.
# See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
elif o._base is not None and o.requires_grad and o._base.requires_grad:
num_aliased_outs = out_tensor_alias_counts[curr_storage]
num_multi_output_view_outs = (
num_aliased_tensors_that_are_multi_output_views[curr_storage]
)
num_aliased_outs_that_are_not_multi_output_views = (
num_aliased_outs - num_multi_output_view_outs
)
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
if (
out_tensor_alias_counts[curr_storage] == 1
or num_aliased_outs_that_are_not_multi_output_views <= 1
):
# Note [Intermediate Bases Optimization]
# Normally if we have an output that aliases an intermediate,
# we need to add the extra "intermediate base" logic further down
# to prevent autograd from yelling at us if the user later tries to
# mutate that output.
# However, the common case here is if we have an output that aliases an intermediate,
# but doesn't alias any other outputs.
# In that case, autograd shouldn't have to worry about the aliasing at all
# (if that output is mutated, there are no other live aliases for autograd to worry about).
# The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs.
# So as an optimization, we won't do intermediate base handling in this case.
# Instead, we'll hide the aliasing from autograd using aten._unsafe_view().
if (
out_tensor_alias_counts[curr_storage] != 1
and num_aliased_outs_that_are_not_multi_output_views <= 1
):
log.debug(
"Encountered AOTAutograd case: differentiable outputs that alias each other \
from a multi-output view call"
)
output_type = OutputType.unsafe_view_alias
base_idx = None
else:
# First, check if o's ._base is an existing output
maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
if maybe_existing_out_idx is not None:
# Special case where the output is an alias of a graph intermediate, but that intermediate
# is itself also a user output.
output_type = (
OutputType.alias_of_intermediate_base_is_user_output
)
base_idx = maybe_existing_out_idx
else:
# Next, check if o's ._base is an intermediate base that we already returned
maybe_existing_base_output_idx = (
intermediate_base_tensor_id_to_output_idx.get(
id(o._base), None
)
)
if maybe_existing_base_output_idx is not None:
output_type = OutputType.alias_of_intermediate
base_idx = maybe_existing_base_output_idx
else:
# Otherwise, take o._base and explicitly return it as an output in the compiled graph
new_out_idx = len(intermediate_bases)
base_idx = new_out_idx
# Indicate to the logic later on (when we trace the joint)
# that this particular output should get it's ._base appended to the forward graph outputs
output_type = (
OutputType.alias_of_intermediate_save_as_output
)
intermediate_base_tensor_id_to_output_idx[
id(o._base)
] = new_out_idx
intermediate_bases.append(o._base)
elif (
# See https://github.com/pytorch/pytorch/issues/100348 for this case.
# This protects against the specific case where a user fn returns (output, output.detach())
out_tensor_alias_counts[curr_storage] > 1
and len(outs_with_identical_metadata_that_require_grad) > 0
and not o.requires_grad
):
# In theory we could use any of these tensors to regenerate the aliased outputs from,
# since they all alias each other and have identical metatadata
out_alias = outs_with_identical_metadata_that_require_grad[0]
existing_out_idx = out_tensor_ids[id(out_alias)]
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = existing_out_idx
else:
output_type = OutputType.non_alias
base_idx = None
if isinstance(o, torch.Tensor):
dynamic_dims = {
i for i, s in enumerate(o.shape) if not is_concrete_int(s)
}
else:
dynamic_dims = None
# Save the current FunctionalTensor output.
#
# This will be used at runtime for reconstructing output views from
# their respective base tensors.
#
# The FunctionalTensor will be saved if one of the 2 conditions below
# is true:
functional_tensor = None
if (
# 1. If the output_type is either of:
# (i) alias_of_intermediate;
# (ii) alias_of_intermediate_save_as_output; or
# (iii) alias_of_intermediate_base_is_user_output.
#
# No need to worry about in-place view operations here, since
# this functionalization step elimitates mutations.
#
# i.e. we have access to the actual base tensor, before the
# in-place operation was applied.
output_type
in (
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
OutputType.alias_of_intermediate_base_is_user_output,
)
) or (
# 2. If the output_type is alias_of_input, and no in-place view
# operationthe was run on the input (base tensor).
#
# In this case, we need to check for metadata mutation because
# the runtime explicitly reconstructs the inputs, before actually
# reconstructing the outputs. Due to in-place view operations, the
# fully reconstructed input may not be this output base tensor
# anymore.
output_type == OutputType.alias_of_input
and base_idx is not None
and not input_info[base_idx].mutates_metadata
):
if isinstance(o, FunctionalTensor):
functional_tensor = FunctionalTensorMetadataEq(o.elem)
out_info = OutputAliasInfo(
output_type=output_type,
raw_type=type(o),
base_idx=base_idx,
dynamic_dims=dynamic_dims,
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad,
functional_tensor=functional_tensor,
)
output_info.append(out_info)
# See Note [AOT Autograd: Views to avoid tangents aliasing inputs]
def view_avoid_dupes_with_primals(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
return transform_subclass(
t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t)
)
if isinstance(t, Tensor):
return t.view(t.shape)
return t
# This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
f_input_tangents = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutation_type == MutationType.MUTATED_OUT_GRAPH
and info.mutates_data
and info.requires_grad
]
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type
in [
OutputType.non_alias,
OutputType.unsafe_view_alias,
OutputType.custom_function_view,
]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases are also included in the backward graph
f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
traced_tangents = pytree.tree_map(from_fun, f_tangents)
traced_tangents = pytree.tree_map(
view_avoid_dupes_with_primals, traced_tangents
)
# See Note [Tangents must be contiguous]
traced_tangents = pytree.tree_map(
coerce_tangent,
traced_tangents,
)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
passed_indices = set(static_input_indices)
static_input_indices = [
i
for i, arg in enumerate(flat_args)
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
]
static_input_logger.debug(
"static input indices metadata analysis: %s", static_input_indices
)
f_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutation_type == MutationType.MUTATED_OUT_GRAPH
]
f_metadata_mutated_inputs = [
inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata
]
# This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be.
# When handling subclasses, we need info about **all** outputs of compiled forward graph,
# so we know precisely which graph outputs to wrap back into tensor subclasses
# Ideally we would refactor this so not have an is_train flag, and have the separate
# inference and training paths decide which inputs/output to ask for subclass info on.
# However, we currently stash indexing information on each SubclassMeta about its order
# in the graph outputs list.
f_fw_graph_outs = list(flat_f_outs)
if is_train or not keep_input_mutations:
f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs
else:
# even when "keep_input_mutations" is True,
# we never keep metadata-only mutations in the fw graph
f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs
if is_train:
f_fw_graph_outs = f_fw_graph_outs + intermediate_bases
fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
grad_enabled_mutation = None
if torch.is_grad_enabled() != prior_grad_enabled:
grad_enabled_mutation = torch.is_grad_enabled()
torch.set_grad_enabled(
prior_grad_enabled
) # Restore the prior state after tracing it
log.debug(
(
"grad_mode mutation encountered in graph. "
"Will emit mutation epilogue, to set grad_mode=%s"
),
grad_enabled_mutation,
)
metadata = ViewAndMutationMeta(
input_info=input_info,
output_info=output_info,
num_intermediate_bases=len(intermediate_bases),
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
subclass_inp_meta=create_subclass_meta(flat_args),
subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs),
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
static_input_indices=static_input_indices,
tokens=mode._tokens,
)
return metadata
return inner

View File

@ -0,0 +1,314 @@
# mypy: allow-untyped-defs
"""
This module dispatches the graphs to either the forward-only or joint compilation
pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata.
"""
import dataclasses
from typing import Any, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import lazy_format_graph_code
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import _detect_infra_mode
from .. import config
from .functional_utils import (
assert_functional_graph,
propagate_input_mutation_stacktraces,
)
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
from .traced_function_transforms import (
aot_dispatch_subclass,
create_functionalized_fn,
create_joint,
fn_input_mutations_to_outputs,
fn_prepped_for_autograd,
handle_effect_tokens_fn,
)
from .utils import (
copy_fwd_metadata_to_bw_nodes,
root_module_when_exporting_non_strict,
unlift_tokens,
)
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
with enable_python_dispatcher(), FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
):
fx_g = make_fx(
f,
decomposition_table=aot_config.decompositions,
record_module_stack=True,
pre_dispatch=aot_config.pre_dispatch,
)(*args)
return fx_g
def aot_dispatch_base_graph(
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, List[Any], Optional[SubclassMeta]]:
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
# The cases that aot_dispatch_base doesn't need to handle include:
# - outputs that are aliases of graph intermediates
# - outputs that are aliases of graph inputs
# While cases that it does need to handle include:
# - input mutations (including when inputs are aliases of each other)
# - input metadata mutations
fn_to_trace = fn_input_mutations_to_outputs(
flat_fn,
fw_metadata,
keep_data_input_mutations=aot_config.keep_inference_input_mutations,
)
fn_to_trace, updated_flat_args = create_functionalized_fn(
fn_to_trace,
flat_args,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=False,
)
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
# into CompilerWrappers.
(
fn_to_trace,
updated_flat_args_subclasses_desugared,
maybe_subclass_meta,
) = aot_dispatch_subclass(
fn_to_trace,
updated_flat_args,
is_joint_structure=False,
meta=fw_metadata,
fw_only=flat_fn,
)
(fn_to_trace, updated_flat_args_subclasses_desugared) = handle_effect_tokens_fn(
fn_to_trace,
updated_flat_args_subclasses_desugared,
meta=fw_metadata,
trace_joint=False,
)
aot_graphs_log.debug(
"aot_config id: %s, fw_metadata=%s,subclass_metadata=%s",
str(aot_config.aot_id),
str(fw_metadata),
str(maybe_subclass_meta),
)
# We track buffer assignments when exporting in non-strict mode.
# (In contrast, strict mode errors on any attribute assignment.)
mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn)
if aot_config.is_export and mod_when_exporting_non_strict is not None:
# For any buffer that is assigned, we want to associate it to the final proxy node
# that it is assigned to. This node can then be added as a buffer mutation output.
assigned_buffers = {}
def _map_assigned_buffer_to_proxy(_mod, name, buffer):
# We intercept buffer assignments on the root module through this hook.
if _mod._buffers is mod_when_exporting_non_strict._buffers:
# The value assigned to a buffer is a functional tensor, which wraps a fake tensor.
assert isinstance(
buffer, torch._subclasses.functional_tensor.FunctionalTensor
)
fake = buffer.from_functional()
# The fake tensor in turn is associated with a proxy node.
proxy_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
assert proxy_mode is not None
proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot(
fake, proxy_mode.tracer
).proxy.node
# We map the assigned buffer to this proxy node.
assigned_buffers[name] = proxy.name
return buffer
handle = torch.nn.modules.module.register_module_buffer_registration_hook(
_map_assigned_buffer_to_proxy
)
saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only(
torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared
)
fw_module = _create_graph(
fn_to_trace,
updated_flat_args_subclasses_desugared,
aot_config=aot_config,
)
if aot_config.is_export and mod_when_exporting_non_strict is not None:
# We update metadata to consider any assigned buffers as buffer mutations.
i = len(dict(mod_when_exporting_non_strict.named_parameters()))
for name, _ in mod_when_exporting_non_strict.named_buffers():
if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined]
fw_metadata.input_info[i] = dataclasses.replace(
fw_metadata.input_info[i], mutates_data=True
)
fw_metadata.num_mutated_inp_runtime_indices += 1
i += 1
# We add nodes corresponding to buffer assignments as output nodes in the graph.
add_nodes = []
output_node = None
output_node = list(fw_module.graph.nodes)[-1]
for name in assigned_buffers.values(): # type: ignore[possibly-undefined]
for node in fw_module.graph.nodes:
if node.name == name:
add_nodes.append(node)
node.users[output_node] = None
output_node.args = ((*add_nodes, *output_node.args[0]),)
handle.remove() # type: ignore[possibly-undefined]
# As long as we opted to remove input mutations, then
# there should be *NO* mutating ops in the graph at this point.
copy_count = assert_functional_graph(fw_module.graph)
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph)
propagate_input_mutation_stacktraces(fw_module.graph)
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata, aot_config)
saved_updated_flat_args_subclasses_desugared = (
saved_updated_flat_args_subclasses_desugared[num_tokens:]
)
assert copy_count == copy_count2
if aot_config.enable_log:
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Forward graph",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
# TODO: should factor this into a separate function for export that always only returns just the graph.
if aot_config.is_export:
assert (
maybe_subclass_meta is None
), "aot_export_module does not support tensor subclass inputs for now."
return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta
# Has the precondition that there
# are no duplicate arguments in flat_args (e.g., the same Tensor
# object never shows up twice. However, two tensor inputs MAY alias
# the same storage, so long as they have separate TensorImpls.)
def aot_dispatch_autograd_graph(
flat_fn,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
# It includes outputs of the original forward, *and* any updated inputs due to input mutations.
# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
joint_inputs = (flat_args, fw_metadata.traced_tangents)
fn_prepared_for_autograd = fn_prepped_for_autograd(
flat_fn,
fw_metadata,
)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
joint_fn_to_trace,
joint_inputs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=True,
)
# TODO: replace with AOTDispatchSubclassWrapper once we refactor
# fn_input_mutations_to_outputs and create_functionalized_fn
# into CompilerWrappers.
subclass_tracing_info = aot_dispatch_subclass(
joint_fn_to_trace,
updated_joint_inputs,
is_joint_structure=True,
meta=fw_metadata,
fw_only=flat_fn,
)
joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn
updated_joint_inputs = subclass_tracing_info.plain_tensor_args
(joint_fn_to_trace, updated_joint_inputs) = handle_effect_tokens_fn(
joint_fn_to_trace,
updated_joint_inputs,
meta=fw_metadata,
trace_joint=True,
)
# When we call _create_graph, this may mutate the metadata of joint
# inputs. But callers are expecting to get the original joint inputs. So
# we make aliases of all the inputs to make sure we have a copy that
# doesn't get modified.
#
# This destroys requires_grad/grad_fn information. However, backends
# beneath AOTAutograd are indifferent to this information, so it doesn't
# matter.
saved_updated_joint_inputs = pytree.tree_map_only(
torch.Tensor, lambda t: t.detach(), updated_joint_inputs
)
maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
# Redundant with the check above, but worth having in case tracing introduced
# a fake tensor. Unlikely.
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
fx_g.graph.eliminate_dead_code()
copy_fwd_metadata_to_bw_nodes(fx_g)
fx_g.recompile()
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
# when we need to manually detach() some inputs in the forward.
# Higher order ops might eventually need to do the same.
if aot_config.is_export:
assert (
maybe_subclass_meta is None
), "aot_export_module does not support tensor subclass inputs for now."
return fx_g, saved_updated_joint_inputs, maybe_subclass_meta

View File

@ -0,0 +1,494 @@
# mypy: allow-untyped-defs
"""
This file contains utilities related to functionalization in AOTAutograd:
1. converting to/from functional tensors
2. detecting Tensor mutations - both metadata and Tensor value
3. regenerating/replaying views from their base
4. checking if a graph is functional i.e. whether it contains any mutation ops
"""
from __future__ import annotations
from typing import Optional
import torch
from torch import Tensor
from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
def to_fun(t):
if isinstance(t, Tensor):
if is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t))
torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined]
return out
else:
return FunctionalTensor.to_functional(t)
else:
return t
def sync_functional_tensor(t):
if is_traceable_wrapper_subclass(t):
attrs, ctx = t.__tensor_flatten__() # type: ignore[attr-defined]
for attr in attrs:
sync_functional_tensor(getattr(t, attr))
else:
torch._sync(t)
# When subclasses are involved, t here will usually look something like:
# SubclassA(SubclassB(FunctionalTensor(_to_fun_tensor(FakeTensor))))
def from_fun(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t))
torch._mirror_autograd_meta_to(t, out) # type: ignore[attr-defined]
return out
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t) # type: ignore[attr-defined]
return t
sync_functional_tensor(t)
return torch._from_functional_tensor(t.elem)
def is_fun(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
t_attrs, _ = t.__tensor_flatten__() # type: ignore[attr-defined]
t_inners = [getattr(t, attr) for attr in t_attrs]
any_fun = any(is_fun(x) for x in t_inners)
all_fun = all(is_fun(x) for x in t_inners)
assert any_fun == all_fun
return any_fun
return isinstance(t, FunctionalTensor)
# t here is either
# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
# (2) A traceable tensor subclass that holds a FunctionalTensor
# (3) Not a tensor
def has_data_mutation(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
# A tensor subclass was updated if any of its inner elements were updated
return any(has_data_mutation(getattr(t, attr)) for attr in attrs)
else:
if isinstance(t, torch.Tensor):
assert isinstance(t, FunctionalTensor)
return torch._functionalize_has_data_mutation(t.elem) # type: ignore[attr-defined]
return False
def are_all_mutations_hidden_from_autograd(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
# If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd.
return all(
are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs
)
elif isinstance(t, torch.Tensor):
assert isinstance(t, FunctionalTensor)
return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem)
else:
return False
def are_all_mutations_under_no_grad_or_inference_mode(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
return all(
are_all_mutations_under_no_grad_or_inference_mode(getattr(t, attr))
for attr in attrs
)
else:
assert isinstance(t, FunctionalTensor)
return torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode(
t.elem
)
def was_inductor_storage_resized(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
if any(was_inductor_storage_resized(getattr(t, attr)) for attr in attrs):
raise RuntimeError(
f"storage resizing is not supported on tensor subclass: {type(t)}"
)
elif not isinstance(t, torch.Tensor):
return False
else:
assert isinstance(t, FunctionalTensor)
return torch._functionalize_was_inductor_storage_resized(t.elem)
# f_arg here is either
# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
# (2) A traceable tensor subclass that holds a FunctionalTensor
# (3) Not a tensor
# Assumption: arg promises to be the "original" tensor wrapped by f_arg
# Note: "storage mutations" coming from set_() are a type of metadata mutation. So:
# - check_only_storage_mutation=True: only return true if there was a storage mutation
# - check_only_storage_mutation=Flse: return true if there was any metadata mutation (including a storage mutation)
def has_metadata_mutation(f_arg, arg, *, check_only_storage_mutation: bool):
if is_traceable_wrapper_subclass(f_arg):
attrs, _ = f_arg.__tensor_flatten__()
# A tensor subclass was updated if any of its inner elements were updated
f_inner_ts = [getattr(f_arg, attr) for attr in attrs]
inner_ts = [getattr(arg, attr) for attr in attrs]
return any(
has_metadata_mutation(
f_inner_t,
inner_t,
check_only_storage_mutation=check_only_storage_mutation,
)
for f_inner_t, inner_t in zip(f_inner_ts, inner_ts)
)
else:
if not isinstance(f_arg, torch.Tensor):
assert not isinstance(arg, torch.Tensor)
return False
assert isinstance(f_arg, FunctionalTensor)
assert isinstance(arg, FakeTensor)
arg_after = torch._from_functional_tensor(f_arg.elem)
# This is true if the current tensor experienced at least one set_() call
maybe_storage_changed = torch._functionalize_was_storage_changed(f_arg.elem) # type: ignore[attr-defined]
# However, multiple set_() calls can cancel out. So we also check whether the
# storage of the tensor has changed.
# Note: if an input experienced two set_() calls that cancel out, **and**
# it experiences an data mutation, we pessimistically think that the set_()
# call is necessary here. We could in theory fix this, but this will
# hopefully never happen in user code, and is not needed for fsdp.
if is_sparse_any(arg):
# TODO:add sparse tensors support to functionalization
same_storages = False
else:
same_storages = StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(
arg_after.untyped_storage()
)
has_storage_metadata_mutation = maybe_storage_changed and not same_storages
if check_only_storage_mutation:
return has_storage_metadata_mutation
# storage metadata mutation is a type of metadata mutation, so return true if we saw one
if has_storage_metadata_mutation:
return True
maybe_metadata_mutated = torch._functionalize_has_metadata_mutation(f_arg.elem) # type: ignore[attr-defined]
# This is true if the current tensor experienced at least one metadata mutation.
# So if false, we know there was no metadata mutation
if not maybe_metadata_mutated:
return False
# However, multi metadata mutations can cancel out.
# So we also check if the concrete sizes/strides on the tensor have changed.
same_sizes = arg.shape == arg_after.shape
same_strides = arg.stride() == arg_after.stride()
same_offsets = arg.storage_offset() == arg_after.storage_offset()
has_metadata_mutation_ = maybe_metadata_mutated and not (
same_sizes and same_strides and same_offsets
)
# We consider a tensor to have been metadata mutated if its storage was mutated through a set_() call.
return has_metadata_mutation_
def gen_alias_from_base(
aliased_base_tensor,
target_meta_tensor,
target_requires_grad,
target_functional_tensor: Optional[FunctionalTensorMetadataEq] = None,
*,
replay_views,
):
# Patch the correct requires_grad field of the output tensor, depending on whether:
# (i) the reconstructed output (out) was came from a tensor that requires grad or not;
# and (ii) the concrete returned output does require grad or not.
def patch_requires_grad(out):
if aliased_base_tensor.requires_grad and not target_requires_grad:
out = out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
out.requires_grad_(True)
return out
# If provided, use the target functional tensor for replaying the views.
#
# In summary, we use the fact that FunctionalTensorWrapper saves the view
# functions applied to itself (collected during functionalization) so as
# to replay them (view functions) on the aliased_base_tensor.
if (
replay_views
and target_functional_tensor is not None
and not torch._functionalize_is_symbolic(target_functional_tensor.tensor)
):
functional_tensor = target_functional_tensor.tensor
out = torch._functionalize_apply_view_metas(
functional_tensor, aliased_base_tensor
)
# If re-applying the ViewMeta sequence succeeded, there should be no more
# problems going forward. We just check we got to the target shape and
# patch requires_grad flag.
assert out.shape == target_meta_tensor.shape, (
"incorrect out shape after application of ViewMeta sequence: "
f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)"
)
return patch_requires_grad(out)
# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
if target_meta_tensor._base is not None:
# The base that we want to replay our view off of might have a different shape than the view's original base.
b = target_meta_tensor._base
abt = aliased_base_tensor
# Don't unnecessarily call as_strided if nothing changed; as_strided's
# backward is poorly implemented and slow
if abt is not b and (
abt.size() != b.size()
or abt.stride() != b.stride()
or abt.storage_offset() != b.storage_offset()
):
reshaped_base_tensor = aliased_base_tensor.as_strided(
b.size(), b.stride(), b.storage_offset()
)
else:
reshaped_base_tensor = aliased_base_tensor
out = target_meta_tensor._view_func(reshaped_base_tensor)
# This shape mismatch can happen due to a bug in inplace/view handling in autograd.
# Try putting a breakpoint here and running
# `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
# Also, https://github.com/pytorch/pytorch/issues/49825
#
# As a stopgap, we'll fall back to as_strided.
if out is not None and out.shape == target_meta_tensor.shape:
return patch_requires_grad(out)
size = target_meta_tensor.size()
stride = target_meta_tensor.stride()
storage_offset = target_meta_tensor.storage_offset()
if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
else:
aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
# For outputs aliasing inputs, we need to check if the requires-gradness has changed.
aliased_out = patch_requires_grad(aliased_out)
# For outputs aliasing inputs, we need to check if the dtype has changed.
# as_strided() is the "most generic" view, but it does not cover cross-dtype views
if aliased_out.dtype != target_meta_tensor.dtype:
aliased_out = aliased_out.view(target_meta_tensor.dtype)
return aliased_out
def has_same_metadata(t1, t2):
return (
definitely_true(sym_eq(t1.size(), t2.size()))
and definitely_true(t1.layout == t2.layout)
and (
is_sparse_any(t1)
or (
definitely_true(sym_eq(t1.stride(), t2.stride()))
and definitely_true(t1.storage_offset() == t2.storage_offset())
)
)
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
# Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata
# after applying all the ViewMeta operations.
class FunctionalTensorMetadataEq:
def __init__(self, tensor: torch.Tensor) -> None:
assert torch._is_functional_tensor(tensor)
self.tensor = tensor
def __eq__(self, other: object) -> bool:
# If other is None, then it probably means that we weren't able to recreate
# the FunctionalTensorMetadataEq. One of this cases is when we update the
# view metadata by calling: create_synthetic_base_metadata.
if other is None:
return True
# Comparison agains any other type is not implemented.
if not isinstance(other, FunctionalTensorMetadataEq):
return NotImplemented
return has_same_metadata(self.tensor, other.tensor)
# new_arg and arg here are either:
# (1) both a FakeTensor
# (2) both a traceable tensor subclass that holds a FakeTensor
# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
# When we run functionalization and wrap our inputs into FunctionalTensors,
# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed
#
# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization
# to confirm that inputs were not mutated when running the user's model with functionalization on.
# But when we have subclass inputs, we can't rely on that:
# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs
# a brand new subclass instance: we are calling __tensor_unflatten__, and going
# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor))
def was_tensor_updated(arg, new_arg):
if is_traceable_wrapper_subclass(arg):
assert is_traceable_wrapper_subclass(new_arg)
attrs, _ = arg.__tensor_flatten__()
new_attrs, _ = new_arg.__tensor_flatten__()
assert attrs == new_attrs
# A tensor subclass was updated if any of its inner elements were updated
return any(
was_tensor_updated(getattr(arg, attr), getattr(new_arg, attr))
for attr in attrs
)
else:
return arg is not new_arg
# new_arg and arg here are either:
# (1) both a FakeTensor
# (2) both a traceable tensor subclass that holds a FakeTensor
# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
# When we run functionalization and wrap our inputs into FunctionalTensors,
# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed,
# but shares storage with the old input
def was_tensor_metadata_updated(arg, new_arg):
if is_traceable_wrapper_subclass(arg):
assert is_traceable_wrapper_subclass(new_arg)
attrs, _ = arg.__tensor_flatten__()
new_attrs, _ = new_arg.__tensor_flatten__()
assert attrs == new_attrs
# A tensor subclass was updated if any of its inner elements were updated
return any(
was_tensor_metadata_updated(getattr(arg, attr), getattr(new_arg, attr))
for attr in attrs
)
else:
return arg is not new_arg and StorageWeakRef(
arg.untyped_storage()
) == StorageWeakRef(new_arg.untyped_storage())
# Returns the number of detected copy_
def assert_functional_graph(fx_g: torch.fx.Graph) -> int:
allowed_mutation_ops = [
torch.ops.aten.copy_.default,
torch.ops.aten.set_.source_Tensor,
]
if hasattr(torch.ops.fsdp, "set_"):
allowed_mutation_ops.append(torch.ops.fsdp.set_.default)
placeholders = set()
mutation_count = 0
# NB: It would also be nice to verify that the mutations all happen at the
# end, but we also do some administrative views after mutations so this
# isn't actually true. (TODO: Could this cause problems for Inductor?)
for n in fx_g.nodes:
if n.op == "placeholder":
placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
if n.target in allowed_mutation_ops:
suffix = True
# Can only copy_/set_ into an input
# this is mostly a hack to avoid failing XLA tests.
# See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113
if "set_buffer_donor_" not in str(n.args[0]):
assert (
n.args[0] in placeholders
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
mutation_count += 1
else:
assert (
not n.target._schema.is_mutable
), f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}"
return mutation_count
def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
placeholders = set()
for n in fx_g.nodes:
if n.op == "placeholder":
placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
if n.target is torch.ops.aten.copy_.default:
# Can only copy_ into an input, and can only do so once
if "set_buffer_donor_" not in str(n.args[0]):
assert (
n.args[0] in placeholders
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
placeholders.remove(n.args[0])
copy_from_node = n.args[1]
# Pre-condition: every node has a "stack_trace" field in its meta,
# but copy_() nodes do not (since we manually added them during functionalization).
# Instead, we manually propagate here.
if "stack_trace" in copy_from_node.meta:
n.meta["stack_trace"] = copy_from_node.meta["stack_trace"]
def _check_if_mutation_can_be_in_graph(
keep_input_mutations: bool,
mutates_data,
mutates_metadata,
mutations_hidden_from_autograd,
mutations_under_no_grad_or_inference_mode,
mutates_storage_metadata,
mutation_inductor_storage_resize,
requires_grad,
):
if keep_input_mutations:
in_graph = (
mutates_data or mutates_storage_metadata or mutation_inductor_storage_resize
) and (
(not mutates_metadata and not requires_grad)
or mutations_hidden_from_autograd
or mutations_under_no_grad_or_inference_mode
)
else:
in_graph = False
# See Note [set_() Input Mutations in AOTAutograd]
# If there was a `set_()`, we require that all mutations were under no_grad,
# so we can (safely) emit the set_() in the graph at runtime
# resize_() gets the same treatment
if mutation_inductor_storage_resize or mutates_storage_metadata:
op_name = "resize_" if mutation_inductor_storage_resize else "set_"
assert in_graph, f"""\
Encountered a {op_name} on a graph input, but the input has other mutations that we cannot
keep in the graph. This is not supported today. Current state:
keep_input_mutations={keep_input_mutations}
mutates_data={mutates_data}
mutates_metadata={mutates_metadata}
mutations_hidden_from_autograd={mutations_hidden_from_autograd}
mutations_under_no_grad_or_inference_mode={mutations_under_no_grad_or_inference_mode}
mutation_inductor_storage_resize={mutation_inductor_storage_resize}
requires_grad={requires_grad}"""
return in_graph

View File

@ -0,0 +1,493 @@
# mypy: allow-untyped-defs
"""
This module is one of the analysis modules - it takes as input a function or graph
and some preexisting properties, and returns some data that is useful for deciding
how to further proceed with compilation or construct runtime wrappers.
In particular, the following analyses are provided:
1. Refine the view and mutation metadata collected previously - removing duplicate
inputs or mapping views to their bases.
2. We also analyze the function signature for export graphs.
"""
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch import Tensor
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from .. import config
from .collect_metadata_analysis import coerce_tangent
from .schemas import (
BackwardSignature,
GraphSignature,
InputAliasInfo,
OutputAliasInfo,
OutputType,
ViewAndMutationMeta,
)
from .utils import strict_zip
zip = strict_zip
def remove_dupe_metadata(
m: ViewAndMutationMeta,
keep_arg_mask: List[bool],
add_dupe_map: List[int],
) -> ViewAndMutationMeta:
assert len(m.input_info) == len(keep_arg_mask)
# Easy invariant: the first argument should never be a dupe (it will be kept)
assert len(keep_arg_mask) > 0 and keep_arg_mask[0]
# Filter dupe'd mutated inputs out of traced_tangents
num_data_mutations = len([x for x in m.input_info if x.mutates_data])
other_traced_tangents = m.traced_tangents[num_data_mutations:]
inp_traced_tangents = m.traced_tangents[:num_data_mutations]
filtered_inp_traced_tangents = [
# See Note [Tangents must be contiguous]
x
for i, x in enumerate(inp_traced_tangents)
if keep_arg_mask[m.mutated_inp_runtime_indices[i]]
]
traced_tangents = filtered_inp_traced_tangents + other_traced_tangents
return ViewAndMutationMeta(
input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]],
# For outputs that are views of inputs, we store the index of the input that the output
# was generated from. Need to update that index to account for removed dupes.
output_info=[
OutputAliasInfo(
output_type=o.output_type,
raw_type=o.raw_type,
dynamic_dims=o.dynamic_dims,
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor,
)
for o in m.output_info
],
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
# We are guaranteed not to get here, since dupes are not supported today with subclass inputs.
subclass_inp_meta=[],
subclass_fw_graph_out_meta=[],
subclass_tangent_meta=[],
is_train=m.is_train,
)
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
# after adding synthetic base arguments to the function.
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
# and updating it with our synthetic base calling convention.
#
# When config.debug_assert is set, we automatically regenerate the metadata
# and compare it to this output for sanity.
#
# In addition to the updated metadata, also return the list of input indices
# that will need to be updated in the synthetic base epilogue
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
# after adding synthetic base arguments to the function.
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
# and updating it with our synthetic base calling convention.
#
# When config.debug_assert is set, we automatically regenerate the metadata
# and compare it to this output for sanity.
#
# In addition to the updated metadata, also return the list of input indices
# that will need to be updated in the synthetic base epilogue
def create_synthetic_base_metadata(
m: ViewAndMutationMeta,
# Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a
# synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)
synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]],
outer_args: List[Any],
inner_args: List[Any],
) -> Tuple[ViewAndMutationMeta, List[int]]:
# maps inner arg indices to outer arg indices
synthetic_base_to_indices: Dict[int, List[int]] = {}
for inner_idx in range(len(inner_args)):
outer_aliased_indices_of_current_base_arg = [
outer_idx
for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info)
if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx)
or (
isinstance(inner_idx_or_tuple, tuple)
and inner_idx_or_tuple[0] == inner_idx
)
]
synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg
# given the requires_grad info on mutated inputs,
# generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases.
input_infos = []
for outer_indices in synthetic_base_to_indices.values():
# leaf-ness should be all-or-nothing for aliased tensor.
# (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf)
any_leaf = any(m.input_info[x].is_leaf for x in outer_indices)
all_leaf = all(m.input_info[x].is_leaf for x in outer_indices)
assert any_leaf == all_leaf
mutates_data = (
True
if len(outer_indices) > 1
else m.input_info[outer_indices[0]].mutates_data
)
mutates_metadata = (
False
if len(outer_indices) > 1
else m.input_info[outer_indices[0]].mutates_metadata
)
requires_grad = any(m.input_info[x].requires_grad for x in outer_indices)
mutations_hidden_from_autograd = all(
m.input_info[x].mutations_hidden_from_autograd for x in outer_indices
)
mutations_under_no_grad_or_inference_mode = all(
m.input_info[x].mutations_under_no_grad_or_inference_mode
for x in outer_indices
)
mutation_inductor_storage_resize = all(
m.input_info[x].mutation_inductor_storage_resize for x in outer_indices
)
inpt_info = InputAliasInfo(
# If len(outer_indices) > 1, then this input is a synthetic base.
# The invariant is that to the rest of aot autograd, synthetic bases only show up if
# one of their aliases gets a data mutation. And if any of their aliases get metadata
# mutations, they will be hidden from the rest of aot autograd.
mutates_data=mutates_data,
mutates_metadata=mutates_metadata,
mutations_hidden_from_autograd=all(
m.input_info[x].mutations_hidden_from_autograd for x in outer_indices
),
mutates_storage_metadata=False
if len(outer_indices) > 1
else m.input_info[outer_indices[0]].mutates_storage_metadata,
mutations_under_no_grad_or_inference_mode=mutations_under_no_grad_or_inference_mode,
mutation_inductor_storage_resize=mutation_inductor_storage_resize,
is_leaf=any_leaf,
requires_grad=requires_grad,
keep_input_mutations=m.keep_input_mutations,
)
input_infos.append(inpt_info)
# Find any inputs that fulfill the following criteria:
# (1) They are part of a synthetic base (because they alias another input,
# and at least one input experiences a data mutation)
# (2) They experience a metadata mutation
outer_aliased_arg_idx_with_metadata_mutations = [
outer_idx
for outer_idx, inpt_info in enumerate(m.input_info)
if inpt_info.mutates_metadata
and not isinstance(synthetic_base_info[outer_idx], int)
]
# grab the original requires grad info on the outputs, except the ones from the mutated inputs
input_metadata_output_info = [
OutputAliasInfo(
output_type=OutputType.alias_of_input,
raw_type=FunctionalTensor,
dynamic_dims={
i
for i, s in enumerate(outer_args[outer_idx].shape)
if not is_concrete_int(s)
},
base_idx=synthetic_base_info[outer_idx][0], # type: ignore[index]
requires_grad=outer_args[outer_idx].requires_grad,
)
for outer_idx in outer_aliased_arg_idx_with_metadata_mutations
]
existing_output_infos = []
for o in m.output_info:
new_base_idx = (
None
if o.base_idx is None
else (
synthetic_base_info[o.base_idx]
if isinstance(synthetic_base_info[o.base_idx], int)
else synthetic_base_info[o.base_idx][0] # type: ignore[index]
)
)
# If base_idx is changed for OutputType.is_input, we need to update the output type to reflect the change
new_output_type = (
OutputType.alias_of_input
if o.output_type == OutputType.is_input and o.base_idx != new_base_idx
else o.output_type
)
existing_output_infos.append(
OutputAliasInfo(
output_type=new_output_type,
raw_type=o.raw_type,
dynamic_dims=o.dynamic_dims,
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
base_idx=new_base_idx, # type: ignore[arg-type]
requires_grad=o.requires_grad,
functional_tensor=o.functional_tensor,
)
)
inner_mutated_tangents = [
# See Note [Tangents must be contiguous]
coerce_tangent(x)
for inner_idx, x in enumerate(inner_args)
if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
]
output_info = existing_output_infos + input_metadata_output_info
# Regenerate traced tangents to include mutated inputs including synthetic bases
traced_tangents = (
inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents) :]
)
return (
ViewAndMutationMeta(
input_info=input_infos,
output_info=output_info,
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
# We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs.
subclass_inp_meta=[],
subclass_fw_graph_out_meta=[],
subclass_tangent_meta=[],
is_train=m.is_train,
),
outer_aliased_arg_idx_with_metadata_mutations,
)
def _get_last_mem_address(x):
out = x.storage_offset()
for size, stride in zip(x.size(), x.stride()):
out += (size - 1) * stride
return out
# Assumption: x and y are known to share a storage, and we are trying to determine
# if their memory is actually completely disjoint, based on sizes/strides/storage_offset
def _tensors_definitely_do_not_overlap(x, y):
if x is y:
return False
if x.numel() == 0 or y.numel() == 0:
return True
# Make x always on the left
if x.storage_offset() > y.storage_offset():
x, y = y, x
# Short-circuit in the "obvious" overlapping case: both tensors are contiguous
if x.is_contiguous() and y.is_contiguous():
if x.storage_offset() + x.numel() > y.storage_offset():
# definitely overlap
return False
else:
# definitely no overlap
return True
# Short-circuit: if last memory address of x is < start of y, then not overlapping.
x_last = _get_last_mem_address(x)
if x_last < y.storage_offset():
return True
if x.dim() == 2 and y.dim() == 2 and x.stride(1) == 1 and y.stride(1) == 1:
# This cases is needed for the shampoo optimizer.
# All tensors are 2d (non-contiguous), have the same outer stride, and have an inner stride of 1
# (so rows are contiguous)
if x.stride(0) == y.stride(0):
offset_delta = y.storage_offset() - x.storage_offset()
if offset_delta < x.size(1):
# definitely overlaps (row 0 of y overlaps with row 0 of x)
# Example:
# base = torch.arange(32).reshape(4, 8)
# x = base.narrow(1, 0, 4)
# x: size=(4, 4), stride=(8, 1), offset=0
# y = base.narrow(1, 3, 4)
# y: size=(4, 4), stride=(8, 1), offset=3
return False
x_total_elems_covered = x.stride(0) * (x.size(0) - 1) + x.size(1)
if x_total_elems_covered <= offset_delta:
# definitely does not overlap (last byte of x is before start of y)
# Example:
# x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27)
# y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28)
return True
# At this point, we want to check if the 0th row of y
# overlaps with **some** row of x.
# We can check this by shifting y backward by the shared stride, repeatedly,
# until the first row of y is before the first row of x.
# Then we can check if these rows overlap.
# We can accomplish this by modding our offset by the stride.
offset_delta_mod = offset_delta % x.stride(0)
# Example:
# 0 1 2 3
# 9 10 11 12
# 18 19 20 21
# 27 28 29 30
# x: size=(4, 4), stride=(9, 1), offset=0
# y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap)
# y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap)
# y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap)
# y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap)
# If the interval [modded_offset, modded_offset + x_size] falls entirely
# without
if offset_delta_mod + y.size(1) <= x.stride(0):
return True
return False
def compute_overlapping_inputs(fwd_inputs, aliased_input_indices):
max_aliased_inps_w_dyn_shapes = (
config._max_aliased_inputs_with_dynamic_shapes_enabled
)
definitely_error_on_dyn_shapes = False
# If the JK is false / not set, we will fall back to obeying the config above
# If it is true, we will always error when there are aliased + mutated inps with dynamic shapes
if torch._inductor.config.is_fbcode():
definitely_error_on_dyn_shapes = torch._utils_internal.justknobs_check(
"pytorch/dynamo:disable_aliased_inputs_with_mutation_and_dyn_shapes"
)
actual_aliased_indices = set()
num_aliases = len(aliased_input_indices)
# > 2 check because num_aliases==1 means no aliasing
if num_aliases >= 2 and (
definitely_error_on_dyn_shapes or num_aliases > max_aliased_inps_w_dyn_shapes
):
dynamic_shape_indices = set()
for j in range(num_aliases):
j_ = aliased_input_indices[j]
curr_inp = fwd_inputs[j_]
if any(
isinstance(x, torch.SymInt)
for x in itertools.chain(
curr_inp.shape, curr_inp.stride(), [curr_inp.storage_offset()]
)
):
dynamic_shape_indices.add(j_)
assert (
len(dynamic_shape_indices) == 0
), f"""\
Encountered a graph where:
- {num_aliases} graph inputs all share the same storage (input indices: {str(aliased_input_indices)})
- at least one of these aliased inputs was mutated
- at least one of these inputs is being compiled with dynamic shapes (indices: {str(dynamic_shape_indices)})
Current limit: {str(max_aliased_inps_w_dyn_shapes)}
Killswitch enabled: {str(definitely_error_on_dyn_shapes)}
The most common way to run into this situation is when your model parameters are allocated as one giant buffer
and are all mutated by the optimizer, and some of your parameters end up getting compiled with dynamic shapes.
You can avoid this problem by marking your parameters so they explicitly do not participate in dynamic shapes,
by marking each dim of your parameter static:
torch._dynamo.mark_static(param, 0) # (1, 2, ... for every dimension on the parameter).
If you are running into this issue in a situation where your parameters are static but some other inputs
are aliased and mutated, and they should be dynamic, please file an issue.
"""
for j in range(num_aliases):
for i in range(j):
j_ = aliased_input_indices[j]
i_ = aliased_input_indices[i]
if not _tensors_definitely_do_not_overlap(fwd_inputs[i_], fwd_inputs[j_]):
actual_aliased_indices.add(i_)
actual_aliased_indices.add(j_)
return actual_aliased_indices
def _graph_input_names(gm):
return [node.name for node in gm.graph.find_nodes(op="placeholder")]
def _graph_output_names(gm):
output_node = next(iter(reversed(gm.graph.nodes)))
assert output_node.op == "output" and len(output_node.args) == 1
return_args = output_node.args[0]
return [getattr(return_arg, "name", None) for return_arg in return_args]
def create_graph_signature(
fx_g: torch.fx.GraphModule,
fw_metadata: ViewAndMutationMeta,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
*,
user_args_flat: List[Tensor],
params_and_buffers_flat: List[Tensor],
param_names: List[str],
buffer_names: List[str],
trace_joint: bool,
num_user_fw_outs: Optional[int],
loss_index: Optional[int],
) -> GraphSignature:
# Retrieve graph input names
graph_input_names = _graph_input_names(fx_g)
# Retrieve graph output names
graph_output_names = _graph_output_names(fx_g)
num_params_buffers = len(param_names) + len(buffer_names)
num_tokens = len(fw_metadata.tokens)
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
# Such that # graph inps = # user inps + # params + # buffers
num_user_args = len(graph_input_names) - num_params_buffers - num_tokens
if trace_joint:
assert num_user_fw_outs is not None
num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inp_runtime_indices
backward_output_names = graph_output_names[num_fw_outs:]
grad_index = itertools.count(0)
gradients_to_parameters = {
backward_output_names[next(grad_index)]: param_names[i]
for i, param in enumerate(params_and_buffers_flat)
if param.requires_grad
}
gradients_to_user_inputs = {
backward_output_names[next(grad_index)]: graph_input_names[
i + len(params_and_buffers_flat)
]
for i, user_input in enumerate(user_args_flat)
if user_input.requires_grad
}
assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len(
backward_output_names
)
# Check that we have fully accounted for all graph outputs
backward_signature = BackwardSignature(
gradients_to_parameters,
gradients_to_user_inputs,
graph_output_names[loss_index],
)
else:
backward_signature = None
num_user_fw_outs = (
len(graph_output_names)
- fw_metadata.num_mutated_inp_runtime_indices
- num_tokens
)
return GraphSignature.from_tracing_metadata(
in_spec=in_spec,
out_spec=out_spec,
graph_input_names=graph_input_names,
graph_output_names=graph_output_names,
view_mutation_metadata=fw_metadata,
named_parameters=param_names,
named_buffers=buffer_names,
num_user_inputs=num_user_args,
num_user_outputs=num_user_fw_outs,
loss_index=loss_index,
backward_signature=backward_signature,
)

View File

@ -0,0 +1,784 @@
# mypy: allow-untyped-defs
"""
Functions in this module do most of the "work" of AOTAutograd.
An aot_dispatch_* function:
- Takes in the input flat_fn, flat_args, and some metadata
- Runs a set of pre compile wrappers (e.g. argument deduping)
- Runs the actual compiler
- Wraps the returned callable in a set of post compile wrappers
- Returns the wrapped callable and metadata.
"""
import itertools
import logging
import traceback
from contextlib import nullcontext
from typing import Any, Callable, List, Optional, Sequence, Tuple
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import lazy_format_graph_code
from torch._guards import CompileContext, TracingContext
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
from torch.multiprocessing.reductions import StorageWeakRef
from .. import config
from .autograd_cache import (
AOTAutogradCache,
AOTAutogradCacheEntry,
CompiledBackward,
CompiledForward,
)
from .dispatch_and_compile_graph import (
aot_dispatch_autograd_graph,
aot_dispatch_base_graph,
)
from .logging_utils import track_graph_compiling
from .runtime_wrappers import (
AOTDedupeWrapper,
AOTDispatchAutograd,
AOTDispatchSubclassWrapper,
AOTSyntheticBaseWrapper,
AutogradLazyBackwardCompileInfo,
CompilerWrapper,
DebugAssertWrapper,
EffectTokensWrapper,
FakifiedOutWrapper,
FunctionalizedRngRuntimeWrapper,
make_runtime_safe,
post_compile,
pre_compile,
RuntimeWrapper,
)
from .schemas import AOTConfig, MutationType, ViewAndMutationMeta
from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta
from .utils import _get_symint_hints, make_boxed_func, strict_zip, unlift_tokens
zip = strict_zip
log = logging.getLogger(__name__)
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
aten = torch.ops.aten
# Returns a Callable and a ViewAndMutationMeta.
# Currently, only export needs the ViewAndMutationMeta after this function.
DispatchReturn = Tuple[Callable, ViewAndMutationMeta]
def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]:
"""
Wrappers that run on every dispatch function
"""
return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)]
# Export's dispatching logic is unique in a few ways: it only needs the "graph"
# bits of aot_autograd, and doesn't need to do any specific wrapping.
def aot_dispatch_export(
flat_fn: Callable,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
needs_autograd: bool,
) -> DispatchReturn:
wrappers = _create_wrappers_for_dispatch(needs_autograd)
flat_fn, flat_args, fw_metadata = pre_compile(
wrappers,
flat_fn,
flat_args,
aot_config,
fw_metadata=fw_metadata,
)
if needs_autograd and not aot_config.pre_dispatch:
graph, _, _ = aot_dispatch_autograd_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
else:
graph, _, _ = aot_dispatch_base_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
# NB: the wrappers that run in pre_compile for export are
# either a no-op, because they're not needed, or will raise a runtime error,
# since they don't support export.
# We still run these wrappers to make sure that they're not needed pre compile,
# but we technically don't need to run them post compile at all here.
compiled_fn, fw_metadata = post_compile(
wrappers, graph, aot_config, runtime_metadata=fw_metadata
)
# Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph
# (either a joint or an inference-only graph)
assert isinstance(compiled_fn, torch.fx.GraphModule)
return compiled_fn, fw_metadata
def aot_dispatch_base(
flat_fn,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> DispatchReturn:
"""
Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler.
"""
wrappers = _create_wrappers_for_dispatch(needs_autograd=False)
flat_fn, flat_args, fw_metadata = pre_compile(
wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc]
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
fakified_out_wrapper = FakifiedOutWrapper()
(
fw_module,
updated_flat_args,
fw_metadata,
) = fakified_out_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper()
(
fw_module,
updated_flat_args,
fw_metadata,
) = functionalized_rng_wrapper.pre_compile(
fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata
)
disable_amp = torch._C._is_any_autocast_enabled()
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
compiler = (
aot_config.inference_compiler
if aot_config.inference_compiler is not None
else aot_config.fw_compiler
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = (
fw_metadata
if maybe_subclass_meta is None
else maybe_subclass_meta.fw_metadata
)
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw = compiler(fw_module, updated_flat_args)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
make_runtime_safe(fw_metadata, maybe_subclass_meta)
# However, RuntimeWrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not hasattr(compiled_fw, "_boxed_call"):
compiled_fw = make_boxed_func(compiled_fw)
# Create a wrapper to set up the rng functionalize and fakified out bits
compiled_fw = functionalized_rng_wrapper.post_compile(
compiled_fw, aot_config, runtime_metadata=fw_metadata
)
if config.enable_autograd_cache and aot_config.cache_key:
if fw_key := getattr(compiled_fw, "_fx_graph_cache_key", None):
entry = AOTAutogradCacheEntry(
compiled_fw=CompiledForward(fw_key),
compiled_bw=None,
runtime_metadata=fw_metadata,
dispatch_wrappers=wrappers,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
indices_of_inps_to_detach=[],
)
AOTAutogradCache.save(aot_config.cache_key, entry)
compiled_fw = fakified_out_wrapper.post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fw = EffectTokensWrapper().post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
# Why do we need to pass in num_fw_outs_saved_for_bw?
# See Note: [Partitioner handling for Subclasses, Part 2]
compiled_fw = AOTDispatchSubclassWrapper(
trace_joint=False,
# TODO: once we use pre_compile this will be flat_fn at the top of this function
fw_only=None,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=None,
).post_compile(
compiled_fw,
aot_config, # not used
runtime_metadata=fw_metadata,
)
if not hasattr(compiled_fw, "_boxed_call"):
compiled_fw = make_boxed_func(compiled_fw)
compiled_fn = RuntimeWrapper(
indices_of_inps_to_detach=[],
trace_joint=False,
disable_amp=disable_amp,
).post_compile(
compiled_fw,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fn = post_compile(
wrappers, compiled_fn, aot_config, runtime_metadata=fw_metadata
)
return compiled_fn
def collect_fw_donated_buffer_idxs(
fw_ins: List[Optional[FakeTensor]],
user_fw_outs: List[Optional[FakeTensor]],
bw_outs: List[Optional[FakeTensor]],
saved_tensors: List[FakeTensor],
) -> List[int]:
"""
Checks if the saved tensors are donated buffers, which means a saved tensor is not
an alias of any tensors in fw_ins, user_fw_outs, and bw_outs.
"""
storage_refs = set()
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
if isinstance(t, FakeTensor):
storage_refs.add(StorageWeakRef(t.untyped_storage()))
num_saved_tensor = len(saved_tensors)
donated_buffer_idxs = []
for i in range(num_saved_tensor):
t = saved_tensors[i]
if StorageWeakRef(t.untyped_storage()) not in storage_refs:
donated_buffer_idxs.append(i)
return donated_buffer_idxs
def collect_bw_donated_buffer_idxs(
fw_module: torch.fx.GraphModule,
bw_module: torch.fx.GraphModule,
fw_metadata: ViewAndMutationMeta,
) -> List[int]:
"""
Collects backward donated buffer indexes from fw_module and bw_module.
"""
fw_ins = fw_module.graph.find_nodes(op="placeholder")
bw_outs = next(reversed(bw_module.graph.find_nodes(op="output"))).args[0]
fw_outs = next(reversed(fw_module.graph.find_nodes(op="output"))).args[0]
fw_ins = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_ins]
fw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in fw_outs]
bw_outs = [n.meta["val"] if hasattr(n, "meta") else None for n in bw_outs]
user_fw_outs = fw_outs[: fw_metadata.num_forward]
saved_tensors = fw_outs[fw_metadata.tensors_saved_for_backwards_slice]
fw_donated_buffer = collect_fw_donated_buffer_idxs(
fw_ins,
user_fw_outs,
bw_outs,
saved_tensors,
)
assert fw_metadata.num_symints_saved_for_bw is not None
return [fw_metadata.num_symints_saved_for_bw + i for i in fw_donated_buffer]
def aot_dispatch_autograd(
flat_fn,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> DispatchReturn:
"""
Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers,
and returns a wrapped torch.autograd.Function with a forward and backward.
"""
wrappers = _create_wrappers_for_dispatch(needs_autograd=True)
flat_fn, flat_args, fw_metadata = pre_compile(
wrappers,
flat_fn,
flat_args,
aot_config,
fw_metadata=fw_metadata,
)
fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
# Copied from aot_dispatch_autograd_graph.
disable_amp = torch._C._is_any_autocast_enabled()
if aot_config.enable_log:
aot_joint_log.info(
"%s",
lazy_format_graph_code(
"Joint graph",
fx_g,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
trace_structured(
"aot_joint_graph",
payload_fn=lambda: fx_g.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
with torch.no_grad():
inner_meta = (
fw_metadata
if maybe_subclass_meta is None
else maybe_subclass_meta.fw_metadata
)
with track_graph_compiling(aot_config, "joint"):
# See Note: [Partitioner handling for Subclasses, Part 1]
# See Note: [Recomputing subclass mutation handling]
mutated_inp_runtime_indices = (
compute_inner_mutated_inp_indices_from_subclass_meta(
fw_metadata, inner_meta
)
)
num_tokens = len(fw_metadata.tokens)
num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices)
num_inner_fwd_outputs = (
num_mutated_inp_runtime_indices
+ inner_meta.num_outputs
+ inner_meta.num_intermediate_bases
+ inner_meta.num_outputs_rng_offset
+ num_tokens # See Note [Side-Effectful Tokens in AOTAutograd]
)
fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
# See Note [Side-Effectful Tokens in AOTAutograd]
if config.unlift_effect_tokens and (
num_tokens > 0 or fw_metadata.num_backward_tokens > 0
):
unlift_tokens(fw_module, fw_metadata, aot_config, bw_module)
num_inner_fwd_outputs -= num_tokens
joint_inputs = (
joint_inputs[0][num_tokens:],
joint_inputs[1],
)
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output
fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw)
symint_outs_saved_for_bw = [
n for n in fw_outs_saved_for_bw if is_sym_node(n)
]
fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
if torch._functorch.config.donated_buffer:
fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs(
fw_module,
bw_module,
inner_meta,
)
inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs
if aot_config.enable_log:
aot_graphs_log.info(
"aot_config id: %s, fw_metadata=%s, inner_meta=%s",
str(aot_config.aot_id),
str(fw_metadata),
str(inner_meta),
)
# Note [Detaching inputs that never need gradients]
# See https://github.com/pytorch/pytorch/issues/97745
# Suppose we have a function like this that we want to compile:
#
# def f(x, y):
# return torch.mul(x, y.detach())
#
# What gradients should we compute for x and y?
# By default, AOTAutograd will compute a gradient for **every** input that requires gradients,
# and so we'll compute:
# x_grad_input = y
# y_grad_input = None
# Does this preserve the semantics of eager mode?
# Unfortunately, no.
# Doing the above will cause autograd to **continue** to backprop the autograd tape
# that was generated from constructing y.
#
# This is **different** from what would have happened in eager mode.
# In eager mode, if we backprop through the output of this function, autograd will only traverse
# the bit of the autograd tape corresponding to "x".
# In particular, if a user had previously backpropped through y's autograd tape,
# And then they try to backprop through the output of the above function,
# then we'll hit the dreaded "Trying to backward through the graph a second time" error.
#
# You might think: If autograd sees that a gradient is None, shouldn't it stop early,
# instead of continuing the backprop through the ancestors of that node in the graph?
#
# Autograd has two passes:
# (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed
# (2) a second pass that actually goes ahead and executes each node when it becomes ready,
# propagating gradients
# By the time we're executing a node and we see that it produces a None, the set of nodes to execute
# is already locked-in.
#
# The fix: instead, we can recognize statically that the graph we're compiling will never contribute
# gradients to y, and prevent autograd from trying to traverse y's autograd tape at all.
# We can do this by manually detach'ing y before sending it through the `CompiledFunction`.
#
# Note that this solution is not bulletproof.
# It's possible to construct a case where eager may or may not have have tried to autograd through y,
# depending on the actual grad_outputs that were passed in during the backward.
# There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`,
# allowing autograd to re-use the graph.
#
# An example of this case is:
# def f(x):
# return x.detach() * 2, x * 3
# If we were to only backprop through outs[0], in eager, we would stop
# If we backward only on the first output, we shouldn't send a grad through x.
# But the custom autograd function doesn't know that: it will materialize zero grads for x * 3
# and we will end up with a zero grad at x.
# If we later backprop through the second output, this will also require backprop'ing through x.
# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
_indices_of_inps_to_detach: List[int] = []
# reversed() since we expect output at end of graph
bw_output = next(reversed(bw_module.graph.find_nodes(op="output")))
bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment]
# TODO: we should apply the below "detach inputs if their gradients are statically known to be None"
# optimization even if we have subclass inputs/outputs (we do not handle this today).
# Computing which our our inputs get None gradients is a bit more complicated,
# if any of our inputs are subclasses. Why?
# (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses.
# (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors,
# so we need to figure out which subclass fw inputs they map to.
if maybe_subclass_meta is None:
num_backward_tokens: int = inner_meta.num_backward_tokens
assert (
len(bw_outs)
== len(fw_metadata.input_info)
+ inner_meta.num_outputs_rng_offset
+ num_backward_tokens
)
bw_outs_no_rng_no_tokens = bw_outs
if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0:
bw_outs_no_rng_no_tokens = bw_outs[
: -(inner_meta.num_outputs_rng_offset + num_backward_tokens)
]
assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info)
for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens):
# If our input experiences a metadata mutation inside the graph (e.g. set_()),
# we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation
metadata_mutation_in_graph = (
fw_metadata.input_info[i].mutation_type
== MutationType.MUTATED_IN_GRAPH
and fw_metadata.input_info[i].mutates_storage_metadata
)
is_non_leaf = (
fw_metadata.input_info[i].requires_grad
and not fw_metadata.input_info[i].is_leaf
)
if bw_out is None and not metadata_mutation_in_graph and is_non_leaf:
_indices_of_inps_to_detach.append(i)
if aot_config.enable_log:
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Forward graph",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
aot_graphs_log.info(
"%s",
lazy_format_graph_code(
"Backward graph",
bw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
trace_structured(
"aot_forward_graph",
payload_fn=lambda: fw_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
trace_structured(
"aot_backward_graph",
payload_fn=lambda: bw_module.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
# AMP is already traced out in joint graph. we do not wish to reapply it accidentally
# in the compiler.
with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
adjusted_flat_args = joint_inputs[0]
fakified_out_wrapper = FakifiedOutWrapper()
(
fw_module,
adjusted_flat_args,
fw_metadata,
) = fakified_out_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper(
return_new_outs=False
)
(
fw_module,
adjusted_flat_args,
fw_metadata,
) = functionalized_rng_wrapper.pre_compile(
fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata
)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.fw_metadata = inner_meta
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
if not hasattr(compiled_fw_func, "_boxed_call"):
compiled_fw_func = make_boxed_func(compiled_fw_func)
if fakified_out_wrapper.needs_post_compile:
fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
compiled_fw_func = EffectTokensWrapper().post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
compiled_fw_func = AOTDispatchSubclassWrapper(
fw_only=None,
trace_joint=False,
maybe_subclass_meta=maybe_subclass_meta,
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw,
).post_compile(
compiled_fw_func,
aot_config, # not used
runtime_metadata=fw_metadata,
)
compiled_fw_func = functionalized_rng_wrapper.post_compile(
compiled_fw_func, aot_config, runtime_metadata=fw_metadata
)
compiled_fw_func = fakified_out_wrapper.post_compile(
compiled_fw_func,
aot_config,
runtime_metadata=fw_metadata,
)
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
placeholder_list = fx_placeholder_vals(bw_module)
forward_saved_for_backwards_strides = None
if fwd_output_strides is not None:
forward_saved_for_backwards_strides = fwd_output_strides[
inner_meta.tensors_saved_for_backwards_slice
]
# saved activations can have different stride to eager if
# the compiler does layout optimization. We should restride the
# tensor passed in for compiling the backward graph using the
# saved tensor's stride.
for i in range(len(placeholder_list)):
ph_arg = placeholder_list[i]
if not isinstance(ph_arg, torch.Tensor):
continue
if forward_saved_for_backwards_strides is None:
continue
real_stride = None
# Per all_args calling convention
j = i - num_symints_saved_for_bw
if 0 <= j < len(forward_saved_for_backwards_strides):
real_stride = forward_saved_for_backwards_strides[j]
if real_stride is None:
continue
# Comparing ph_arg.stride() with real_stride directly may
# cause dynamic dimensions in ph_arg being specialized to static
# value. Using the hints to avoid that.
if _get_symint_hints(ph_arg.stride()) != real_stride:
# Note that here we use the stride of the real tensor to
# restride a FakeTensor. This does not cause trouble
# for dynamic shape since this code path only get
# executed if layout optimization is enabled. And we
# disable layout optimization for dynamic shape right
# now.
#
# A solution that decide stride order based on real
# tensor's stride and then apply that stride order to
# the FakeTensor does not work smoothly since some
# tensor's layout is not 'dense'. E.g. mixnet_l has a
# tensor with size [8, 64, 112, 112] and strides
# (2408448, 1, 21504, 192). The solution mentioned will
# decide a stride of (802816, 1, 7168, 64) for this
# tensor which is wrong.
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
compiled_bw_func = None
if num_symints_saved_for_bw > 0:
try:
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
)
except Exception as e:
exc = e
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "eager_compile_backwards_failure",
"encoding": "string",
},
payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
)
log.warning(
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True,
)
# Compiled autograd will run the bw_module in the backward pass,
# so recompilation need happen anyway if the backward pass is ever
# called.
#
# The reason we do the GraphModule recompilation here is because
# the lazy recompilation will cause issue in the backward pass
# with compiled autograd.
#
# Do the _LazyGraphModule.force_recompile here rather than when
# bw_module is first generated by the partitioner because the bw_module.recompile
# may be called in some code path later and cause the _LazyGraphModule.forward
# becomes the lazy version again. One example is when dynamic shape is enabled
# upfront, the bw_compiler will be called above which can cause extra
# graph module recompilation on bw_module.
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(bw_module)
saved_context = TracingContext.try_get()
saved_compile_context = CompileContext.try_get()
backward_state_indices = [
idx for idx, x in enumerate(flat_args) if isinstance(x, BackwardState)
]
assert len(backward_state_indices) <= 1
lazy_backward_info = AutogradLazyBackwardCompileInfo(
bw_module,
placeholder_list,
saved_context,
saved_compile_context,
)
make_runtime_safe(fw_metadata, maybe_subclass_meta)
try_save_cache_entry: Optional[Callable] = None
if config.enable_autograd_cache:
def try_save_cache_entry(compiled_bw_func, _fw_metadata): # noqa: F811
fw_key = getattr(compiled_fw_func, "_fx_graph_cache_key", None)
bw_key = getattr(compiled_bw_func, "_fx_graph_cache_key", None)
if aot_config.cache_key and fw_key and bw_key:
entry = AOTAutogradCacheEntry(
CompiledForward(fw_key),
CompiledBackward(
bw_key, backward_state_indices, num_symints_saved_for_bw
),
_fw_metadata,
wrappers,
maybe_subclass_meta,
num_fw_outs_saved_for_bw,
_indices_of_inps_to_detach,
)
AOTAutogradCache.save(aot_config.cache_key, entry)
if compiled_bw_func is not None:
# If we already compiled it we can just run it right now without waiting
try_save_cache_entry(compiled_bw_func, fw_metadata)
try_save_cache_entry = None
compiled_fn = AOTDispatchAutograd.post_compile(
compiled_fw_func,
compiled_bw_func,
maybe_subclass_meta,
num_symints_saved_for_bw,
backward_state_indices,
disable_amp,
_indices_of_inps_to_detach,
lazy_backward_info,
aot_config,
fw_metadata=fw_metadata,
try_save_cache_entry=try_save_cache_entry,
)
if config.debug_assert:
flat_requires_grad: List[Optional[bool]] = [
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
]
compiled_fn = DebugAssertWrapper(
flat_requires_grad=flat_requires_grad
).post_compile(compiled_fn, aot_config, runtime_metadata=fw_metadata)
compiled_fn = post_compile(
wrappers,
compiled_fn,
aot_config,
runtime_metadata=fw_metadata,
)
return compiled_fn

View File

@ -0,0 +1,147 @@
# mypy: allow-untyped-defs
"""
Contains utils for logging in AOTAutograd, including managing the names of the graphs under
compilation, capturing user-friendly tracebacks, and debug messages.
"""
import collections
from contextlib import contextmanager
from typing import List, Tuple
import torch
import torch.fx.traceback as fx_traceback
# This is a list since looking forward, we can have this arbitrarily nested.
graph_being_compiled: List[str] = []
# TODO: It would be nice to reset the numbering every time aot_id goes
# up, but this is annoying to do right now (because we don't know if
# an aot_id will come back from the dead), so right now this also happens
# to be a globally unique number too (at the cost of wobbling if you change
# how the graphs compile)
nth_graph: int = 0
model_name: str = "model"
def set_model_name(name):
global model_name
model_name = name
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
return list(graph_being_compiled), model_name, nth_graph
def get_aot_graph_name() -> str:
"""
Returns the name of the graph being compiled.
"""
global model_name, graph_being_compiled, nth_graph
return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}"
get_graph_being_compiled = get_aot_graph_name
@contextmanager
def track_graph_compiling(aot_config, graph_name):
global graph_being_compiled
# TODO: Don't shove the aot_id in here; set it in the context
graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
old_name = None
if tracing_context := torch._guards.TracingContext.try_get():
old_name = tracing_context.aot_graph_name
tracing_context.aot_graph_name = graph_being_compiled
has_tracing_context = True
else:
has_tracing_context = False
try:
yield
finally:
global nth_graph
nth_graph += 1
graph_being_compiled = []
if has_tracing_context:
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.aot_graph_name = old_name
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque() # type: ignore[var-annotated]
for node in roots:
if node is not None and node not in seen:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_, seq_nr):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback( # type: ignore[attr-defined]
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
fx_traceback.set_grad_fn_seq_nr(seq_nr)
return prehook
def get_posthook(special_stack_, seq_nr):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
fx_traceback.reset_grad_fn_seq_nr()
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr()))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to multiple use of tensor around:"
)
node.register_hook(get_posthook(special_stack, node._sequence_nr()))
def describe_input(i, aot_config):
if i < aot_config.num_params_buffers:
return f"parameter/buffer {i}"
else:
return f"input {i - aot_config.num_params_buffers}"
def format_guard_bug_msg(aot_config, expected):
return (
f"At compilation time, graph {aot_config.aot_id} was compiled under the "
f"assumption that {expected}, but at runtime this was not the case. "
"This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch."
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,833 @@
# mypy: allow-untyped-defs
"""
The various dataclasses, Enums, namedtuples etc used in AOTAutograd. This includes
input/output types, metadata, config, function signatures etc.
"""
import collections
import dataclasses
import functools
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union
import torch
import torch.utils._pytree as pytree
from torch._guards import Source
from torch._ops import OpOverload
from torch._subclasses import FakeTensor
from torch._subclasses.fake_tensor import is_fake
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config
from .functional_utils import (
_check_if_mutation_can_be_in_graph,
FunctionalTensorMetadataEq,
)
from .utils import strict_zip
zip = strict_zip
OutputType = Enum(
"OutputType",
(
# output is not an alias
"non_alias",
# output aliases an input
"alias_of_input",
# output **is** an input tensor
"is_input",
# output has a ._base tensor, which is a graph intermediate.
# We need to return its ._base as a graph output,
# so its requires_grad info is populated correctly.
# Instructs the runtime code to regenerate the current output
# from a base tensor, graph_intermediates[base_idx]
"alias_of_intermediate_save_as_output",
# Same as above; but we don't need to explicitly add its ._base
# as a graph output, because it already **is** a graph output.
"alias_of_intermediate",
# Same as above; but the output's ._base is **already** a user output.
# Instructs the runtime code to regenerate the current output from
# a base tensor, user_outputs[base_idx]
"alias_of_intermediate_base_is_user_output",
# See Note [Intermediate Bases Optimization]
"unsafe_view_alias",
# output is an alias, but has a custom autograd.Function backward.
# In this case, we don't want to do view-replay, since we won't be able to replay the custom function.
# Instead, we'll treat this output "normally", and trace its backward into the graph.
"custom_function_view",
),
)
# This class stores info about every user output.
@dataclass(frozen=True)
class OutputAliasInfo:
# Tells us if this output is:
# (1) a regular (non-aliased) output
# (2) an alias of a forward input
# (3) **is** a forward input (special case of "alias_of_input")
# (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
# (5) an alias of an intermediate, that explicitly requires returning the intermediate
# as a graph output
# (6) an alias of an intermediate, where that intermediate is also a user output
output_type: OutputType
# The raw type of the output (torch.Tensor, SymInt, etc)
raw_type: type
# If (1) above, then
# - base_idx is None
# If (2) or (3) above, then
# - Tells us that the base of this alias is user_fwd_input[base_idx]
# (This is an index into the inputs *before* we make synthetic bases)
# If (4) or (5) above, then
# - Tells us that the base of this alias is output_graph_intermediates[base_idx]
# here, this refers to the index of the *direct* traced
# If (6) above, then:
# - Tells us that the base of this alias is output_user_fwds[base_idx]
# here, this refers to the index of the *direct* traced
base_idx: Optional[int]
# If it is a Tensor, what the dynamic dims are (otherwise is None)
dynamic_dims: Optional[Set[int]]
# requires_grad
requires_grad: bool
# FunctionalTensorWrapper that represents this output.
#
# Provides us the means to replay views from it.
#
# We need to wrap the actual FunctionalTensorWrapper with this class so that
# we only compare the tensor's metadata. That's because with the transformations
# of the model throughout AOTAutograd, the sequence of ViewMeta and the base
# tensor might change.
functional_tensor: Optional[FunctionalTensorMetadataEq] = None
class MutationType(Enum):
NOT_MUTATED = 1
MUTATED_IN_GRAPH = 2
MUTATED_OUT_GRAPH = 3
# This class tells us info about user inputs.
@dataclass(frozen=True)
class InputAliasInfo:
is_leaf: bool
mutates_data: bool
mutates_metadata: bool
mutations_hidden_from_autograd: bool
mutations_under_no_grad_or_inference_mode: bool
mutation_inductor_storage_resize: bool
mutates_storage_metadata: bool
requires_grad: bool
keep_input_mutations: bool
def __post_init__(self):
if self.mutates_storage_metadata:
# For convenience, we guarantee that this is always true.
# In practice, If we call .set_(), then at runtime there is no need
# to additionally fix up the tensor metadata, since our runtime
# call to inp.set_(updated_inp) will already have the right metadata
assert self.mutates_metadata
@functools.cached_property
def mutation_type(self) -> MutationType:
if (
(not self.mutates_data)
and (not self.mutates_metadata)
and not (self.mutation_inductor_storage_resize)
):
return MutationType.NOT_MUTATED
if _check_if_mutation_can_be_in_graph(
self.keep_input_mutations,
self.mutates_data,
self.mutates_metadata,
self.mutations_hidden_from_autograd,
self.mutations_under_no_grad_or_inference_mode,
self.mutates_storage_metadata,
self.mutation_inductor_storage_resize,
self.requires_grad,
):
return MutationType.MUTATED_IN_GRAPH
return MutationType.MUTATED_OUT_GRAPH
@dataclass
class SubclassCreationMeta:
"""
Used for AOTDispatch.
This dataclass gives us the information we need to reconstruct a tensor subclass
from our flat inputs.
Why is this important? The graph that we'd like to trace out contains flat tensor inputs,
But the user's original model may have subclass inputs and outputs.
So we need to wrap/unwrap subclasses as necessary to translate between the user's
view (subclass inps/outs), and the backend compiler's view (graph with no subclass args).
Complications arise mostly from the fact that a subclass can hold more than one inner tensor;
So for a given subclass input/output, we need to carefully track which indices map
to the subclass tensor in the corresponding "dense-tensor-only" graph.
"""
# In the inner graph that only takes in dense tensor inputs,
# this maps to the first index of "tensors that should go in this subclass wrapper"
flat_tensor_start_idx: int
# arg_count is inclusive of the arg_counts of any
# inner tensor subclasses: If I have a TwoTensor and
# both of its inner elements are TwoTensors, then the
# arg_count of the outer-most sublass will be 4
arg_count: int
# meta and attrs are produced by the subclass's __tensor_flatten__.
# We need to keep them around along with outer_size / outer_stride to plumb them
# into __tensor_unflatten__
attrs: Dict[str, Union["SubclassCreationMeta", None]]
outer_size: List[int]
outer_stride: List[int]
meta: Any
# Stores the original subclass itself.
# This is needed because we need the autograd metadata on the original subclass
# (this is guaranteed to be a wrapper subclass that holds a fake tensor,
# so holding onto this at runtime shouldn't leak memory)
# This field is nulled out after calling make_runtime_safe()
original_subclass: Optional[torch.Tensor]
# Used at runtime to determine the subclass type, so we don't need to save the original subclass
original_subclass_type: Optional[type] = None
def creation_fn(self, all_args, *, is_runtime: bool):
inner_tensors = {}
curr_start_idx = self.flat_tensor_start_idx
for attr, creation_meta in self.attrs.items():
if creation_meta is None:
subclass = all_args[curr_start_idx]
curr_start_idx += 1
else:
subclass = creation_meta.creation_fn(all_args, is_runtime=is_runtime)
curr_start_idx += creation_meta.arg_count
inner_tensors[attr] = subclass
if is_runtime:
assert self.original_subclass_type is not None
original_subclass_type = self.original_subclass_type
else:
original_subclass_type = type(self.original_subclass)
rebuilt = original_subclass_type.__tensor_unflatten__( # type: ignore[attr-defined]
inner_tensors, self.meta, self.outer_size, self.outer_stride
)
if not is_runtime:
# After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper
# has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass.
# We don't trace through the autograd engine at runtime though, so no need
# to compute this extra metadata then!
torch._mirror_autograd_meta_to(self.original_subclass, rebuilt) # type: ignore[attr-defined]
return rebuilt
def make_runtime_safe(self):
assert self.original_subclass is not None
self.original_subclass_type = type(self.original_subclass)
self.original_subclass = None
# Recurse on nested subclass info
for creation_meta in self.attrs.values():
if creation_meta is not None:
creation_meta.make_runtime_safe()
def __post_init__(self):
# sanity assert to make sure we don't leak memory
assert is_fake(self.original_subclass)
# This saves the type of subclass nested structure to compare
# against runtime tangent inputs. We do wanna compute this at AOT
# time as it is invoked in hot-path
from .subclass_utils import get_types_for_subclass
self.subclass_type = get_types_for_subclass(self.original_subclass)
# This class encapsulates all aliasing + mutation info we need about the forward graph
# See a more detailed overview of the edge case handling at
# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
@dataclass(eq=False)
class ViewAndMutationMeta:
# length = # user inputs
# This gives us info about every input, and what sort of mutation happened to it (if any)
input_info: List[InputAliasInfo]
# length = # user outputs
# This gives us info about every output (mostly around whether it aliases other tensors)
output_info: List[OutputAliasInfo]
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
# Note: this is not necessarily the same thing as:
# len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
# Because outputs might share a ._base, or an output's ._base might itself be
# another user output (in both cases, we won't redundantly append bases to the end of the graph)
num_intermediate_bases: int
# For inference only: instructs us to keep data-only input mutations directly in the graph
keep_input_mutations: bool
# length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
# + (# intermediate bases)
# These are the FakeTensor (or potential SymInt) outputs that we traced from our
# metadata pass of the user's forward function.
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
# pass once, and re-use the output throughout AOTAutograd
traced_tangents: List[Any]
# Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs
# They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors,
# Given a (potentially larger) list of plain torch tensors.
# Taking subclass_inp_meta as an example:
# subclass_inp_meta[i] = j (an int) tells us:
# "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph."
# subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2)
# "The i'th user input is subclass holding two inner tensors, which are
# inputs[3] and inputs[4] of the plain-tensor graph".
# length = # user inputs
subclass_inp_meta: List[Union[int, SubclassCreationMeta]]
# So, the full set of outputs to the forward graph looks something like:
# (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors)
# where the first 3 of those 4 can be subclasses
# (but not saved_for_bw tensors, since these are internal to the compiler
# and not user visible, so there's no point in wrapping/unwrapping them at runtime).
# This list contains subclass information on all of the fw graph outputs
# except for saved_for_bw_tensors.
subclass_fw_graph_out_meta: List[Union[int, SubclassCreationMeta]]
# length = # backward graph inputs
subclass_tangent_meta: List[Union[int, SubclassCreationMeta]]
# TODO: we should kill this
# (need to default it to not break internal)
is_train: bool = False
# length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
# + (# intermediate bases)
# At runtime, we don't keep the traced_tangents around since they're not serializable.
# Instead, we keep any necessary subclass metadata necessary about each traced_tangent.
# This list is generated after calling make_runtime_safe().
traced_tangent_metas: Optional[List[Any]] = None
num_symints_saved_for_bw: Optional[int] = None
# The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue
# NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode
# that is intended to be in effect prior to running the graph, in keeping with
# equivalence to eager mode. It is the responsibility of upstream graph acquisition
# to reset the grad mode to its pre-graph value prior to calling aot_autograd.
grad_enabled_mutation: Optional[bool] = None
# Keeps track of whether `torch.use_deterministic_algorithms` was turned on
# when the forward was run. If deterministic mode was turned off during the
# forward, but is turned on during the backward call, then an error is
# raised
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_input_indices: List[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
# dictionary telling us how many tokens we will need during tracing.
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
# Only filled in if/when we trace the joint function
# If an input requires grad and is mutated in the backward, it is only safe to keep the mutation
# in the graph if gradients are disabled while the backward runs
# (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True)
# At runtime during the backward, we use this list of indices to error properly if we find out
# that it was not safe to include a backward mutation in the graph.
indices_of_inputs_that_requires_grad_with_mutations_in_bw: List[int] = field(
default_factory=list
)
# Indexes of saved tensors which are donated buffer.
# Donated buffer means the tensor is not alias of any forward user input, forward user output,
# and backward output.
bw_donated_idxs: Optional[List[int]] = None
# Number of tokens used in backward, appended at the end of backward outputs.
# Filled after tracing joint function.
num_backward_tokens: int = 0
def __post_init__(self):
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.
mutated_inp_runtime_indices = [
i
for i, m in enumerate(self.input_info)
if (m.mutation_type == MutationType.MUTATED_OUT_GRAPH)
]
mutated_graph_handled_indices = [
i
for i, m in enumerate(self.input_info)
if m.mutation_type == MutationType.MUTATED_IN_GRAPH
]
self.mutated_graph_handled_indices = mutated_graph_handled_indices
self.num_mutated_graph_handled_indices = len(self.mutated_graph_handled_indices)
mutated_graph_handled_indices_seen_by_autograd = [
i
for i in mutated_graph_handled_indices
if not self.input_info[i].mutations_hidden_from_autograd
]
self.mutated_graph_handled_indices_seen_by_autograd = (
mutated_graph_handled_indices_seen_by_autograd
)
self.num_mutated_graph_handled_indices_seen_by_autograd = len(
self.mutated_graph_handled_indices_seen_by_autograd
)
aliased_out_indices = [
i
for i, m in enumerate(self.output_info)
if m.output_type
not in [
OutputType.non_alias,
OutputType.unsafe_view_alias,
OutputType.custom_function_view,
]
]
unsafe_view_out_indices = [
i
for i, m in enumerate(self.output_info)
if m.output_type is OutputType.unsafe_view_alias
]
# This is pre-computed in post_init for perf.
# It contains the index of every element
# of input_info that corresponds to a mutation (data or metadata or both)
self.mutated_inp_runtime_indices = mutated_inp_runtime_indices
self.num_mutated_inp_runtime_indices = len(self.mutated_inp_runtime_indices)
# This is pre-computed for perf.
# It contains the index of every element
# of output_info that corresponds to an alias (either of an input or intermediate)
self.aliased_out_indices = aliased_out_indices
self.unsafe_view_out_indices = unsafe_view_out_indices
self.num_outputs = len(self.output_info)
self.num_outputs_non_aliased = len(
[
x
for x in self.output_info
if x.output_type
in [
OutputType.non_alias,
OutputType.unsafe_view_alias,
OutputType.custom_function_view,
]
]
)
self.num_outputs_aliased_to_inputs = len(
[
x
for x in self.output_info
if x.output_type
in [
OutputType.alias_of_input,
OutputType.is_input,
]
]
)
self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices)
self.num_outputs_aliased_to_intermediates = len(
[
x
for x in self.output_info
if x.output_type
in [
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
OutputType.alias_of_intermediate_base_is_user_output,
]
]
)
self.num_outputs_aliased = (
self.num_outputs_aliased_to_inputs
+ self.num_outputs_aliased_to_intermediates
)
self.dynamic_outputs = any(o.dynamic_dims for o in self.output_info)
# See Note: [AOTAutograd Backward Guards]
# This is pre-computed for fast asserts on the types of our grad_outputs in the backward.
# Eventually, we should kill this and replace with real backward guards.
# (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor)
self.output_types = [
torch.Tensor if isinstance(x, FakeTensor) else type(x)
for x in self.traced_tangents
]
self.is_rng_op_functionalized = config.functionalize_rng_ops
# All of the above metadata is collected by tracing the fw function.
# However, extra outputs for rng offsets behave differently. Both fwd
# and bwd graphs have their own outputs for the total consumed offsets.
# Unlike mutated inputs, we don't have to worry about sending the right
# set of tensors between fwd and bwd. Fwd and bwd offsets are
# independent and simpler to handle. Therefore, we track them
# separately.
self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0
# Our forward() returns both (tokens, mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)
# Tokens will be split out before mutations/view handling and we do not count them here.
self.num_forward_returns = (
self.num_mutated_inp_runtime_indices
+ self.num_outputs
+ self.num_intermediate_bases
)
# In case of functionalization of rng ops, the fw_module returns one
# additional output for rng offset. This rng offset is used right
# away to advance the rng state, and is not passed on to the raw
# outputs. However, we need to know the exact boundary to identify
# which tensors to be saved for the bwd graph. num_forward captures
# this information.
self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset
def make_runtime_safe(self):
"""
There are various fields in ViewAndMutationMeta that aren't serializable. This function is called after all tracing
is completed to simplify certain fields in the metadata so that they can be safely cached.
Doing so may lose information (in the case of traced_tangents), but none of the information is needed at runtime.
"""
# TODO: This function is only a best effort: there are other fields that may not be cache safe
# (i.e., there's no guarantee that tensor_flatten() returns a serializable result), or that
# SubclassCreationMeta is cache safe.
assert self.traced_tangent_metas is None
def extract_metadata(t):
if isinstance(t, torch.Tensor) and is_traceable_wrapper_subclass(t):
(inner_tensors, flatten_spec) = t.__tensor_flatten__() # type: ignore[attr-defined]
# Technically, we only need the flatten_spec, not the inner tensors.
# However, some Tensor subclasses (like TwoTensor) may have flatten_spec = None.
# And we want to be able to assert that this metadata is non-None,
# to distinguish between "this was a tensor subclass with no metadata" vs.
# "this wasn't a tensor subclass at all".
return (inner_tensors, flatten_spec)
else:
return None
self.traced_tangent_metas = [extract_metadata(t) for t in self.traced_tangents]
# Clear traced tangents at runtime
self.traced_tangents = []
new_output_info = []
for out in self.output_info:
if config.view_replay_for_aliased_outputs:
new_out = out
else:
# If we're not using view_replay, remove the functional tensor.
# Functional tensors are unfortunately not serializable,
# so doing this is required for AOTAutograd caching.
new_out = dataclasses.replace(out, functional_tensor=None)
new_output_info.append(new_out)
self.output_info = new_output_info
for inp_meta in self.subclass_inp_meta:
if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe()
for inp_meta in self.subclass_fw_graph_out_meta:
if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe()
for inp_meta in self.subclass_tangent_meta:
if isinstance(inp_meta, SubclassCreationMeta):
inp_meta.make_runtime_safe()
@property
def tensors_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(self.num_forward, -self.num_symints_saved_for_bw)
else:
return slice(self.num_forward, None)
@property
def symints_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(-self.num_symints_saved_for_bw, None)
else:
return slice(0, 0) # empty slice
def __eq__(self, other):
if not isinstance(other, ViewAndMutationMeta):
return NotImplemented
return (
self.input_info == other.input_info
and self.output_info == other.output_info
and self.num_intermediate_bases == other.num_intermediate_bases
and self.keep_input_mutations == other.keep_input_mutations
and self.is_rng_op_functionalized == other.is_rng_op_functionalized
and self.num_outputs_rng_offset == other.num_outputs_rng_offset
and len(self.traced_tangents) == len(other.traced_tangents)
and all(
x.shape == y.shape and x.dtype == y.dtype
for x, y, in zip(self.traced_tangents, other.traced_tangents)
)
and self.num_backward_tokens == other.num_backward_tokens
)
@dataclass(eq=False)
class SubclassMeta:
# A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses)
# So for example, if the user had a model containing two `TwoTensor` inputs,
# Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here.
fw_metadata: ViewAndMutationMeta
# Note: [Computing Subclass Metadata about grad_inputs]
# Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses
#
# You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs?
# (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous)
#
# This doesn't really work though. take this example:
#
# def f(DoubleTensor, DenseTensor):
# return DoubleTensor * DenseTensor
#
# In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor.
# When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs.
# This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input)
# and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors.
#
# Note that this info **cannot** easily be figured out from ViewAndMutationMeta.
# We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed.
#
# See Note: [AOTAutograd Backward Guards]
# This will also eventually require us to install backward guards,
# in case we made incorrect assumptions about the subclass-ness of our grad_outputs
#
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
def __init__(self) -> None:
# The fields in this class get set after its construction.
pass
# This class exists because:
# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
# - we only care about the metadata on those aliases, so we can regenerate them.
# We do not want them to participate in the autograd.Function.
# We do that by wrapping them in an opaque class, so the autograd.Function
# does not know to treat them as tensors.
@dataclass(frozen=True)
class TensorAlias:
alias: torch.Tensor
@dataclass
class BackwardSignature:
"""
Provides information about the backward section of an exported
joint forward-backward graph.
For a particular fx GraphModule, this class contains information on:
(1) A mapping from each gradient (backwards output) to the parameter
it corresponds to (forward input)
(2) A mapping from each gradient (backwards output) to the user input
it corresponds to (forward input)
(3) Which of the forward outputs corresponds to the loss, that we backprop on.
Each string name is the `node.name` of the corresponding node in the fx graph.
"""
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
loss_output: str
GraphOutputName = NewType("GraphOutputName", str)
GraphInputName = NewType("GraphInputName", str)
FQN = NewType("FQN", str)
@dataclass
class GraphSignature:
"""
Provides information about an exported module.
For a particular fx GraphModule, this class contains information on:
(1) Which graph inputs are parameters, buffers, or user inputs
(2) (for params/buffers) a mapping from the name of each graph argument
to its parameter/buffer FQN in the original nn.Module.
(3) If there are input mutations, these are represented as extra outputs
in the fx GraphModule. We provide a mapping from these
extra output names to the names of the actual inputs.
(4) The pytree metadata on how to flatten/unflatten inputs and outputs.
The corresponding FX GraphModule only accepts and returns
pytree-flattened inputs/outputs.
(5) (Optionally) if the FX is a joint forward-backward graph, we provide
a signature on the backward section of the joint graph.
"""
parameters: List[FQN]
buffers: List[FQN]
user_inputs: List[GraphInputName]
user_outputs: List[GraphOutputName]
inputs_to_parameters: Dict[GraphInputName, FQN]
inputs_to_buffers: Dict[GraphInputName, FQN]
# If the user's module mutates a buffer,
# it's represented in the graph as an extra graph output.
# This dict is a mapping from
# "graph outputs that correspond to updated buffers"
# to the FQN names of those mutated buffers.
buffers_to_mutate: Dict[GraphOutputName, FQN]
user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
backward_signature: Optional[BackwardSignature]
input_tokens: List[GraphInputName]
output_tokens: List[GraphOutputName]
@classmethod
def from_tracing_metadata(
cls,
*,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
graph_input_names: List[str],
graph_output_names: List[str],
view_mutation_metadata: ViewAndMutationMeta,
named_parameters: List[str],
named_buffers: List[str],
num_user_inputs: int,
num_user_outputs: int,
loss_index: Optional[int],
backward_signature: Optional[BackwardSignature],
) -> "GraphSignature":
graph_inputs = graph_input_names
graph_outputs = graph_output_names
parameters = list(named_parameters)
buffers = list(named_buffers)
num_tokens = len(view_mutation_metadata.tokens)
# Calling convention assumptions:
# (1) graph inputs = (input_tokens, params, buffers, user_inputs)
# (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients)
# (If we are capturing an inference graph, this convention is identical
# except that param_gradients is empty)
# See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens
# Address input calling conventions:
start, stop = 0, num_tokens
input_tokens = graph_inputs[start:stop]
start, stop = stop, stop + len(parameters)
inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters))
start, stop = stop, stop + len(buffers)
inputs_to_buffers = dict(
zip(
graph_inputs[start:stop],
buffers,
)
)
start, stop = stop, stop + num_user_inputs
user_inputs = graph_inputs[start:stop]
# We should've gone through all the inputs now
assert len(graph_inputs) - stop == 0
# Address output calling conventions:
start, stop = 0, num_tokens
output_tokens = graph_outputs[start:stop]
names = [*input_tokens, *parameters, *buffers, *user_inputs]
mutations = []
for idx, input_info in enumerate(view_mutation_metadata.input_info):
if input_info.mutates_data:
# Only buffers can be mutated, not parameters
assert idx >= len(parameters)
mutations.append(names[idx + num_tokens])
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
start, stop = (
stop,
stop + view_mutation_metadata.num_mutated_inp_runtime_indices,
)
outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
user_inputs_to_mutate = {}
buffers_to_mutate = {}
for output_name, mutation_name in outputs_to_mutations.items():
if mutation_name in user_inputs:
user_inputs_to_mutate[output_name] = mutation_name
else:
assert mutation_name in buffers
buffers_to_mutate[output_name] = mutation_name
start, stop = stop, stop + num_user_outputs
user_outputs = graph_outputs[start:stop]
unused_outputs = len(graph_outputs) - stop
if backward_signature is not None:
unused_outputs -= len(backward_signature.gradients_to_parameters) + len(
backward_signature.gradients_to_user_inputs
)
assert unused_outputs == 0
return GraphSignature(
parameters=parameters, # type: ignore[arg-type]
buffers=buffers, # type: ignore[arg-type]
user_inputs=user_inputs, # type: ignore[arg-type]
user_outputs=user_outputs, # type: ignore[arg-type]
inputs_to_buffers=inputs_to_buffers, # type: ignore[arg-type]
inputs_to_parameters=inputs_to_parameters, # type: ignore[arg-type]
user_inputs_to_mutate=user_inputs_to_mutate,
buffers_to_mutate=buffers_to_mutate, # type: ignore[arg-type]
in_spec=in_spec,
out_spec=out_spec,
backward_signature=backward_signature,
input_tokens=input_tokens, # type: ignore[arg-type]
output_tokens=output_tokens, # type: ignore[arg-type]
)
@dataclass
class AOTConfig:
"""
Configuration for AOTDispatcher
"""
fw_compiler: Callable
bw_compiler: Callable
partition_fn: Callable
decompositions: Dict[OpOverload, Callable]
num_params_buffers: int
aot_id: int
keep_inference_input_mutations: bool
is_export: bool = False
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
static_input_indices: Optional[List[int]] = None
inference_compiler: Optional[Callable] = None
enable_log: bool = True
# this is always false outside of export.
pre_dispatch: bool = False
# Key to use for AOTAutogradCache
cache_key: Optional[str] = None
def __post_init__(self):
if self.pre_dispatch:
assert self.is_export, "Can only have pre_dispatch IR for export."
SubclassTracingInfo = collections.namedtuple(
"SubclassTracingInfo",
["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"],
)

View File

@ -0,0 +1,347 @@
# mypy: allow-untyped-defs
"""
This file contains utilities for tracing through __torch_dispatch__ based tensor subclasses and modes.
AOTAutograd's responsibility is to trace through all pytorch capabilities that live in the pytorch dispatcher,
and this includes tensor subclasses that implement __torch_dispatch__.
"""
import typing
from typing import Any, List, Optional, Tuple, Union
import torch.utils._pytree as pytree
from torch import Tensor
from torch._subclasses.fake_tensor import get_plain_tensors
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .schemas import MutationType, SubclassCreationMeta, ViewAndMutationMeta
from .utils import strict_zip
zip = strict_zip
def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
args_flattened = pytree.arg_tree_leaves(*args)
any_subclass_args = any(
is_traceable_wrapper_subclass(x)
for x in args_flattened
if isinstance(x, Tensor)
)
from torch._functorch._aot_autograd.schemas import SubclassCreationMeta
any_subclass_outputs = any(
type(x) is SubclassCreationMeta for x in fw_metadata.subclass_fw_graph_out_meta
)
# This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime.
return any_subclass_args or any_subclass_outputs
def create_subclass_metadata(a, start_idx):
if not is_traceable_wrapper_subclass(a):
return None, start_idx + 1
inner_keys, metadata = a.__tensor_flatten__()
new_start_idx = start_idx
attrs = {}
for key in inner_keys:
new_subclass_meta, new_start_idx = create_subclass_metadata(
getattr(a, key), new_start_idx
)
attrs[key] = new_subclass_meta
# It *must* be because is_traceable_wrapper_subclass() - but mypy is not smart.
assert isinstance(a, Tensor)
return (
SubclassCreationMeta(
flat_tensor_start_idx=start_idx,
arg_count=new_start_idx - start_idx,
attrs=attrs,
meta=metadata,
outer_size=a.size(), # type: ignore[attr-defined, arg-type]
outer_stride=a.stride(), # type: ignore[arg-type]
original_subclass=a,
),
new_start_idx,
)
# Given a real tensor subclass, returns a nested list of Plain tensor types
def get_types_for_subclass(tensor_subclass):
if not is_traceable_wrapper_subclass(tensor_subclass):
return ["Tensor"]
inner_keys, _ = tensor_subclass.__tensor_flatten__()
result = []
for key in inner_keys:
inner_tensor = getattr(tensor_subclass, key)
result.extend(get_types_for_subclass(inner_tensor))
return result
# Given a flat list of arguments, some of which may be tensor subclasses,
# computes metadata about "how to reconstruct the current list of subclasses,
# if we were given their flattened dense tensors instead"
def create_subclass_meta(
curr_args: Union[List[Any], Tuple[Any, ...]]
) -> List[Union[int, SubclassCreationMeta]]:
idx = 0
infos: List[Union[int, SubclassCreationMeta]] = []
for a in curr_args:
if is_traceable_wrapper_subclass(a):
assert isinstance(a, Tensor)
start_idx = idx
subclass_meta, _ = create_subclass_metadata(a, start_idx)
infos.append(subclass_meta)
cnt = subclass_meta.arg_count
else:
infos.append(idx)
cnt = 1
idx += cnt
return infos
# Output structure:
# - List[Tensor] if tracing an inference graph
# - Tuple[List[Tensor], List[Tensor]] if tracing a joint graph.
# This function effectively concats each inner list of subclass tensors
# into a (potentially longer) list of inner tensors.
#
# This function takes in a pytree of arguments and unwraps any tensor subclasses.
# Annoyingly, we can't use pytrees to perform the unwrapping, because unwrapping returns
# a list of tensors that we would then need to concat together.
# Instead, we specialize the logic for the inference vs. joint graph case.
# NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime
def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
def concat_inner_tensors_from_subclasses(xs):
xs_inner = []
for x in xs:
if is_traceable_wrapper_subclass(x):
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
else:
xs_inner.append(x)
return xs_inner
if is_joint_structure:
assert isinstance(wrapped_args, tuple) and len(wrapped_args) == 2
assert isinstance(wrapped_args[0], (tuple, list)) and isinstance(
wrapped_args[1], (tuple, list)
)
unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args[0])
unwrapped_args_tangents = concat_inner_tensors_from_subclasses(wrapped_args[1])
unwrapped_args = (unwrapped_args_fw, unwrapped_args_tangents)
else:
assert isinstance(wrapped_args, (list, tuple))
unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args)
unwrapped_args = unwrapped_args_fw
return unwrapped_args
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
static_input_indices = set(static_input_indices)
new_ind = 0
remapped_static_indices = []
for i, arg in enumerate(wrapped_args):
num_indices = 1
if is_traceable_wrapper_subclass(arg):
num_indices = len(get_plain_tensors(typing.cast(Tensor, arg)))
for _ in range(num_indices):
if i in static_input_indices:
remapped_static_indices.append(new_ind)
new_ind += 1
return remapped_static_indices
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
def wrap_tensor_subclasses(
unwrapped_args: Union[Tuple[Any, ...], List[Any]],
*,
subclass_metas: List[Union[int, SubclassCreationMeta]],
num_fw_outs_saved_for_bw: Optional[int] = None,
is_runtime: bool = False,
) -> Tuple[Any, ...]:
wrapped_args = []
num_args_tallied = 0
for subclass_meta in subclass_metas:
if isinstance(subclass_meta, int):
wrapped_args.append(unwrapped_args[subclass_meta])
num_args_tallied += 1
else:
assert isinstance(subclass_meta, SubclassCreationMeta)
wrapped_args.append(
subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
)
num_args_tallied += subclass_meta.arg_count
# Note: [Partitioner handling for Subclasses, Part 2]
# At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw,
# to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them.
#
# When this function is called at runtime in the forward,
# we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs.
#
# One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen?
# Answer: we do it **inside of our compiled autograd.Function**.
# This seems like morally the right place: autograd happens above subclass desugaring,
# so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors.
#
# This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph
# into a forward and backward graph, we end up with some activations that show up as extra outputs
# in the compiled forward graph, that are **not** user outputs.
# These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses.
#
# On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`),
# we computed subclass metadata on every forward output, but this did **not** include activations
# created by the partitioner.
# as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations),
# but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`.
# We then need to make sure that we return (*wrapped_user_fw_outs, *activations).
if num_fw_outs_saved_for_bw is not None:
assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw, (
f"Expected the number actual unwrapped-subclass outputs {len(unwrapped_args)} to equal "
f"the number of args calculated from subclasses ({num_args_tallied}) plus the number of "
f"additional activations saved for the backward pass ({num_fw_outs_saved_for_bw})"
)
activations = unwrapped_args[num_args_tallied:]
if isinstance(wrapped_args, tuple) and isinstance(activations, tuple):
return wrapped_args + activations
return tuple(list(wrapped_args) + list(activations))
else:
assert len(unwrapped_args) == num_args_tallied
return tuple(wrapped_args)
# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses.
# This function carefully handles the inference vs. joint cases:
# - when is_joint_structure is True, args is (primals, tangents)
# - when is_joint_structure is False, args is [*primals]
def wrap_tensor_subclasses_maybe_joint(
unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta
) -> Union[Tuple[Any, ...], List[Any]]:
# Since this function is re-used for both inference and joint graphs,
if is_joint_structure:
assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2
assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance(
unwrapped_args[1], (tuple, list)
)
primals, tangents = unwrapped_args[0], unwrapped_args[1]
wrapped_primals = wrap_tensor_subclasses(
primals, subclass_metas=meta.subclass_inp_meta
)
wrapped_tangents = wrap_tensor_subclasses(
tangents, subclass_metas=meta.subclass_tangent_meta
)
return (wrapped_primals, wrapped_tangents)
else:
wrapped_args = wrap_tensor_subclasses(
unwrapped_args, subclass_metas=meta.subclass_inp_meta
)
return wrapped_args
# TODO: UNUSED. delete?
def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMeta:
# input infos
input_info = []
for inp, subclass_meta in zip(meta.input_info, meta.subclass_inp_meta):
num_inps = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
for _ in range(num_inps):
input_info.append(inp)
# output infos
output_info = []
subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[
meta.num_mutated_inp_runtime_indices :
]
if meta.num_intermediate_bases > 0:
subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[
: -meta.num_intermediate_bases
]
# sanity assert
assert len(meta.output_info) == len(subclass_out_meta_user_outs_only)
# Assume that the information on the output is shared by all of its inner tensors.
for out, subclass_meta in zip(meta.output_info, subclass_out_meta_user_outs_only):
num_outs = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
for _ in range(num_outs):
output_info.append(out)
# A bit hacky, but we don't actually care about all of the metadata here.
# This metadata is used **underneath** both autograd and subclass de-sugaring,
# So all we really care about is stuff like:
# - num inputs/outputs (needed by the partitioner)
# - input mutations (**not** used today, since we don't handle input mutations inside the subclass,
# although we should handle this eventually)
# TODO: add a test case to assert we error when this happens, instead of getting silent correctness
num_intermediate_bases = None
keep_input_mutations = meta.keep_input_mutations
traced_tangents = None
subclass_inp_meta = None
subclass_fw_graph_out_meta = None
subclass_tangent_meta = None
metadata = ViewAndMutationMeta(
input_info=input_info, # type: ignore[arg-type]
output_info=output_info, # type: ignore[arg-type]
num_intermediate_bases=num_intermediate_bases, # type: ignore[arg-type]
keep_input_mutations=keep_input_mutations, # type: ignore[arg-type]
traced_tangents=traced_tangents, # type: ignore[arg-type]
subclass_inp_meta=subclass_inp_meta, # type: ignore[arg-type]
subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, # type: ignore[arg-type]
subclass_tangent_meta=subclass_tangent_meta, # type: ignore[arg-type]
)
return metadata
def compute_inner_mutated_inp_indices_from_subclass_meta(
fw_metadata: ViewAndMutationMeta,
inner_metadata: ViewAndMutationMeta,
) -> List[int]:
# Note: [Recomputing subclass mutation handling]
#
# Generally, if a subclass requires grad, its components will not require grad.
# But for the purposes of tracking returned tensors, we should treat those component
# tensors as if they require grad.
#
# For example, if the subclass tensor requires grad and will be mutated in a way that
# requires us to handle the mutation outside of the graph, we need to return it
# from the forward graph. The inner_meta data won't consider the component tensors
# as if they need to be returned, because they don't require grad; but really, we
# should handle those tensors the same way we handle the subclass tensor itself; i.e.
# if we'd include the subclass tensor as part of the outputs, then we should also
# include the component tensors.
#
# To do this, we patch num_mutated_inp_runtime_indices below by expanding the inputs
# from the outer subclass tensors and propagating
updated_input_info = []
inner_idx = 0
if not fw_metadata.subclass_inp_meta:
# Sometimes we don't have subclass info, e.g. synthetic_base codepaths
return inner_metadata.mutated_inp_runtime_indices
assert len(fw_metadata.subclass_inp_meta) == len(fw_metadata.input_info)
for outer_idx, inp_meta in enumerate(fw_metadata.subclass_inp_meta):
if isinstance(inp_meta, int):
assert outer_idx < len(fw_metadata.input_info)
if inner_metadata is not None:
assert inner_idx < len(inner_metadata.input_info)
assert (
inner_metadata.input_info[inner_idx]
== fw_metadata.input_info[outer_idx]
)
updated_input_info.append(fw_metadata.input_info[outer_idx])
inner_idx += 1
else:
for _ in range(inp_meta.arg_count):
updated_input_info.append(fw_metadata.input_info[outer_idx])
inner_idx += 1
if inner_metadata is not None:
assert len(inner_metadata.input_info) == len(updated_input_info)
return [
i
for i, inp in enumerate(updated_input_info)
if inp.mutation_type == MutationType.MUTATED_OUT_GRAPH
]

View File

@ -0,0 +1,881 @@
# mypy: allow-untyped-defs
"""
This module is responsible for transforming functions to be traced into a form
that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis)
to handle.
It does so by:
1. functionalization (including RNG functionalzation)
2. creating a joint graph when required
3. transforming mutations into extra outputs
4. dispatching subclasses
"""
import warnings
from contextlib import contextmanager, nullcontext
from functools import wraps
from typing import Any, Callable, List, Tuple, Union
from unittest.mock import patch
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
from torch._guards import detect_fake_mode
from torch._prims_common import CUDARngStateHelper
from torch.fx.experimental.proxy_tensor import (
maybe_disable_thunkify,
maybe_enable_thunkify,
)
from torch.fx.experimental.symbolic_shapes import (
definitely_false,
PropagateUnbackedSymInts,
sym_eq,
)
from torch.nn.utils import stateless
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import (
from_fun,
has_data_mutation,
has_metadata_mutation,
is_fun,
sync_functional_tensor,
to_fun,
was_inductor_storage_resized,
)
from .logging_utils import setup_stacktrace_preservation_hooks
from .schemas import (
AOTConfig,
MutationType,
OutputType,
SubclassMeta,
SubclassTracingInfo,
ViewAndMutationMeta,
)
from .subclass_utils import (
create_subclass_meta,
remap_unwrapped_subclass_arg_indices,
requires_subclass_dispatch,
unwrap_tensor_subclasses,
wrap_tensor_subclasses_maybe_joint,
)
from .utils import maybe_to_fresh_input
# This function returns a new function that returns mutated inputs as outputs.
# if keep_data_input_mutations is set, then we assume that data-only mutations
# will be left in the graph, and we only return metadata-mutated inputs as outputs.
def fn_input_mutations_to_outputs(
fn: Callable,
meta: ViewAndMutationMeta,
keep_data_input_mutations: bool,
) -> Any:
@wraps(fn)
def inner_fn(*args):
outs = fn(*args)
assert len(meta.output_info) == len(outs)
# The compiled fw will return mutated input tensors, *including* metadata-only mutation.
# However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
# (because data-only input mutations are handled directly in the compiled graph)
mutated_inputs_to_return = [
x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices
]
return *mutated_inputs_to_return, *outs
return inner_fn
# This function takes in a fn with external aliasing and mutation,
# and returns a new fn with no external aliasing and mutation,
# as needed for autograd.
# The main transformations are:
# - Return mutated inputs as extra outputs
# - Clone mutated inputs that require gradients,
# because autograd will require us to pass the pre-mutated inputs into autograd.grad
# - Return intermediate bases of outputs as additional outputs,
# needed to appease autograd.Function
# The new function returns:
# (1) The updated outputs
# (2) A boolean mask of len(new_fn_outputs),
# that can be used to tell autograd.grad which outputs should get tangents
# if we trace the backward.
def fn_prepped_for_autograd(
fn: Callable,
meta: ViewAndMutationMeta,
) -> Any:
@wraps(fn)
def inner_fn(*args):
args_maybe_cloned = [
maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args)
]
outs = fn(*args_maybe_cloned)
assert isinstance(outs, (tuple, list))
outs = list(outs)
assert len(meta.output_info) == len(outs)
mutated_inputs_to_return = [
x
for (i, x) in enumerate(args_maybe_cloned)
if i in meta.mutated_inp_runtime_indices
]
intermediate_bases = []
for i, (o, info) in enumerate(zip(outs, meta.output_info)):
if info.output_type == OutputType.alias_of_intermediate_save_as_output:
intermediate_bases.append(o._base)
assert meta.num_intermediate_bases == len(intermediate_bases)
# the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
# Also return a boolean mask specifying which outputs to this function will be used as tangents
mutated_inputs_grad_mask = [
meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data
and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad
for (i, x) in enumerate(mutated_inputs_to_return)
]
# Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
# which we *should* send to grad()
output_grad_mask = [
meta.output_info[i].output_type
in [
OutputType.non_alias,
OutputType.unsafe_view_alias,
OutputType.custom_function_view,
]
# Also, only tensor outputs should participate in the backward
# (in particular, Symint outputs in the forward graph shouldn't get tangents)
and issubclass(meta.output_info[i].raw_type, Tensor)
and meta.output_info[i].requires_grad
for (i, x) in enumerate(outs)
]
intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))]
out_grad_mask = (
mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask
)
assert len(out_grad_mask) == len(fw_outs_to_return)
# Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
# and not primals (the preserved inputs, pre-mutation, that we pass to grad())
# This is annoying: our joint function needs to be aware of functionalization
# (syncing mutated inputs before calling autograd.grad())
# In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
for arg in args_maybe_cloned:
if not isinstance(arg, Tensor):
continue
sync_functional_tensor(arg)
return fw_outs_to_return, out_grad_mask
return inner_fn
# Given a fn, computes the joint.
# NOTE: fn is expects the following behavior:
# (1) fn() needs to return a tuple of (outs, mask),
# where `mask` tells us which outputs are meant to have tangents.
# we don't know this info automatically, because we don't actually want to blindly
# compute tangents for every output that requires grad.
# Specifically, outputs that alias inputs won't participate in the backward and get tangents.
# (2) fn() cannot mutate any inputs that require gradient.
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
def inner_fn(primals: List[Any], tangents: List[Any]):
outs, tangent_mask = fn(*primals)
assert len(tangent_mask) == len(outs)
outs_to_grad = [
o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
]
assert len(outs_to_grad) == len(tangents)
# Get the inputs that need gradients
grad_primals = []
inputs_needs_grads = []
# Note that we're not using primals here,
# being carefully not to pass any mutated inputs into autograd.grad()
for p in primals:
is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
inputs_needs_grads.append(is_grad_tensor)
if is_grad_tensor:
grad_primals.append(p)
# Get the outputs that need gradients
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs_to_grad, tangents):
if isinstance(out, Tensor) and out.requires_grad:
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
# The issue is that we are sensitive to decomps that don't accurately maintain
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
# The not definitely_false is also sketchy; if unbacked
# symints are involved, we're just going to assume that the
# decomps setup the base shape correctly
needed_outs.append(
out
if not definitely_false(sym_eq(out.shape, tangent.shape))
else out.view(tangent.shape)
)
needed_tangents.append(tangent)
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
if config.functionalize_rng_ops:
PhiloxStateTracker.mark_beginning_of_backward()
backward_out: Tuple[Tensor, ...] = ()
# Call the backwards pass
if grad_primals:
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
if functional_tensor_mode is not None:
# Side-Effect Tokens:
# We want to have independent chains of tokens for forward and backward.
# functional_tensor_mode._tokens is used by both.
# We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output,
# to return them as joint graph outputs.
# We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward.
# Joint graph tracing allows tokens discovery,
# So all the tokens in backward will be created and added as a graph inputs during tracing.
functional_tensor_mode._tokens_forward_output = (
functional_tensor_mode._tokens
)
functional_tensor_mode._tokens = {}
with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta():
# for full graph export, we always export a joint graph where we assume no tangents are needed.
if aot_config.no_tangents:
assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
allow_unused=True,
)
else:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
)
backward_out_iter = iter(backward_out)
return outs, [
next(backward_out_iter) if i else None for i in inputs_needs_grads
]
def inner_fn_with_anomaly(*args):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.")
with torch.autograd.detect_anomaly(check_nan=False):
return inner_fn(*args)
return inner_fn_with_anomaly
def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any:
# Functionalization of rng ops changes the calling convention of the joint graph.
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
# At runtime, we pass on the current seed and offset. This is hidden from
# the user.
fake_mode = detect_fake_mode()
if fake_mode is None:
fake_mode = nullcontext()
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
out = PhiloxStateTracker.get_state_as_tensor()
return out
def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"):
PhiloxStateTracker.set_state_from_tensor(x)
def append_rng_offsets(args):
if trace_joint:
# args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs)
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset)
return (
(*args[0], PhiloxStateTracker.get_updated_fwd_offset()),
(*args[1], PhiloxStateTracker.get_updated_bwd_offset()),
)
else:
# args signature before: Tuple(fwd_outputs)
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset)
return (*args, PhiloxStateTracker.get_updated_fwd_offset())
def traced_joint(
primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset
):
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
"torch.cuda.set_rng_state", override_set_rng_state
):
return append_rng_offsets(func(primals, tangents))
def traced_forward(*primals_fwd_seed_fwd_base_offset):
# The signature is (*primals, seed, offset)
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
"torch.cuda.set_rng_state", override_set_rng_state
):
return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
if trace_joint:
# Get the current seed and offset to setup tracing.
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
fake_mode
)
bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
fake_mode
)
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward")
return traced_joint, (
*args,
fwd_seed,
fwd_base_offset,
bwd_seed,
bwd_base_offset,
)
else:
# Get the current seed and offset to setup tracing.
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
fake_mode
)
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
return traced_forward, (*args, fwd_seed, fwd_base_offset)
@contextmanager
def set_partitioner_tag(tag: str):
meta_key = "partitioner_tag"
assert fx_traceback.has_preserved_node_meta()
original_val = fx_traceback.current_meta.get(meta_key, None)
fx_traceback.current_meta[meta_key] = tag
try:
yield
finally:
fx_traceback.current_meta[meta_key] = original_val
def set_partitioner_tag_is_backward():
return set_partitioner_tag("is_backward")
def set_partitioner_tag_must_be_in_backward():
return set_partitioner_tag("must_be_in_backward")
# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
# - fn corresponds to the user's fw function
# - fn arguments have been flattened, duplicate arguments have been handled
# - In the returned function, the "primals" arguments *includes* synthetic bases.
# This function does the work of functionalizing the input function,
# and performing copy_() calls at the end of the function if `keep_input_mutations` is set.
# The function returned has signature that is either:
# (1) "traced_fn(primals: List[Any])" if trace_joint is False
# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True
# Returns a new (functionalized) function, and updated arguments to call it with.
def create_functionalized_fn(
fn,
args,
*,
meta: ViewAndMutationMeta,
aot_config: AOTConfig,
trace_joint: bool,
) -> Any:
@wraps(fn)
def _functionalized_f_helper(*args):
with maybe_enable_thunkify():
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
with disable_above:
# The functionalization code here can potentially trigger traces
# into the graph, but we'd prefer to NOT do this, because if we
# trace them now, we will end up with FX nodes that don't have
# module stack annotations, which makes unflattener unhappy.
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
# Run the joint
f_outs = fn(*f_args)
if trace_joint:
# We support a limited amount of mutation of graph inputs during the backward pass.
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
# Here, we perform extra checks for primals that were mutated in the **backward**
# We're doing the checks here instead of doing them with the rest of the input mutation handling because:
# - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
# during the forward, because the handling is different: some input mutations from the the forward
# can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
# types of mutations in the backward we would need a bw-only runtime epilogue.
# - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
# the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
# require an extra round of tracing though, so it's more efficient to do in-line here.
assert (
isinstance(args, tuple)
and len(args) == 2
and isinstance(args[0], (list, tuple))
)
# Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
primals_before = args[0]
primals_after = pytree.tree_map(from_fun, f_args[0])
for idx, (f_inpt, before, after, inpt_info) in enumerate(
zip(f_args[0], primals_before, primals_after, meta.input_info)
):
# Store information about mutations in joint(for backward analysis)
joint_mutates_data = has_data_mutation(f_inpt)
joint_mutates_metadata = has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
)
# Ban metadata mutations on fw inputs during the bw
if not inpt_info.mutates_metadata:
assert (
not joint_mutates_metadata
), "Found a graph input that had its metadata mutated in the backward. This is not supported"
# Ban storage resizing on fw inputs during the bw
if not inpt_info.mutation_inductor_storage_resize:
assert not was_inductor_storage_resized(
f_inpt
), "Found a graph input that had storage resizing in the backward. This is not supported"
# Allow data mutations on fw inputs during the bw, but only if they do not require grad
# So we can guarantee that we can keep the mutations in the graph
if (
joint_mutates_data
and not inpt_info.mutates_data
and not inpt_info.mutates_storage_metadata
):
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
)
# Now that we covered mutations to *forward* inputs during the backward,
# we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
# Today, we will just error in all cases of this happening unless someone needs us to support it.
tangents_before = args[1]
tangents_after = pytree.tree_map(from_fun, f_args[1])
for f_inpt, before, after in zip(
f_args[1], tangents_before, tangents_after
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
) and not has_data_mutation(
f_inpt
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
# This makes it pretty difficult for this logic to operate on synthetic bases.
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
# (unpacked) input aliases, instead of the synthetic base.
# Example case where (3) could be important:
#
# def f(x, y):
# x.mul_(2)
# y.mul_(3)
# return x, y
# a = torch.ones(1'000'000)
# x, y = out(a[0:9], a[1:10])
#
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
# a giant "updated synthetic base" and copying into a's entire storage.
#
# For now, we are pessimistically not performing the optimization from (3);
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
# about synthetic bases.
for i, (inpt_old, inpt_f) in enumerate(
zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
):
if not isinstance(inpt_f, torch.Tensor):
continue
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
if (
meta.input_info[i].mutation_type
== MutationType.MUTATED_IN_GRAPH
):
# See Note [set_() Input Mutations in AOTAutograd]
# all mutations on the input must be under no_grad, so it is safe to put in the graph
# Here, we're saying that if an input experienced a set call, inp.set_(other),
# then we can effectively not have to worry about whether its data was mutated.
# There are 3 cases:
# (1) We mutate inp *after* the set_() call. other is a graph intermediate.
# In this case, we're not really mutating the input storage of "inp";
# we're mutating the storage of an intermdiate value (other),
# and slamming that storage into the input tensor. So no data mutation is necessary.
# (2) We mutate inp *after* the set_() call. other is a graph *input*.
# In this case, the data mutation will be properly handled in the runtime
# epilogue during the processing of "other"
# (3) We mutate inp *before* the set_() call.
# This case is *not* currently handled.
if meta.input_info[i].mutates_storage_metadata:
with torch.no_grad():
inpt_old.set_(inpt_new)
# Note [Ordering of resize_() and set_()]
# Importantly: the common usage in FSDP is that we have a dummy parameter
# that sees a set_() and **Then** a resize_().
# We must put those mutations into the graph in the same order,
# Since running them in the opposite order will have different behavior.
# We fully ban resize_() followed by set_() for now, although in principal
# we could support this
if meta.input_info[i].mutation_inductor_storage_resize:
# resizing is not supported on subclasses (we error earlier if this happens)
from torch._subclasses.functional_tensor import (
FunctionalTensor,
)
assert isinstance(inpt_f, FunctionalTensor)
old_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=True
)
new_storage_size = torch._functionalize_get_storage_size( # type: ignore[attr-defined]
inpt_f.elem, before=False
)
if old_storage_size != new_storage_size:
assert (
old_storage_size == 0 or new_storage_size == 0
), f"""\
Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
(the case for FSDP)"""
torch.ops.inductor.resize_storage_bytes_(
inpt_old, new_storage_size
)
if new_storage_size == 0:
# Even if we marked the input as having a data mutation (thus needing a copy_()),
# We should **ignore** it if our input has no storage
# (this can happen if, e.g. we temporarily resize our input, copy data into it,
# and resize it back down to zero)
continue
# Optimization: if the copy_() is a no-op then don't include it in the graph.
# In theory inductor could optimize this away, however in fsdp, we end up with
# param.copy_(param), where param is a zero-storage-size tensor,
# and running this op in eager mode (using the aot_eager backend) will result in a segfault.
# So we may as well optimize it away here.
if inpt_old is inpt_new:
# (This check needs to be done after putting resize_() in the graph,
# since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
continue
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
if (
meta.input_info[i].mutates_data
and meta.input_info[i].mutations_hidden_from_autograd
):
# Hidden from autograd = run under no_grad, **and** don't bump VC
# (although if the tensor was created in inference mode, it has no VC)
if inpt_old.is_inference():
maybe_preserve_vc = nullcontext()
else:
maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
inpt_old # type: ignore[assignment]
)
with torch.no_grad(), maybe_preserve_vc:
inpt_old.copy_(inpt_new)
elif (
meta.input_info[i].mutates_data
and meta.input_info[
i
].mutations_under_no_grad_or_inference_mode
):
# Under no_grad = run under no_grad (we still bump the VC though)
# (inference_mode will also bump the VC, as long as the tensor in question
# was created outside of inference_mode)
with torch.no_grad():
inpt_old.copy_(inpt_new)
elif meta.input_info[i].mutates_data:
inpt_old.copy_(inpt_new)
# When an output tensor is a functionalized mutated input, and we
# were able to move the mutation in to the graph then we can return
# the mutated input directly. This prevents duplicating the
# tensors contents.
flat_outs, outs_spec = pytree.tree_flatten(f_outs)
flat_outs = [from_fun(o) for o in flat_outs]
num_outs = len(meta.output_info)
for i, outp in enumerate(flat_outs[:num_outs]):
info = meta.output_info[i]
if info.output_type != OutputType.is_input:
continue
assert info.base_idx is not None
if (
meta.input_info[info.base_idx].mutation_type
== MutationType.MUTATED_IN_GRAPH
):
fw_args = args[0] if trace_joint else args
flat_outs[i] = fw_args[info.base_idx]
return pytree.tree_unflatten(flat_outs, outs_spec)
return pytree.tree_map(from_fun, f_outs)
# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
# and "tangents" as its input names (which are special-cased by the partitioner)
# TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export
def joint_helper(primals, tangents):
return _functionalized_f_helper(primals, tangents)
helper = joint_helper if trace_joint else _functionalized_f_helper
if config.functionalize_rng_ops:
# Setup the wrapper for functionalization of rng ops
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
return helper, args
def handle_effect_tokens_fn(
fn,
args,
*,
meta: ViewAndMutationMeta,
trace_joint: bool,
) -> Any:
num_tokens = len(meta.tokens)
@wraps(fn)
def inner_fn(*args):
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
)
with disable_above:
# See Note [Side-Effectful Tokens in AOTAutograd]
if trace_joint:
assert isinstance(args, tuple) and isinstance(args[0], (list, tuple))
tokens = args[0][:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = (args[0][num_tokens:], *args[1:])
else:
tokens = args[:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = args[num_tokens:]
# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
assert functional_tensor_mode is not None
f_tokens = pytree.tree_map(to_fun, tokens)
for i, k in enumerate(meta.tokens.keys()):
functional_tensor_mode._tokens[k] = f_tokens[i]
# Run the joint
outs = fn(*args)
# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
if trace_joint:
assert len(outs) == 2
assert len(functional_tensor_mode._tokens_forward_output) == num_tokens
fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values()
bwd_out_tokens = functional_tensor_mode._tokens.values()
f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens]
f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens]
meta.num_backward_tokens = len(bwd_out_tokens)
return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens))
out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()]
return (*out_tokens, *outs)
# Additionally pass in tokens as inputs
# See Note [Side-Effectful Tokens in AOTAutograd]
additional_fwd_token_inputs = [torch.tensor([])] * num_tokens
if trace_joint:
args = ([*additional_fwd_token_inputs, *args[0]], *args[1:])
else:
args = [*additional_fwd_token_inputs, *args]
return inner_fn, args
# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor
# Also returns:
# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated)
# - the updated ViewAndMutationMeta for this dense -> dense function.
# The other important arguments are:
# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function.
# when is_joint_structure=False, this is just the forward function.
# - fw_only: this is *always* the forward-only function.
# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions.
# In particular, we need this to tell the partitioner how many dense forward outputs there are.
def aot_dispatch_subclass(
flat_fn_maybe_joint,
args: List[Any],
*,
is_joint_structure: bool,
meta: ViewAndMutationMeta,
fw_only: Callable,
) -> SubclassTracingInfo:
# Skip logic if we don't need to trace through any subclasses
req_subclass_dispatch = requires_subclass_dispatch(args, meta)
if not req_subclass_dispatch:
return SubclassTracingInfo(
plain_tensor_trace_fn=flat_fn_maybe_joint,
plain_tensor_args=args,
maybe_subclass_meta=None,
)
# TODO: add subclass guards (later PR).
# What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs).
# Annoying: we don't know the grad input metas until we're in the middle of tracing the joint,
# so we set it later, while we're tracing the joint (see inner_fn() below).
# Another option would be to run our run_functionalized_fw_and_collect_metadata() function
# directly on the joint, but this would hurt compile time (adding yet another pass through the joint).
subclass_meta = SubclassMeta()
def inner_fn(fn, args, *, use_trace_joint: bool):
# Step 1: wrap tensor inputs into subclasses if necessary
all_args = wrap_tensor_subclasses_maybe_joint(
args, is_joint_structure=use_trace_joint, meta=meta
)
# Step 2: call the inner function, with our (maybe subclass) inputs
wrapped_outs = fn(*all_args)
if use_trace_joint:
# See Note: [Computing Subclass Metadata about grad_inputs]
# We also stash subclass info on our grad_inputs, if we're tracing the joint.
nonlocal subclass_meta
assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2
# Don't need fw outs since we already have subclass metadata on them
grad_inputs = wrapped_outs[1]
subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs)
# Step 3: Unwrap any subclass outputs back into dense tensors
unwrapped_outs = unwrap_tensor_subclasses(
wrapped_outs, is_joint_structure=use_trace_joint
)
return unwrapped_outs
def joint_fn(primals, tangents):
with maybe_enable_thunkify():
return inner_fn(
flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
)
def fw_fn(*primals):
with maybe_enable_thunkify():
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
def metadata_fn(*primals):
return inner_fn(fw_only, primals, use_trace_joint=False)
args_unwrapped = unwrap_tensor_subclasses(
args, is_joint_structure=is_joint_structure
)
remapped_static_indices = remap_unwrapped_subclass_arg_indices(
args, meta.static_input_indices
)
if is_joint_structure:
primals_unwrapped = args_unwrapped[0]
fn_to_trace = joint_fn
else:
primals_unwrapped = args_unwrapped
fn_to_trace = fw_fn
# Note: [Partitioner handling for Subclasses, Part 1]
# The way the partitioner works is that:
# (1) we pass is a single graph containing the joint fw/bw,
# where the # of graph outputs corresponds to # fw_outputs + # grad_inputs
# (2) The partitioner accepts an arguments, num_fwd_outputs,
# and assumes that the first "num_fwd_outputs" graph outputs correspond
# to outputs of the forward graph.
# How do tensor subclasses enter the picture?
# the num_fwd_outputs in the final graph is actually non-trivial to compute,
# because it can be influenced by input mutations and intermediate bases.
# So we compute it by inspecting the current ViewAndMutationMeta object.
# However, the original ViewAndMutationMeta that we computed was created
# on the subclass -> subclass graph,
# which can have a different number of outputs than the dense -> dense graph.
# That's why we createa a fresh metadata object on the dense -> dense function here,
# and plumb it back up to the partitioner.
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
metadata_fn,
static_input_indices=remapped_static_indices,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
)(*primals_unwrapped)
subclass_meta.fw_metadata = meta_updated
return SubclassTracingInfo(
plain_tensor_trace_fn=fn_to_trace,
plain_tensor_args=args_unwrapped,
maybe_subclass_meta=subclass_meta,
)
def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
# Redundant with dynamo, but worth having in case this gets invoked elsewhere.
# https://github.com/pytorch/pytorch/issues/103569
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
), maybe_disable_thunkify():
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
)
with torch.autograd.detect_anomaly(check_nan=False):
detect_fake_mode().epoch += 1
out = PropagateUnbackedSymInts(mod).run(
*args[params_len:], **kwargs
)
else:
out = mod(*args[params_len:], **kwargs)
if not isinstance(out, (tuple, list)):
raise RuntimeError(
"Graph output must be a (). This is so that we can avoid "
"pytree processing of the outputs. Please change the module to "
"have tuple outputs or use aot_module instead."
)
return out
# Note [Preserving the nn module stack metadata during export non-strict mode]
# This path is currently only used by the non-strict export flow,
# where we cannot rely on dynamo to preserve nn stack metadata in our captured graph.
# Instead, we stash the original user nn module here, and rely on `make_fx` to grab
# this stashed module and use it to track nn module stack metadata
if store_orig_mod and not hasattr(functional_call, "_orig_mod"):
functional_call._orig_mod = mod # type: ignore[attr-defined]
return functional_call

View File

@ -0,0 +1,446 @@
# mypy: allow-untyped-defs
"""
Contains various utils for AOTAutograd, including those for handling collections.
"""
import dataclasses
import operator
import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import getArtifactLogger
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import py_sym_types
KNOWN_TYPES = [
torch.Tensor,
BackwardState,
int,
str,
float,
bool,
type(None),
*py_sym_types,
FakeScriptObject,
torch.ScriptObject,
]
original_zip = zip
aot_graphs_effects_log = getArtifactLogger(__name__, "aot_graphs_effects")
def strict_zip(*iterables, strict=True, **kwargs):
if not strict:
return original_zip(*iterables, **kwargs)
length = len(iterables[0])
for iterable in iterables[1:]:
if len(iterable) != length:
raise ValueError(
"The iterables have different lengths and strict mode is enabled."
)
return original_zip(*iterables, **kwargs)
def _get_symint_hints(exprs):
"""
Get the hints of a list/tuple of int/SymInt.
"""
if isinstance(exprs, (list, tuple)):
return type(exprs)(_get_symint_hints(e) for e in exprs)
elif isinstance(exprs, torch.SymInt):
return exprs.node.shape_env.size_hint(exprs.node.expr)
else:
return exprs
def partial_flatten_asdict(obj: Any) -> Any:
if dataclasses.is_dataclass(obj):
return {
field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
}
elif isinstance(obj, (list, tuple)):
return obj.__class__([partial_flatten_asdict(item) for item in obj])
elif isinstance(obj, dict):
return {k: partial_flatten_asdict(v) for k, v in obj.items()}
else:
return obj
def normalize_as_list(x):
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
def _get_autocast_states():
return [
torch.is_autocast_enabled("cuda"),
torch.is_autocast_enabled("cpu"),
torch.get_autocast_dtype("cuda"),
torch.get_autocast_dtype("cpu"),
torch.is_autocast_cache_enabled(),
]
def make_boxed_func(f):
def g(args):
return f(*args)
g._boxed_call = True # type: ignore[attr-defined]
return g
def make_boxed_compiler(compiler):
@wraps(compiler)
def f(fx_g, inps):
out_f = compiler(fx_g, inps)
fx_g = make_boxed_func(out_f)
return fx_g
return f
def call_func_at_runtime_with_args(
f, args: Union[Tuple[Any], List[Any]], steal_args=False, disable_amp=False
):
if not steal_args:
args = list(args)
assert isinstance(args, list)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
warnings.warn(
"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
return out
# Inspired by autodidax (thanks!)
class PytreeThunk:
spec: Optional[pytree.TreeSpec] = None
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
is_simple: Optional[
bool
] = None # if the output spec is a tuple/list, we won't bother unflattening it.
is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec
def set(self, spec: pytree.TreeSpec) -> None:
assert self.spec is None or self.spec == spec
assert spec is not None
self.spec: pytree.TreeSpec = spec
if self.spec.type in {tuple, list} and all(
child.is_leaf() for child in spec.children_specs
):
self.is_simple = True
if self.spec.is_leaf():
self.is_really_simple = True
def unflatten(self, x: List[Any]) -> Any:
if self.is_really_simple:
return x[0]
if self.is_simple:
return x
assert self.spec is not None
return pytree.tree_unflatten(x, self.spec)
# Creates a function that returns flattened inputs and outputs
# Also returns the output tree spec, which is needed to recover the "unflattened"
# output tree structure later.
def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
if kwargs is None:
kwargs = {}
# Save the args_spec for flat_tensor_args to unflatten while tracing
_, tensor_args_spec = pytree.tree_flatten((args, kwargs))
out_spec = PytreeThunk()
def flat_fn(*flat_args):
# The input are flattened tensor args. Prepare the args in the
# order that original function expects. Add static args as well.
# They will appear as tensor constants in the traced graph.
nonlocal out_spec
args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
tree_out = fn(*args, **kwargs)
flat_out, spec = pytree.tree_flatten(tree_out)
for i in flat_out:
is_known_type = False
for j in KNOWN_TYPES:
if isinstance(i, j):
is_known_type = True
break
if not is_known_type:
raise RuntimeError(
f"Found {type(i)} in output, which is not a known type. "
"If this type holds tensors, you need to register a pytree for it. "
"See https://github.com/pytorch/functorch/issues/475 for a brief "
"explanation why. If you don't need to register a pytree, please "
"leave a comment explaining your use case and we'll make this more "
"ergonomic to deal with"
)
out_spec.set(spec)
return flat_out
# Can't use functools.wraps here because the wrapper has different
# calling convention
if hasattr(fn, "_orig_mod"):
flat_fn._orig_mod = fn._orig_mod # type: ignore[attr-defined]
return flat_fn, out_spec
# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
# When tracing the joint forward + backward, for any inputs in the graph that are mutated,
# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
# The idea is that when we trace the backward, we need to pass in the *original* primals
# to autograd.grad(), before they were mutated.
# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
# This means that "idx" here represents the index of the (potentially) synthetic base.
# What we need to do is:
# (1) map the current (post-synthetic-base calling convention) input argument index
# to int index pre-synthetic-base-calling-convention.
# (2) There could be multiple, if this index corresponds to a synthetic base
# that has multiple input aliases.
# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
def maybe_to_fresh_input(idx, t, meta):
if not isinstance(t, torch.Tensor):
return t
if idx in meta.mutated_inp_runtime_indices:
# We only need to bother cloning mutated inputs that participate in autograd.
mutated_inp_idx = meta.mutated_inp_runtime_indices.index(idx)
if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data:
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the mutation
return t.clone()
if meta.input_info[idx] and meta.input_info[idx].mutates_metadata:
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the metadata mutation
return t.view(t.shape)
return t
def is_with_effects(node):
return (
node.op == "call_function"
and node.target == torch.ops.higher_order.with_effects
)
def is_with_effects_op(node, op):
return is_with_effects(node) and node.args[1] == op
def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
# Remove the tokens from the inputs/outputs of the graph since inductor does
# not want these extra inputs/outputs, and replace them with
# _make_token() to create a token, and _sink_tokens() to collect the
# tokens. See Note [Side-Effectful Tokens in AOTAutograd]
# Logic:
# 1. Inputs identified as input tokens:
# - If used as a first argument in with_effects
#
# 2. Outputs identified as output tokens:
# - If Produced by getitem(with_effects, 0)
#
# 3. Checks invariants of number input output tokens:
# forward:
# expected_num_erased_inputs == len(fw_metadata.tokens)
# expected_num_erased_outputs == len(fw_metadata.tokens)
# backward:
# expected_num_erased_inputs == fw_metadata.num_backward_tokens
# expected_num_erased_outputs == fw_metadata.num_backward_tokens
num_forward_tokens = len(fw_metadata.tokens)
num_backward_tokens = fw_metadata.num_backward_tokens
def rewrite_with_effects_input_token(module, node):
with module.graph.inserting_before(node):
new_token_node = module.graph.call_function(
torch.ops.prims._make_token.default, ()
)
new_token_node.meta["val"] = torch.tensor([])
new_token_node.meta["tensor_meta"] = torch.tensor([])
args = list(node.args)
args[0] = new_token_node
node.args = tuple(args)
def rewrite_output(module, node, output_token_nodes, other_output_args):
for output_token_node in output_token_nodes:
assert (
output_token_node.op == "call_function"
and output_token_node.target == operator.getitem
and output_token_node.args[1] == 0
)
with module.graph.inserting_before(node):
module.graph.call_function(
torch.ops.prims._sink_tokens.default,
(output_token_nodes,),
)
node.args = (other_output_args,)
def do(module, subgraph, expected_num_erased):
num_erased_inputs = 0
num_erased_outs = 0
input_nodes = []
input_token_nodes = set()
with_effect_nodes = []
output_token_nodes = []
other_output_nodes = []
for i, node in enumerate(module.graph.nodes):
if node.op == "placeholder":
input_nodes.append(node)
elif is_with_effects(node):
with_effect_nodes.append(node)
if node.args[0] in input_nodes:
input_token_nodes.add(node.args[0])
rewrite_with_effects_input_token(module, node)
elif node.op == "output":
outs = node.args[0]
for out in outs:
if (
isinstance(out, torch.fx.node.Node)
and out.op == "call_function"
and out.target == operator.getitem
and out.args[1] == 0
and out.args[0] in with_effect_nodes
):
output_token_nodes.append(out)
else:
other_output_nodes.append(out)
rewrite_output(module, node, output_token_nodes, other_output_nodes)
num_erased_outs = len(output_token_nodes)
for input_token_node in input_token_nodes:
module.graph.erase_node(input_token_node)
num_erased_inputs = len(input_token_nodes)
assert (
num_erased_inputs == expected_num_erased
), f"{subgraph} num_erased_inputs:{num_erased_inputs} {input_token_nodes}!=expected {expected_num_erased}"
assert (
num_erased_outs == expected_num_erased
), f"{subgraph} num_erased_outs:{num_erased_outs} {output_token_nodes}!=expected {expected_num_erased}"
module.recompile()
if num_forward_tokens > 0:
if aot_config.enable_log:
from torch._dynamo.utils import lazy_format_graph_code
aot_graphs_effects_log.debug(
"%s",
lazy_format_graph_code(
"Forward graph before unlifting tokens",
fw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
do(
fw_module,
"forward",
num_forward_tokens,
)
if bw_module is not None and num_backward_tokens > 0:
if aot_config.enable_log:
from torch._dynamo.utils import lazy_format_graph_code
aot_graphs_effects_log.debug(
"%s",
lazy_format_graph_code(
"Backward graph before unlifting tokens",
bw_module,
aot_config.aot_id,
include_stride=True,
include_device=True,
colored=True,
),
)
do(bw_module, "backward", num_backward_tokens)
# This is sad, but we need to update the metadata to get rid of
# the tokens.
fw_metadata.tokens = {}
fw_metadata.num_backward_tokens = 0
def root_module_when_exporting_non_strict(flat_fn):
# When exporting in non-strict mode, we wrap the root module in a specific pattern.
# See `_aot_export_non_strict` in torch.export._trace.py.
# We look for that wrapping pattern here.
if hasattr(flat_fn, "_orig_mod") and hasattr(flat_fn._orig_mod, "_export_root"):
return flat_fn._orig_mod._export_root
else:
return None
def copy_fwd_metadata_to_bw_nodes(fx_g):
"""
Input: `fx_g` which contains the joint fwd+bwd FX graph created by
aot_autograd.
This function walks the graph and copies over metadata from forward nodes
to backward nodes, using the `seq_nr` field as a one-to-many mapping
from forward node to backward node. This metadata is useful for performance
profiling and debugging.
"""
def _is_forward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is populated, this
# node is from the forward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this by walking
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
if not _is_forward_node_with_seq_nr(node):
continue
seq_nr = node.meta["seq_nr"]
if seq_nr in fwd_seq_nr_to_node:
# If we already saw an op with the current `seq_nr`, that means
# that the current op did not create an autograd node, and there
# is no corresponding backward node, so we skip.
continue
fwd_seq_nr_to_node[node.meta["seq_nr"]] = node
for node in fx_g.graph.nodes:
if not _is_backward_node_with_seq_nr(node):
continue
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,449 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can
# trace through functorch transforms.
# Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing
# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
# to Dynamo.
import functools
from torch._functorch.utils import argnums_t, exposed_in
from torch._functorch.vmap import (
_check_out_dims_is_int_or_int_pytree,
_check_randomness_arg,
_chunked_vmap,
_process_batched_inputs,
Callable,
in_dims_t,
out_dims_t,
vmap_impl,
)
# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
# sends those into func, and then unwraps the output BatchedTensors. Operations
# on BatchedTensors perform the batched operations that the user is asking for.
#
# vmap's randomness behavior differs from JAX's, which would require a PRNG key
# to be passed everywhere.
@exposed_in("torch.func")
def vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = "error",
*,
chunk_size=None,
) -> Callable:
"""
vmap is the vectorizing map; ``vmap(func)`` returns a new function that
maps ``func`` over some dimension of the inputs. Semantically, vmap
pushes the map into PyTorch operations called by ``func``, effectively
vectorizing those operations.
vmap is useful for handling batch dimensions: one can write a function
``func`` that runs on examples and then lift it to a function that can
take batches of examples with ``vmap(func)``. vmap can also be used to
compute batched gradients when composed with autograd.
.. note::
:func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
convenience. Use whichever one you'd like.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. ``in_dims`` should have a
structure like the inputs. If the ``in_dim`` for a particular
input is None, then that indicates there is no map dimension.
Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If ``out_dims`` is a Tuple, then
it should have one element per output. Default: 0.
randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If 'different',
the randomness for each batch will be different. If 'same', the
randomness will be the same across batches. If 'error', any calls to
random functions will error. Default: 'error'. WARNING: this flag
only applies to random PyTorch operations and does not apply to
Python's random module or numpy randomness.
chunk_size (None or int): If None (default), apply a single vmap over inputs.
If not None, then compute the vmap :attr:`chunk_size` samples at a time.
Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
If you run into memory issues computing the vmap, please try a non-None chunk_size.
Returns:
Returns a new "batched" function. It takes the same inputs as
``func``, except each input has an extra dimension at the index
specified by ``in_dims``. It takes returns the same outputs as
``func``, except each output has an extra dimension at the index
specified by ``out_dims``.
.. warning:
:func:`vmap` works best with functional-style code. Please do not
perform any side-effects in ``func``, with the exception of
in-place PyTorch operations. Examples of side-effects include mutating
Python data structures and assigning values to variables not captured
in ``func``.
One example of using :func:`vmap` is to compute batched dot products. PyTorch
doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
rummaging through docs, use :func:`vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
:func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
model authoring experience.
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)
:func:`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
we can vectorize the whole computation, computing the Jacobian in a single
call to ``autograd.grad``.
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>> for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension, ``in_dims`` specifies
the dimension that each inputs are batched along as
>>> torch.dot # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
``in_dims`` must be a tuple with the batch dimension for each input as
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
matching the shape of the input:
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched
along any dimension by using ``out_dims``
>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will
accept kwargs
>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>> return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
.. note::
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
from torch._dynamo import is_compiling
_check_randomness_arg(randomness)
if not (chunk_size is None or chunk_size > 0):
raise ValueError(
f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
)
def wrapped(*args, **kwargs):
return vmap_impl(
func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
)
if not is_compiling():
wrapped = functools.wraps(func)(wrapped)
return wrapped
def chunk_vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = "error",
chunks=2,
) -> Callable:
"""
chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
chunks at a time. For more details about vectorizing map, see :func:`vmap`.
.. note::
Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. ``in_dims`` should have a
structure like the inputs. If the ``in_dim`` for a particular
input is None, then that indicates there is no map dimension.
Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If ``out_dims`` is a Tuple, then
it should have one element per output. Default: 0.
randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If 'different',
the randomness for each batch will be different. If 'same', the
randomness will be the same across batches. If 'error', any calls to
random functions will error. Default: 'error'. WARNING: this flag
only applies to random PyTorch operations and does not apply to
Python's random module or numpy randomness.
chunks (int): Number of chunks to use to split the input data. Default is 2.
If equals to 1 then :func:`vmap` is called.
Returns:
Returns a new "batched" function. It takes the same inputs as
``func``, except each input has an extra dimension at the index
specified by ``in_dims``. It takes returns the same outputs as
``func``, except each output has an extra dimension at the index
specified by ``out_dims``.
"""
_check_randomness_arg(randomness)
if chunks == 1:
return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_):
flat_args_chunks = tuple(
t.chunk(chunks_, dim=in_dim)
if in_dim is not None
else [
t,
]
* chunks_
for t, in_dim in zip(flat_args_, flat_in_dims_)
)
# transpose chunk dim and flatten structure
# chunks_flat_args is a list of flatten args
chunks_flat_args = zip(*flat_args_chunks)
return chunks_flat_args
@functools.wraps(func)
def wrapped_with_chunks(*args, **kwargs):
_check_out_dims_is_int_or_int_pytree(out_dims, func)
_, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
in_dims, args, func
)
# Chunk flat arguments
chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
# Apply vmap on chunks
return _chunked_vmap(
func,
flat_in_dims,
chunks_flat_args,
args_spec,
out_dims,
randomness,
**kwargs,
)
return wrapped_with_chunks
@exposed_in("torch.func")
def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""``grad`` operator helps computing gradients of ``func`` with respect to the
input(s) specified by ``argnums``. This operator can be nested to
compute higher-order gradients.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
function can return a tuple of single-element Tensor and other auxiliary objects:
``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
``argnums`` can be single integer or tuple of integers. Default: 0.
has_aux (bool): Flag indicating that ``func`` returns a tensor and other
auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute gradients with respect to its inputs. By default, the output of
the function is the gradient tensor(s) with respect to the first argument.
If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
respect to each ``argnums`` value is returned.
Example of using ``grad``:
>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> # xdoctest: +SKIP
>>> from torch.func import grad, vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>> y = model(weights, example)
>>> return ((y - target) ** 2).mean() # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
Example of using ``grad`` with ``has_aux`` and ``argnums``:
>>> # xdoctest: +SKIP
>>> from torch.func import grad
>>> def my_loss_func(y, y_pred):
>>> loss_per_sample = (0.5 * y_pred - y) ** 2
>>> loss = loss_per_sample.mean()
>>> return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
.. note::
Using PyTorch ``torch.no_grad`` together with ``grad``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> # xdoctest: +SKIP
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
>>> # xdoctest: +SKIP
>>> with torch.no_grad():
>>> grad(f)(x)
In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``grad`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
# To avoid cyclical dependency.
import torch._functorch.eager_transforms as eager_transforms
from torch._dynamo import is_compiling
def wrapper(*args, **kwargs):
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
if not is_compiling():
wrapper = functools.wraps(func)(wrapper)
return wrapper
@exposed_in("torch.func")
def grad_and_value(
func: Callable, argnums: argnums_t = 0, has_aux: bool = False
) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified ``has_aux``
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. ``argnums`` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that ``func`` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified ``has_aux`` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If ``argnums`` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each ``argnums`` value and the forward computation is returned.
See :func:`grad` for examples
"""
from torch._dynamo import is_compiling
from torch._functorch import eager_transforms
def wrapper(*args, **kwargs):
return eager_transforms.grad_and_value_impl(
func, argnums, has_aux, args, kwargs
)
if not is_compiling():
wrapper = functools.wraps(func)(wrapper)
return wrapper

View File

@ -0,0 +1,752 @@
# mypy: allow-untyped-defs
from typing import Any, NamedTuple, Tuple
import torch
import torch.utils._pytree as pytree
from torch._C._functorch import (
_unwrap_for_grad,
_wrap_for_grad,
current_level,
TransformType,
)
from torch._functorch.apis import vmap
from torch._functorch.utils import enable_single_level_autograd_function
from torch._functorch.vmap import (
_add_batch_dim,
_broadcast_to_and_flatten,
restore_vmap,
unwrap_batched,
wrap_batched,
)
from torch._ops import HigherOrderOperator
from torch.autograd.forward_ad import _set_fwd_grad_enabled
# autograd.Function technically runs before the regular PyTorch dispatcher.
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
# work with it. One day we might decide to change this, but until then,
# we need to give the illusion that autograd.Function runs before those things.
#
# We do this by using creating a custom HigherOrderOperator that only functorch
# dispatches specially.
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("custom_function_call")
def __call__(self, autograd_function, *args, **kwargs):
# When custom_function_call is done dispatching through functorch,
# it should just invoke the autograd.Function. This is consistent
# with the autograd.Function behavior of being invoked before the
# PyTorch dispatcher.
#
# This will lead us into trouble later down the line, but this is
# pre-existing. There is an invariant that a function traced by
# make_fx should have the same behavior when provided the same
# Tensor. However, make_fx sees autograd.Function as a composite
# (because autograd.Function happens before the Python dispatch key)
# and only traces the forward pass.
if torch._C._are_functorch_transforms_active():
return super().__call__(autograd_function, *args, **kwargs)
return autograd_function.apply(*args, **kwargs)
# "custom_function_call"
# This is the mechanism for an autograd.Function that works with functorch transforms.
# It wraps an autograd.Function; interactions with functorch transforms are defined
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
# dispatcher.
custom_function_call = CustomFunctionHigherOrderOperator()
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
# (autograd.Function that only works with a single layer (level) of functorch) that:
# - unwraps the inputs
# - redispatches to custom_function_call
# - wraps the outputs
# and whose backward pass calls the original autograd.Function's backward.
#
# Why do we need to redispatch to custom_function_call?
# -----------------------------------------------------
# This is consistent with how ATen operators work with functorch's grad transform:
# they always redispatch to the original operator.
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
#
# grad1 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin (*)
# - rewrap the outputs on the return
#
# On the redispatch in (*), grad0 will:
# - set up the autograd graph
# - unwrap the inputs
# - redispatch to at::sin
# - rewrap the outputs on the return
#
# To "set up the autograd graph", we generate a _SingleLevelFunction
# and apply it.
@custom_function_call.py_impl(TransformType.Grad)
@custom_function_call.py_impl(TransformType.Jvp)
def custom_function_call_grad(interpreter, autograd_function, *operands):
Generated = generate_single_level_function(interpreter, autograd_function)
with enable_single_level_autograd_function():
flat_out = Generated.apply(*operands)
return flat_out
def generate_single_level_function(interpreter, autograd_function):
level = interpreter.level()
def forward(*operands):
unwrapped_operands = pytree.tree_map_only(
torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
)
# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
# the transform. _SingleLevelFunction will turn off both fwd and bwd
# gradient computation and we need to turn it back on here.
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
unwrapped_output = custom_function_call(
autograd_function, *unwrapped_operands
)
# See NOTE [mark_dirty object identity check]
def wrap_fn(output):
return _wrap_for_grad(output, level)
return wrap_outputs_maintaining_identity(
unwrapped_output, unwrapped_operands, operands, wrap_fn
)
def setup_context(ctx, inputs, output):
return autograd_function.setup_context(ctx, inputs, output)
# backward is only used if the transform is TransformType.Grad
def backward(ctx, *grads):
result = autograd_function.backward(ctx, *grads)
return result
# jvp is only used if the transform is TransformType.Jvp
def jvp(ctx, *tangents):
result = autograd_function.jvp(ctx, *tangents)
return result
# This is the sequence of magic words to dynamically generate a Subclass with
# a given name. A Tensor's .grad_fn field has a class name that is the original
# autograd.Function's name + Backward, so we do this to generate some
# meaningful name.
name = f"{autograd_function.__name__}Generated"
Generated = type(
name,
(torch.autograd.function._SingleLevelFunction,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
"jvp": staticmethod(jvp),
"setup_context": staticmethod(setup_context),
},
)
return Generated
# wrap_outputs_maintaining_identity handles outputs from the vmap,
# backward (vjp), and jvp staticmethod. The way it distinguishes
# between the vmap case and the {backward, jvp} case is if the out_dims
# are specified or not.
#
# NB: we cannot use out_dims=None as the deciding factor. This because
# out_dims=None can still happen in the vmap staticmethod! What the
# user is saying in that case is that their output does not have a
# dimension that is being vmapped over, which is valid.
NO_OUT_DIMS = "not specified"
# NOTE [mark_dirty object identity check]
# autograd.Function's ctx.mark_dirty expect a returned input
# to have the same object identity as the input.
# Mode-only functorch will greatly simplify this logic.
def wrap_outputs_maintaining_identity(
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
):
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
unwrapped_input_to_orig_input = {
id(unwrapped): orig
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
}
flat_outputs, spec = pytree.tree_flatten(outputs)
result = []
out_dims_specified = out_dims != NO_OUT_DIMS
if out_dims_specified:
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
# _broadcast_to_and_flatten returns None if it is unable to broadcast.
# TODO: update following link from master to stable once that's out
if flat_out_dims is None:
raise RuntimeError(
f"The autograd.Function's vmap staticmethod returned an "
f"incompatible (output, out_dims) tuple. "
f"Expected out_dims={out_dims} "
f"to be compatible with the structure of `output`. "
f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
f"but output has structure {spec}. "
f"For more details, please see "
f"https://pytorch.org/docs/main/notes/extending.func.html"
)
for i, output in enumerate(flat_outputs):
if not isinstance(output, torch.Tensor):
result.append(output)
continue
if id(output) in unwrapped_input_to_orig_input:
result.append(unwrapped_input_to_orig_input[id(output)])
continue
if out_dims_specified:
result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
else:
result.append(wrap_fn(output))
return pytree.tree_unflatten(result, spec)
# NOTE: [functorch vjp and autograd interaction]
# There's an edge case with the functorch vjp and autograd interaction
# that will eventually be fixed by mode-only functorch.
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
# so we (the framework) need to do it manually. Regular PyTorch operators
# automatically do so this is consistent.
#
# class MyExp(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.exp()
#
# @staticmethod
# def setup_context(ctx, inputs, output):
# y = output
# ctx.save_for_backward(y)
#
# @staticmethod
# def backward(gy):
# y, = ctx.saved_tensors()
# return MyMul.apply(gy, y)
#
# x = torch.randn([], requires_grad=True)
# gy = torch.randn([], requires_grad=True)
# _, vjp_fn = vjp(MySin.apply, x)
# result = vjp_fn(gy)
#
# MyMul is an autograd.Function that is not shown here.
# It saves a `y` for backward (since gy requires grad).
#
# in vjp_fn(gy), we get:
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
# but is now dead since we are outside the vjp context.
#
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
# will automatically unwrap the GradTensorWrapper when applied.
# But since autograd.Function technically sits above the regular PyTorch
# dispatcher, it doesn't get this treatment. So we manually do
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
class VmapInfo(NamedTuple):
batch_size: int
randomness: str
def has_overriden_vmap_rule(autograd_function):
return autograd_function.vmap is not torch.autograd.Function.vmap
def validate_vmap_returns_tuple_of_two_elements(result):
base_error_msg = (
"Expected the vmap staticmethod to have two returns, an output "
"and out_dims with pytree structure compatible with the output. "
)
if not isinstance(result, tuple):
raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
if not len(result) == 2:
raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
@custom_function_call.py_impl(TransformType.Vmap)
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
if any(
isinstance(val, torch.Tensor)
for val in torch.utils._pytree.tree_flatten(kwargs)[0]
):
raise NotImplementedError(
f"Run vmap on autograd.Function with kwarg-only Tensor args. "
f"Please do not pass kwarg-only Tensors to autograd.Function. "
f"Got: {kwargs}"
)
if autograd_function.generate_vmap_rule:
if has_overriden_vmap_rule(autograd_function):
# TODO: Update link to stable once that's out
# https://github.com/pytorch/pytorch/issues/92029
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it has both generate_vmap_rule=True and an overriden vmap "
f"staticmethod. Please set generate_vmap_rule=False or delete "
f"the overriden vmap staticmethod to avoid ambiguity. "
f"For more details, please see "
f"https://pytorch.org/docs/main/notes/extending.func.html"
)
return custom_function_call_vmap_generate_rule(
interpreter, autograd_function, *operands
)
if not has_overriden_vmap_rule(autograd_function):
# TODO: Update link to stable once that's out
# https://github.com/pytorch/pytorch/issues/92029
raise RuntimeError(
f"You tried to vmap over {autograd_function.__name__}, but "
f"it does not have vmap support. Please override and implement the "
f"vmap staticmethod or set generate_vmap_rule=True. "
f"For more details, please see "
f"https://pytorch.org/docs/main/notes/extending.func.html"
)
return custom_function_call_vmap_helper(
interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
)
def custom_function_call_vmap_helper(
interpreter, vmap_function, op, *operands, **kwargs
):
current_level = interpreter.level()
info = VmapInfo(
batch_size=interpreter.batch_size(),
randomness=interpreter.randomness(),
)
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
# If none of the tensors are batched at the current level, then we skip the
# current level. This saves the user from needing to handle this case in
# their vmap staticmethod (and is consistent with our C++ batching rule API)
if pytree.tree_all(lambda dim: dim is None, in_dims):
with interpreter.lower():
if isinstance(op, torch.autograd.function.FunctionMeta):
return custom_function_call(op, *operands)
else:
return op(*operands, **kwargs)
with interpreter.lower():
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
validate_vmap_returns_tuple_of_two_elements(result)
unwrapped_output, out_dims = result
# See NOTE [mark_dirty object identity check]
def wrap_fn(output, out_dim):
return (
output
if out_dim is None
else _add_batch_dim(output, out_dim, current_level)
)
return wrap_outputs_maintaining_identity(
unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
)
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
vmapped_function, get_out_dims = vmapify_autograd_function(
autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
)
with interpreter.lower():
output = custom_function_call(vmapped_function, *unwrapped_operands)
out_dims = get_out_dims()
return wrap_batched(output, out_dims, interpreter.level())
@custom_function_call.py_impl(TransformType.Functionalize)
def custom_function_call_functionalize(
interpreter, autograd_function, generate_vmap_rule, *operands
):
raise RuntimeError("NYI: Functionalize rule for custom_function_call")
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
# The following values are saved from the forward() and setup_context()
# and used in backward().
# Why do we save the values out here instead of on the ctx object?
# - out_dims: There's no way to retrieve this from forward()
# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
# vmap(vmap( but not completely sure if it is a problem. If we
# assigned those fields to the ctx object, the worry is that they
# get overwritten.
init_val = "not populated"
out_dims = init_val
input_shapes: Any = init_val
saved_tensors_bdims: Any = init_val
def forward(*operands):
nonlocal out_dims
outputs, out_dims = restore_vmap(
autograd_function.forward, in_dims, batch_size, randomness
)(*operands)
return outputs
def setup_context(ctx, inputs, outputs):
input_shapes_ = None
saved_tensors_bdims_ = None
def inner(inputs, outputs):
# wrapped_ctx.save_for_backward will:
# - unwrap batchedtensors into (tensor, bdim)
# - save_for_backward(*unwrapped_tensors)
# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
wrapped_ctx = CtxCustomSave(ctx, current_level())
autograd_function.setup_context(wrapped_ctx, inputs, outputs)
# input_shapes are used for reductify later to reduce expanded gradients
# to the correct shape.
# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
# for more details
nonlocal input_shapes_
input_shapes_ = tuple(
inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
)
nonlocal saved_tensors_bdims_
saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
# See NOTE: [Why do we need to run setup_context under a vmap?]
restore_vmap(
inner,
(in_dims, out_dims),
batch_size,
randomness,
)(inputs, outputs)
nonlocal input_shapes
input_shapes = input_shapes_
nonlocal saved_tensors_bdims
saved_tensors_bdims = saved_tensors_bdims_
def jvp(ctx, *tangents):
assert out_dims != init_val
assert saved_tensors_bdims != init_val
def jvp_no_context(saved_tensors, tangents):
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
return autograd_function.jvp(wrapped_ctx, *tangents)
tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
out_tangents, out_tangents_dims = restore_vmap(
jvp_no_context,
(saved_tensors_bdims, tangent_in_dims),
batch_size,
randomness,
)(ctx.saved_tensors, tangents)
result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
return result
def backward(ctx, *grad_outputs):
assert out_dims != init_val
assert input_shapes != init_val
assert saved_tensors_bdims != init_val
def backward_no_context(inputs):
saved_tensors, grad_outputs = inputs
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
return autograd_function.backward(wrapped_ctx, *grad_outputs)
grad_ins, grad_ins_dims = restore_vmap(
backward_no_context,
((saved_tensors_bdims, out_dims),),
batch_size,
randomness,
)((ctx.saved_tensors, grad_outputs))
result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
return result
name = f"Vmapped{autograd_function.__name__}"
Generated = type(
name,
(torch.autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
"jvp": staticmethod(jvp),
"setup_context": staticmethod(setup_context),
"generate_vmap_rule": True,
},
)
def get_out_dims():
assert out_dims != init_val
return out_dims
return Generated, get_out_dims
# tangents might be None, so we need to replace
# the corresponding in_dims with None.
def get_tangents_in_dims(input_dims, tangents):
flat_in_dims, spec = pytree.tree_flatten(input_dims)
flat_tangents = pytree.arg_tree_leaves(*tangents)
result = [
None if tangent is None else in_dim
for in_dim, tangent in zip(flat_in_dims, flat_tangents)
]
return pytree.tree_unflatten(result, spec)
# NOTE: [Why do we need to run setup_context under a vmap?]
# Consider the following autograd.Function
#
# class Sum(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return x.sum()
# @staticmethod
# def setup_context(ctx, inputs, outputs):
# ctx.x_shape = inputs[0]
# @staticmethod
# def backward(ctx, gy):
# return gy.expand(ctx.x_shape)
#
# x = torch.randn(B, 4)
# in_dims = 0
# vmap(Sum.apply, in_dims)(x)
#
# Let's assume for a moment that we didn't vmap setup_context in VmappedSum:
#
# class VmappedSum(torch.autograd.Function):
# @staticmethod
# def forward(x):
# return vmap(Sum.forward, in_dims)(x)
#
# @staticmethod
# def setup_context(ctx, inputs, outputs):
# Sum.setup_context(ctx, inputs, outputs)
#
# @staticmethod
# def backward(ctx, gy):
# def backward_no_context(gy):
# return gy.expand(ctx.x_shape)
#
# dims = (0,)
# gx = vmap(backward_no_context, dims)(gy)
# return gx
#
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
# and we're doing:
#
# def backward_no_context(gy):
# return gy.expand([B, 4])
#
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
#
# This gives us the wrong result (gx has shape [B, B, 4], but it should
# have shape [4]). Performing vmap over setup_context means the shape
# saved has shape [4] and leads to a correct result shape for gx.
# Wraps a ctx object. Forwards all attr accesses to the underlying object
# except for the attrs in _pt_attrs
class WrappedCtx:
_pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
def __init__(self, ctx):
if not isinstance(ctx, WrappedCtx):
reserved_attrs = type(self)._pt_reserved_attrs
for name in reserved_attrs:
if not hasattr(ctx, name):
continue
raise RuntimeError(
f"PyTorch reserves the {reserved_attrs} field on ctx. "
"Please name your fields on ctx something else to avoid name "
"collision."
)
self._pt_inner_ctx = ctx
def __getattr__(self, name):
return getattr(self._pt_inner_ctx, name)
def __setattr__(self, name, value):
if name in type(self)._pt_reserved_attrs:
self.__dict__[name] = value
return
return setattr(self._pt_inner_ctx, name, value)
# Wraps ctx to create a new ctx object that overrides saved_tensors.
class CtxWithSavedTensors(WrappedCtx):
_pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
def __init__(self, ctx, new_saved_tensors):
super().__init__(ctx)
self._pt_new_saved_tensors = new_saved_tensors
@property
def saved_tensors(self):
return self._pt_new_saved_tensors
class CtxCustomSave(WrappedCtx):
_pt_reserved_attrs = (
"_pt_saved_tensors_bdims",
"_pt_current_level",
*WrappedCtx._pt_reserved_attrs,
)
def __init__(self, ctx, current_level):
super().__init__(ctx)
self._pt_saved_tensors_bdims = ()
self._pt_current_level = current_level
def save_for_backward(self, *tensors):
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
self._pt_saved_tensors_bdims = bdims
def save_for_forward(self, *tensors):
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
self._pt_saved_tensors_bdims = bdims
def reductify(
grad_input,
grad_input_bdim,
input_bdim,
batch_size,
target_shape_without_bdim_to_reduce_to=None,
):
if not isinstance(grad_input, tuple):
grad_input = (grad_input,)
if not isinstance(grad_input_bdim, tuple):
grad_input_bdim = (grad_input_bdim,)
if not isinstance(input_bdim, tuple):
input_bdim = (input_bdim,)
if target_shape_without_bdim_to_reduce_to is None:
target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
result = tuple(
reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
for gi, gi_bdim, i_bdim, maybe_ishape in zip(
grad_input,
grad_input_bdim,
input_bdim,
target_shape_without_bdim_to_reduce_to,
)
)
return result
def reductify_leaf(
grad_input,
grad_input_bdim,
input_bdim,
batch_size,
target_shape_without_bdim_to_reduce_to=None,
):
if grad_input is None:
return None
if grad_input_bdim is None and input_bdim is None:
return grad_input
if grad_input_bdim is not None and input_bdim is None:
return grad_input.sum(grad_input_bdim)
# NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
# For reverse-mode AD,
# given a grad_input and input, it is valid for the user to return a
# grad_input that has a broadcasted shape when compared to the input.
# In this situation, autograd automatically reduces the grad_input to
# the shape of the input.
#
# However, when input_bdim is not None, we have problems.
#
# [example 1]
# grad_input: Tensor[3, 4], input: Tensor[B, 4]
# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# [example 2]
# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
# from [B, 4].
#
# This means that we need to also reduce the grad_input to the shape of the
# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
# if not-None then we do the reducing manually, otherwise, we do not do a reduction.
assert input_bdim is not None
if grad_input_bdim is None:
grad_input = grad_input.unsqueeze(input_bdim)
new_shape = list(grad_input.shape)
new_shape[input_bdim] = batch_size
grad_input = grad_input.expand(new_shape)
grad_input_bdim = input_bdim
if target_shape_without_bdim_to_reduce_to is not None:
return vmap(
torch.Tensor.sum_to_size,
in_dims=(grad_input_bdim, None),
out_dims=input_bdim,
)(grad_input, target_shape_without_bdim_to_reduce_to)
if input_bdim != grad_input_bdim:
grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
return grad_input
def autograd_function_forward_rewritten(original_forward, original_setup_context):
def new_forward(ctx, *args, **kwargs):
output = original_forward(*args, **kwargs)
original_setup_context(ctx, args, output)
return output
return new_forward
class AutogradFunctionApply(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("autograd_function_apply")
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
saved_values = None
args_tensor_mask = fwd_kwargs["args_tensor_mask"]
non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
length_of_tensor_args = sum(args_tensor_mask)
# Filter out the original tensor args from fwd_args,
# lifted freevars should not be args of ApplyTemplate.apply
# since we don't need to calculate the gradients of them.
new_fwd_args = fwd_args[:length_of_tensor_args]
class ApplyTemplate(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
nonlocal saved_values
output, saved_values = fwd(None, *fwd_args)
# If users call ctx.mark_non_differentiable() in the original fwd function.
if len(non_differentiable_idx) > 0:
non_differentiable_output = []
for i, x in enumerate(output):
if i in non_differentiable_idx:
non_differentiable_output.append(x)
ctx.mark_non_differentiable(*non_differentiable_output)
return output
@staticmethod
def backward(ctx, *grad):
return bwd(None, *grad, *saved_values)
return ApplyTemplate.apply(*new_fwd_args)
autograd_function_apply = AutogradFunctionApply()

View File

@ -0,0 +1,29 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch.nn as nn
from torch._functorch.utils import exposed_in
def batch_norm_without_running_stats(module: nn.Module):
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
and module.track_running_stats
):
module.running_mean = None
module.running_var = None
module.num_batches_tracked = None
module.track_running_stats = False
@exposed_in("torch.func")
def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module:
"""
In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and
setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root`
"""
# base case
batch_norm_without_running_stats(root)
for obj in root.modules():
batch_norm_without_running_stats(obj)
return root

View File

@ -0,0 +1,231 @@
# mypy: ignore-errors
import contextlib
import json
import operator
import os
import time
import torch
from torch.profiler import profile, ProfilerActivity
def synchronize():
pass
def dump_chrome_trace(
f,
input,
trace_filename,
optimize_ctx,
activities,
num_runs=1,
devices=None,
kwargs_for_f=None,
kwargs_for_profiler=None,
):
"""
Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx]
[num_runs] times to [trace_filename].
[activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA.
Return total runtime without the profiler
Outputs to trace_filename
"""
if devices is None:
devices = ["cuda"]
global synchronize
if devices != ["cpu"] and torch.cuda.is_available():
synchronize = torch.cuda.synchronize
if kwargs_for_f is None:
kwargs_for_f = {}
if kwargs_for_profiler is None:
kwargs_for_profiler = {}
with optimize_ctx:
torch.manual_seed(1337)
for _ in range(5): # warmup runs
f(input, **kwargs_for_f)
synchronize()
torch.manual_seed(1337)
t0 = time.perf_counter()
for _ in range(num_runs):
f(input, **kwargs_for_f)
synchronize()
t1 = time.perf_counter()
timing = t1 - t0
with profile(activities=activities, **kwargs_for_profiler) as prof:
with optimize_ctx:
synchronize()
torch.manual_seed(1337)
for _ in range(num_runs):
f(input, **kwargs_for_f)
synchronize()
prof.export_chrome_trace(trace_filename)
return timing
def get_chrome_trace_events(filename):
f = open(filename)
data = json.load(f)
events = data["traceEvents"]
return events
def is_gpu_compute_event(event):
global gpu_pids
return (
"pid" in event
and event["pid"] in gpu_pids
and "ph" in event
and event["ph"] == "X"
)
def get_sorted_gpu_events(events):
sorted_gpu_events = []
for event in events:
if not is_gpu_compute_event(event):
continue
sorted_gpu_events.append(event)
return sorted(sorted_gpu_events, key=operator.itemgetter("ts"))
def get_duration(sorted_gpu_events):
if len(sorted_gpu_events) == 0:
return 0
event = sorted_gpu_events[0]
current_end_time = event["ts"] + event["dur"]
total_duration = event["dur"]
for event in sorted_gpu_events[1:]:
start_time = max(event["ts"], current_end_time)
end_time = event["ts"] + event["dur"]
total_duration = total_duration + max(end_time - start_time, 0)
current_end_time = max(current_end_time, end_time)
return total_duration
def get_sorted_gpu_mm_conv_events(events):
def is_mm_conv_event(event):
return "name" in event and (
"gemm" in event["name"]
or "conv" in event["name"]
or "cutlass" in event["name"]
or "wgrad" in event["name"]
)
gpu_events = get_sorted_gpu_events(events)
sorted_events = []
for event in gpu_events:
if not is_mm_conv_event(event):
continue
sorted_events.append(event)
return sorted_events
gpu_pids = []
def compute_utilization(filename: str, total_length: float):
"""
Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization
and percent of times spent on matmul and convolution
Args:
filename(str): Name of chrome traces file produced by pytorch profiler
total_length(float): total length of the process without profiler in second
Return:
tuple: (GPU Utilization, percent of time spent on matmul and convolution)
"""
events = get_chrome_trace_events(filename)
# get pids of GPU events
global gpu_pids
gpu_pids = []
for event in events:
if "name" not in event:
continue
if event["name"] == "process_labels" and "GPU" in event["args"]["labels"]:
gpu_pids.append(event["pid"])
total_length = total_length * 1e6
sorted_gpu_events = get_sorted_gpu_events(events)
utilization = get_duration(sorted_gpu_events) / total_length
sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events)
mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length
return utilization, mm_conv_utilization
def benchmark_utilization(
f,
input,
trace_folder,
optimize_ctx=None,
trace_file_name="tmp_chrome_trace",
num_runs=1,
):
"""
Benchmark the GPU Utilization and percent of time spent on matmul and convolution operations of
running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times.
It will produce a chrome trace file in trace_folder/trace_file_name.json
Example:
```
def f(a):
return a.sum()
a = torch.rand(2**20, device="cuda")
utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace")
```
Args:
f: function to benchmark
input: input to :attr:`f`
trace_folder: name of the folder to store the chrome trace
optimize_ctx: the context in which f will run
trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace"
num_runs: number of times to run f, excluding the warm-up runs, default to 1.
Return:
tuple: (GPU Utilization, percent of time spent on matmul and convolution)
"""
isExist = os.path.exists(trace_folder)
if not isExist:
os.makedirs(trace_folder)
print("create folder " + trace_folder)
if optimize_ctx is None:
optimize_ctx = contextlib.nullcontext()
chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json")
total_length = dump_chrome_trace(
f,
input,
chrome_trace_file_name,
optimize_ctx,
[ProfilerActivity.CUDA],
num_runs=num_runs,
devices="cuda",
)
utilization, mm_conv_utilization = compute_utilization(
chrome_trace_file_name, total_length
)
return utilization, mm_conv_utilization

View File

@ -0,0 +1,176 @@
# mypy: ignore-errors
from typing import Callable
import torch
import torch.fx as fx
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_flatten
aten = torch.ops.aten
def get_aten_target(node: fx.Node) -> Callable:
if hasattr(node.target, "overloadpacket"):
return node.target.overloadpacket
return node.target
rand_ops = [
aten.dropout,
aten._fused_dropout,
aten._standard_gamma,
aten.bernoulli,
aten.multinomial,
aten.native_dropout,
aten.normal,
aten.poisson,
aten.binomial,
aten.rrelu,
aten.rand_like,
aten.rand,
aten.randint,
aten.randn,
aten.randperm,
]
# return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
def fx_graph_cse(fx_g: torch.fx.graph.Graph):
new_graph = fx.Graph()
env = {} # map from node in the old graph to node in the new graph
hash_env = {} # map from hash to a node in the new graph
token_map = {} # map from hash to token
from torch._inductor.pattern_matcher import (
compute_mutation_region_ids,
same_mutation_regions,
)
compute_mutation_region_ids(fx_g) # type: ignore[arg-type]
# Make a set of separate storages returned from the output, which will be preserved
# when pruning. This prevents us from deduplicating returned tensors which have
# experienced identical operations, but are separate data structures in eager mode.
output_node: fx.Node = list(fx_g.nodes)[-1]
assert output_node.op == "output"
def checkable_node(node: fx.Node) -> bool:
"""We can evaluate only nodes that represent tensors with defined storage."""
if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor):
return False
try:
node.meta["val"].untyped_storage()
except NotImplementedError:
return False
return True
output_storages = {
StorageWeakRef(n.meta["val"].untyped_storage())
for n in output_node.all_input_nodes
if checkable_node(n)
}
nodes_that_alias_outputs = {
n
for n in fx_g.nodes
if checkable_node(n)
and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages
}
for n in fx_g.nodes:
# The placeholder, output, and get_attr nodes are copied to the new graph without change
# do not CSE away random operations
if (
n.op == "placeholder"
or n.op == "output"
or n.op == "get_attr"
or get_aten_target(n) in rand_ops
# aten.empty is non-deterministic, so don't CSE it.
# Also, aten.empty is almost always fusible into its consumer,
# so it's not worth CSEing.
or get_aten_target(n) is aten.empty
or n in nodes_that_alias_outputs
):
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
# substitute args and kwargs members to their mapping in env if exists
# specs can be used to reconstruct nested list/dictionaries
def substitute(arg_list):
arg_list, spec = tree_flatten(arg_list)
for i in range(len(arg_list)):
v = arg_list[i]
if isinstance(v, torch.fx.node.Node) and v in env:
arg_list[i] = env[v]
if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
arg_list[i] = v.node
return tuple(arg_list), spec
args, args_spec = substitute(n.args)
kwargs, kwargs_spec = substitute(n.kwargs)
# each token corresponds to a unique node
# nodes with the same token can be substituted
token = {
"target": n.target,
"args": args,
"args_spec": args_spec,
"kwargs": kwargs,
"kwargs_spec": kwargs_spec,
}
# hash substituted args to a number, do not hash specs because specs are not hashable
# We need to add type into hash to avoid situations like:
# hash((primals_2, 1.0)) == hash((primals_2, 1))
hash_arg = hash(
(tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
)
hash_val = (n.target, hash_arg)
# check if a node has a substitute and can be eliminated
hash_val_in_hash_env = hash_val in hash_env
overwrite_due_to_mutation = False
if hash_val_in_hash_env and token_map[hash_val] == token:
duplicate_n_prev = hash_env[hash_val]
if same_mutation_regions(n, duplicate_n_prev):
env[n] = duplicate_n_prev
continue
else:
# any futures duplicates should replace with n, not duplicate_n_prev
overwrite_due_to_mutation = True
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node
if overwrite_due_to_mutation or not hash_val_in_hash_env:
hash_env[hash_val] = new_node
token_map[hash_val] = token
return new_graph
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def get_placeholders(graph):
return graph.find_nodes(op="placeholder")
def get_outputs(graph):
for node in graph.find_nodes(op="output"):
return pytree.tree_leaves(node.args[0])
raise AssertionError("No output node found")

View File

@ -0,0 +1,445 @@
# mypy: ignore-errors
import copy
import logging
import os
import pickle
import random
from contextlib import contextmanager
from functools import partial
from typing import Callable, Union
import sympy
import torch
import torch.fx as fx
import torch.nn as nn
import torch.utils._pytree as pytree
from torch import SymInt
from torch._decomp import get_decompositions
from torch.fx.experimental.symbolic_shapes import bind_symbols
from .aot_autograd import aot_function, aot_module, make_boxed_compiler
from .compile_utils import strip_overloads
from .partitioners import (
default_partition,
draw_graph,
min_cut_rematerialization_partition,
)
log = logging.getLogger(__name__)
# These canonicalizations are needed here (and not decompositions), as the ops
# we're trying to canonicalize to CompositeImplicitAutograd.
def _canonicalize(fx_g):
for node in fx_g.graph.find_nodes(
op="call_function", target=torch.ops.aten._to_copy
):
node.target = torch.ops.aten.to
fx_g.recompile()
return fx_g
@contextmanager
def _disable_jit_autocast():
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
try:
yield
finally:
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
@make_boxed_compiler
def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
"""
Compiles the :attr:`fx_g` with Torchscript compiler.
.. warning::
This API is experimental and likely to change.
Args:
fx_g(fx.GraphModule): The input Fx graph module to be compiled.
Returns:
Torch scripted model.
"""
with _disable_jit_autocast():
strip_overloads(fx_g)
for node in fx_g.graph.find_nodes(
op="call_function", target=torch.ops.aten._to_copy
):
if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs:
node.target = torch.ops.aten.to
for node in fx_g.graph.nodes:
new_kwargs = {}
for k, v in node.kwargs.items():
if isinstance(v, torch.device):
v = v.type
new_kwargs[k] = v
node.kwargs = new_kwargs
fx_g.graph.lint()
fx_g.recompile()
f = torch.jit.script(fx_g)
torch._C._jit_pass_remove_mutation(f.graph)
f = torch.jit.freeze(f.eval())
f = torch.jit.optimize_for_inference(f)
if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
f(*inps)
return f
def _draw_graph_compile(fx_g, _, name, clear_meta=True):
print(fx_g.code)
draw_graph(fx_g, name, clear_meta=clear_meta)
return fx_g
def draw_graph_compile(name):
return make_boxed_compiler(partial(_draw_graph_compile, name=name))
@make_boxed_compiler
def nop(fx_g: fx.GraphModule, _) -> Callable:
"""
Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler
and can be used to check accuracy.
.. warning::
This API is experimental and likely to change.
"""
return fx_g
class DebugInterpreter(fx.Interpreter):
def run(self, *args):
self.symbol_mapping = bind_symbols(self.module, *args)
super().run(*args)
def run_node(self, n):
def subst_symint(ni):
if not isinstance(ni, SymInt):
return ni
r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
assert r.is_number, r
return int(r)
def subst_symint_tuple(nis):
return tuple(subst_symint(ni) for ni in nis)
def check_significant_strides(a, b):
if subst_symint(a.numel()) > 0:
for idx in range(a.ndim):
if (
subst_symint(a.stride(idx)) != b.stride(idx)
and subst_symint(a.size(idx)) > 1
):
return False
return True
def check(nv, rv, desc):
assert callable(desc)
assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
assert (
subst_symint_tuple(nv.size()) == rv.size()
), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
same_strides = check_significant_strides(nv, rv)
assert (
same_strides
), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
r = super().run_node(n)
if "val" in n.meta:
n_vals, n_spec = pytree.tree_flatten(n.meta["val"])
r_vals, r_spec = pytree.tree_flatten(r)
# TODO: There is some sort of problem where we record that an
# operator returned a tuple/list, and then later it turns out the
# real version of the operator returned a list/tuple. Need to
# figure out what's actually going on here, the error itself is
# harmless enough as we only getitem out the outputs.
# assert n_spec == r_spec, f"{n_spec} != {r_spec}"
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
if not isinstance(rv, torch.Tensor):
continue
check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
return r
@make_boxed_compiler
def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
"""
Returns a (slow) interpreter over the FX graph module that also checks
various debugging properties (e.g., that tracing strides matched real
strides.)
"""
return DebugInterpreter(fx_g).run
@make_boxed_compiler
def simple_ts_compile(fx_g, _):
strip_overloads(fx_g)
f = torch.jit.script(fx_g)
f = torch.jit.freeze(f.eval())
return f
def nnc_jit(f):
return aot_function(f, simple_ts_compile)
aten = torch.ops.aten
default_decompositions = {
aten.detach,
aten.gelu_backward,
aten.leaky_relu_backward,
aten.sigmoid_backward,
aten.threshold_backward,
aten.hardtanh_backward,
aten.hardsigmoid_backward,
aten.hardswish_backward,
aten.tanh_backward,
aten.silu_backward,
aten.elu_backward,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.masked_fill.Scalar,
aten.masked_fill.Tensor,
aten.elu,
aten.leaky_relu,
aten.hardtanh,
aten.hardswish,
aten.hardsigmoid,
aten.conj_physical,
aten.is_same_size,
}
default_decompositions = get_decompositions(default_decompositions)
@make_boxed_compiler
def print_compile(fx_g, _):
print(fx_g.code)
return fx_g
def memory_efficient_fusion(
fn: Union[Callable, nn.Module],
**kwargs,
):
"""
Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
memory efficient fusion. It uses the
:func:`min_cut_rematerialization_partition` partitioner to perform efficient
recomputation. It uses NVFuser to compile the generated forward and backward
graphs.
.. warning::
This API is experimental and likely to change.
Args:
fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
that takes one ore more arguments. Must return one or more Tensors.
**kwargs: Any other overrides you want to make to the settings
Returns:
Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior
of the original :attr:`fn`, but whose forward and backward graphs have
gone through recomputation optimizations, and the graphs have been
compiled with nvfuser.
"""
config = {
"fw_compiler": ts_compile,
"bw_compiler": ts_compile,
"partition_fn": min_cut_rematerialization_partition,
"decompositions": default_decompositions,
}
config.update(kwargs)
if isinstance(fn, torch.nn.Module):
return aot_module(fn, **config)
else:
return aot_function(fn, **config)
def debug_compile(fx_g, inps):
fx_g.to_folder("foo")
print(
f"""
##############################################################
# To minimize FX graph, copy and paste the below and run it #
##############################################################
import torch
import torch.fx as fx
from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
inps = {[(i.shape, i.dtype) for i in inps]}
inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
from foo import FxModule
mod = FxModule().cuda()
with torch.jit.fuser("fuser2"):
# check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
"""
)
from foo import FxModule
FxModule().cuda()(*inps)
return ts_compile(fx_g, inps)
graph_index = 0
def get_inputs(input_data_path):
"""
Return a random input for the given inputs meta generated from _save_fx_default.
"""
inputs = []
with open(input_data_path, "rb") as f:
inputs_meta = pickle.load(f)
inputs = []
for meta in inputs_meta:
if len(meta) == 1:
type = meta
input = type(random.rand())
else:
type, shape, stride, dtype, device = meta
if dtype in {
torch.int,
torch.int32,
torch.int64,
torch.bool,
torch.int,
torch.uint8,
int,
float,
}:
input = torch.randint(0, 1, shape, dtype=dtype, device=device)
else:
input = torch.rand(shape, dtype=dtype, device=device)
inputs.append(input)
return inputs
def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs):
"""
The forward, backward, and joint computation graph will be stored in
{folder_name}/{current_name}/{current_name}_forward_{graph_index},
{folder_name}/{current_name}/{current_name}_backward_{graph_index}, and
{folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.
The input shape of the graphs will be stored in the .input files.
These files can be loaded with pickle,
and is a list of format (type, shape, stride, dtype, device).
In the case of type = int or float, it is just (type,).
For joint graph input, it is a nested list [[],[]]
where the two inner lists have the same format.
If dump_example_input is True, example_inputs will be stored in .pt file.
Since each function might produce multiple graphs,
the graph_index is used to distinguish difference graphs
"""
from functorch.compile import aot_module_simplified
def get_input_meta(args):
input_meta = []
if len(args) > 0 and isinstance(args[0], tuple): # joint input
input_meta += get_input_meta(args[0])
input_meta += get_input_meta(args[1])
return input_meta
for arg in args:
if type(arg) == int or type(arg) == float:
input_meta.append((type(arg),))
else:
input_meta.append(
(type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
)
return input_meta
def graph_saver_helper(gm_to_save, args, type_name):
global graph_index
if len(gm_to_save.graph.nodes) == 0:
log.log(
logging.WARNING,
"No nodes in graph {%s}_{%s}_{%s}.",
current_name,
type_name,
graph_index,
)
return
gm = copy.deepcopy(gm_to_save)
gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen
gm.recompile()
input_meta = get_input_meta(args)
os.makedirs(f"{folder_name}/{current_name}", exist_ok=True)
gm.to_folder(
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
)
pickle.dump(
input_meta,
open(
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950
"wb",
),
) # noqa: E501
if dump_example_input:
torch.save(
args,
f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
) # noqa: E501
def graph_saver_forward(gm, fw_args):
graph_saver_helper(gm, fw_args, "forward")
return gm
def graph_saver_backward(gm, bw_args):
graph_saver_helper(gm, bw_args, "backward")
global graph_index
graph_index += 1
return gm
def graph_saver_joint(gm, joint_args):
graph_saver_helper(gm, joint_args, "joint")
return default_partition(gm, joint_args)
return aot_module_simplified(
gm,
example_inputs,
fw_compiler=graph_saver_forward,
bw_compiler=graph_saver_backward,
partition_fn=graph_saver_joint,
decompositions=default_decompositions,
)
# WARNING: This isn't tested anywhere!!
def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
"""
Dump the forward, backward, and joint computation graph.
Example Usage:
save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)
optimize_ctx = torchdynamo.optimize(
save_fx_func
)
with torch.enable_grad():
with optimize_ctx:
result = forward_and_backward_pass(model, example_inputs)
"""
global graph_index
graph_index = 0
return partial(_save_fx_default, current_name, folder_name, dump_example_input)

View File

@ -0,0 +1,203 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Global flags for aot autograd
"""
import os
import sys
from typing import TYPE_CHECKING
# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
# Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on.
# This is currently off by default as it will harm tracing time,
# but it is on by default for aot_eager.
debug_assert = False
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
# Today, if you are in a situation where there is "false aliasing"
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
# our checks for this situation are very slow if these inputs have dynamic shapes.
# This config is set to ensure that there aren't too many aliased inputs in this situation,
# so that we error loudly instead of compiling forever.
# Eventually, we should make these checks faster.
# For now, however, you can simply turn off dynamic shapes by marking your inputs static
# when you run into this situation.
_max_aliased_inputs_with_dynamic_shapes_enabled = 5
static_weight_shapes = True
# Applies CSE to the graph before partitioning
cse = True
enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1"
# When AOTAutograd regenerates aliased graph outputs,
# attempt to use functionalization's view-replay logic
# before falling back to the autograd engine's view replay or as_strided.
# This can have some perf implications
# (although for many models this will not matter).
# (1) If you have many view ops chained together, replaying all of them
# at runtime can have more overhead compared to a single as_strided call
# (2) If you are doing training, AsStridedBackward is quite slow,
# and the individual view op backward formulas will likely be faster.
# (3) Some backends like XLA do not support as_strided
# Temporary hack: disable this flag for internal
# (needed to fix an internal issue while avoiding bumping XLA pin)
# eventually: either default this config to false completely
# once XLA pin update works,
# or default config to true and fix relevant bugs
from torch._inductor.config import is_fbcode
# View replay is currently not compatible with AOTAutogradCache, since
# FunctionalTensors are not serializable. We'll need to make them
# serializable before enabling warm cache with this config turned on.
view_replay_for_aliased_outputs = (not is_fbcode()) and (not enable_autograd_cache)
# Restricts the amount of computation AOTAutograd can do.
# NB: We have essentially disabled this heuristic now. However, this is kept
# here for now in case it's useful. Setting it low can artificially reduce the
# amount of recomputation AOTAutograd performs, although not in any kind of
# principled way.
max_dist_from_bw = 1000
# Bans recomputation of nodes that are reading from nodes that is far before
# the current node
ban_recompute_used_far_apart = True
# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
# long chain of recomputation in the backwards pass.
ban_recompute_long_fusible_chains = True
# Bans recomputation of nodes that must be materialized in the backwards pass
# (used by a non-fusible node)
ban_recompute_materialized_backward = True
# Chooses to ban recomputation of nodes based off an allowlist. Setting it to
# False changes it to use a denylist. Main change is on operators like
# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
# that expensive
ban_recompute_not_in_allowlist = True
# Chooses to ban recomputation of reductions. This is generally a good idea, as
# the result of reductions is generally very small but recomputing reductions in
# a fusion can be expensive.
ban_recompute_reductions = True
# Prevents the partitioner from ever saving views (i.e. always recompute them).
# Generally a good idea since views are free to recompute.
recompute_views = False
# By default, the partitioner is purely trying to optimize for runtime (although
# it should always use less memory than eager)
# This knob controls the partitioner to make that tradeoff for you, choosing the
# fastest option that saves less activations than the memory budget.
# Specifically, 0.0 corresponds to the activation memory from applying
# activation checkpointing to the full compiled region, and 1.0 corresponds to
# the activation memory from the default runtime-optimized strategy. So, 0.4
# would result in a strategy that saves 40% of the activations compared to the
# default strategy.
# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
# the activation memory budget.
# NOTE: This *cannot* be treated as
activation_memory_budget = 1.0
# This controls how we estimate the runtime when deciding what the cheapest
# operators to recompute are. The 3 options are
# "flops": Bases it off of the flop count provided by torch.utils.flop_counter
# "profile": Benchmarks each operator to come up with a runtime
# "testing": Returns 1 for everything
activation_memory_budget_runtime_estimator = "flops"
# This controls the solver used for the 0-1 knapsack. By default we use a
# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
# (which has a scipy dependency).
activation_memory_budget_solver = "dp"
# This dumps out a png visualization of the expected runtime vs. activation
# memory tradeoffs for all memory budget values from 0 to 1 in increments of
# 0.5. See an example here:
# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
visualize_memory_budget_pareto = (
os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
)
# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
# Generally, this will probably result in some memory improvement, but at the
# cost of some performance
aggressive_recomputation = False
# If FakeTensor.data_ptr() should error.
# This option is independent of AOTAutograd and torch.compile, but our policy
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional
# which may lead to silent errors unless the backend knows how to handle the
# tokens.
unlift_effect_tokens = False
# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute. While
# seemingly this eliminates the whole point of fake tensors, there are
# two obvious use cases for it:
#
# 1. When users call item()/other data dependent operations,
# if we propagate_real_tensors we are able to determine what
# the true value is and keep going.
#
# 2. It can be useful for testing, when you want to see if the fake
# and real tensors agree with each other. (Note that there are
# currently known inaccuracies in how we clone real tensors, that
# would have to be tightened up for this to be useful in this
# case.)
#
# Note that fake tensors are typically understood to be cheap to store
# indefinitely, so we tend to hold on to them longer than we would
# hold onto the real tensors. So we also support you explicitly
# deallocating the real tensor associated with a fake tensor, at which
# point we will stop propagating real tensors.
#
# One more thing: when you provide a real tensor to fakeify, we will
# clone it, so that we can safely perform mutations on it if necessary.
# This will increase live memory usage. This could potentially be
# optimized by using COW. We also currently do not faithfully
# maintain autograd metadata on the real tensor; this is fine because
# AOTAutograd will only use the fake tensor to determine leafness/etc
# of tensors in question.
fake_tensor_propagate_real_tensors = False
# This controls whether we collect donated buffer. This flag must be set
# False if a user wants to retain_graph=True for backward.
donated_buffer = False
# Controls the default graph output format used by draw_graph
# Supported formats are defined here https://graphviz.org/docs/outputs/
torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
strict_autograd_cache = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
from torch.utils._config_module import install_config_module
# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])

View File

@ -0,0 +1,172 @@
# mypy: allow-untyped-defs
"""
The APIs in this file are exposed as `functorch.*`. They are thin wrappers
around the torch.func.* APIs that have deprecation warnings -- we're trying
to move people to the torch.func.* equivalents.
NB: We don't use *args, **kwargs in the signatures because that changes the
documentation.
"""
import textwrap
import warnings
from typing import Any, Callable, Optional, Tuple, Union
import torch._functorch.apis as apis
import torch._functorch.eager_transforms as _impl
import torch._functorch.make_functional as _nn_impl
import torch.nn as nn
from torch._functorch.eager_transforms import argnums_t
from torch._functorch.vmap import in_dims_t, out_dims_t
def get_warning(api, new_api=None, replace_newlines=False):
if new_api is None:
new_api = f"torch.func.{api}"
warning = (
f"We've integrated functorch into PyTorch. As the final step of the \n"
f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
f"and/or the `torch.func` migration guide for more details \n"
f"https://pytorch.org/docs/main/func.migrating.html"
)
if replace_newlines:
warning = warning.replace("\n", "")
return warning
def warn_deprecated(api, new_api=None):
warning = get_warning(api, new_api, replace_newlines=True)
warnings.warn(warning, FutureWarning, stacklevel=3)
def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
api_name = functorch_api.__name__
if torch_func_api is None:
torch_func_api = getattr(_impl, api_name)
# See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
if torch_func_api.__doc__ is None:
return
warning = get_warning(api_name, new_api_name)
warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
warning_note = textwrap.indent(warning_note, " ")
functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
def vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = "error",
*,
chunk_size=None,
) -> Callable:
warn_deprecated("vmap", "torch.vmap")
return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
warn_deprecated("grad")
return apis.grad(func, argnums, has_aux)
def grad_and_value(
func: Callable, argnums: argnums_t = 0, has_aux: bool = False
) -> Callable:
warn_deprecated("grad_and_value")
return apis.grad_and_value(func, argnums, has_aux)
def vjp(func: Callable, *primals, has_aux: bool = False):
warn_deprecated("vjp")
return _impl.vjp(func, *primals, has_aux=has_aux)
def jvp(
func: Callable,
primals: Any,
tangents: Any,
*,
strict: bool = False,
has_aux: bool = False,
):
warn_deprecated("jvp")
return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
def jacrev(
func: Callable,
argnums: Union[int, Tuple[int]] = 0,
*,
has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False,
):
warn_deprecated("jacrev")
return _impl.jacrev(
func,
argnums,
has_aux=has_aux,
chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy,
)
def jacfwd(
func: Callable,
argnums: argnums_t = 0,
has_aux: bool = False,
*,
randomness: str = "error",
):
warn_deprecated("jacfwd")
return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
def hessian(func, argnums=0):
warn_deprecated("hessian")
return _impl.hessian(func, argnums=argnums)
def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
warn_deprecated("functionalize")
return _impl.functionalize(func, remove=remove)
def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
warn_deprecated("make_functional", "torch.func.functional_call")
return _nn_impl.make_functional(model, disable_autograd_tracking)
def make_functional_with_buffers(
model: nn.Module, disable_autograd_tracking: bool = False
):
warn_deprecated("make_functional_with_buffers", "torch.func.functional_call")
return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
def combine_state_for_ensemble(models):
warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state")
return _nn_impl.combine_state_for_ensemble(models)
setup_docs(vmap, apis.vmap, "torch.vmap")
setup_docs(grad, apis.grad)
setup_docs(grad_and_value, apis.grad_and_value)
setup_docs(vjp)
setup_docs(jvp)
setup_docs(jacrev)
setup_docs(jacfwd)
setup_docs(hessian)
setup_docs(functionalize)
setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call")
setup_docs(
make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call"
)
setup_docs(
combine_state_for_ensemble,
_nn_impl.combine_state_for_ensemble,
"torch.func.stack_module_state",
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,253 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch._functorch.utils import exposed_in
@exposed_in("torch.func")
def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
.. note:: If the module has active parametrizations, passing a value in the
:attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
in the ``parameter_and_buffer_dicts`` input.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # does self.foo = self.foo + 1
>>> print(mod.foo) # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
An example of passing multiple dictionaries
.. code-block:: python
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
print(mod.weight) # tensor(...)
print(mod.buffer) # tensor(...)
x = torch.randn((1, 1))
print(x)
functional_call(mod, a, x) # same as x
print(mod.weight) # same as before functional_call
And here is an example of applying the grad transform over the parameters
of a model.
.. code-block:: python
import torch
import torch.nn as nn
from torch.func import functional_call, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
def compute_loss(params, x, t):
y = functional_call(model, params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
.. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
parameters for better performance and memory usage
Example::
>>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
>>> grad_weights = grad(compute_loss)(detached_params, x, t)
>>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
outside of the transforms, this will result in less memory usage and faster speeds.
Args:
module (torch.nn.Module): the module to call
parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
be used together
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparameterized version. Therefore, if True and different values are passed for the tied
parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
error. Default: False.
Returns:
Any: the result of calling ``module``.
"""
if isinstance(parameter_and_buffer_dicts, dict):
parameters_and_buffers = parameter_and_buffer_dicts
elif isinstance(parameter_and_buffer_dicts, Sequence):
if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
raise ValueError(
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
)
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
all_keys_counter: Dict[str, int] = {}
for k in all_keys:
v = all_keys_counter.get(k, 0)
all_keys_counter[k] = v + 1
repeated_keys = [key for key, n in all_keys_counter.items() if n > 1]
if len(repeated_keys) > 0:
raise ValueError(
f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
)
parameters_and_buffers = {
k: v for d in parameter_and_buffer_dicts for k, v in d.items()
}
else:
raise ValueError(
f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
f"but got {type(parameter_and_buffer_dicts)}"
)
return nn.utils.stateless._functional_call(
module,
parameters_and_buffers,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)
@exposed_in("torch.func")
def stack_module_state(
models: List[nn.Module],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""stack_module_state(models) -> params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
that stack all of their parameters and buffers together, indexed by name.
The stacked parameters are optimizable (i.e. they are new leaf nodes in the
autograd history that are unrelated to the original parameters and can be
passed directly to an optimizer).
Here's an example of how to ensemble over a very simple model:
.. code-block:: python
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
def wrapper(params, buffers, data):
return torch.func.functional_call(models[0], (params, buffers), data)
params, buffers = stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
When there's submodules, this follows state dict naming conventions
.. code-block:: python
import torch.nn as nn
class Foo(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
hidden = 4
self.l1 = nn.Linear(in_features, hidden)
self.l2 = nn.Linear(hidden, out_features)
def forward(self, x):
return self.l2(self.l1(x))
num_models = 5
in_features, out_features = 3, 3
models = [Foo(in_features, out_features) for i in range(num_models)]
params, buffers = stack_module_state(models)
print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
.. warning::
All of the modules being stacked together must be the same (except for
the values of their parameters/buffers). For example, they should be in the
same mode (training vs eval).
"""
if len(models) == 0:
raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError(
"stack_module_state: Expected all models to have the same training/eval mode."
)
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError(
"stack_module_state: Expected all models to be of the same class."
)
all_params = [dict(model.named_parameters()) for model in models]
params = {
k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
for k in all_params[0]
}
all_buffers = [dict(model.named_buffers()) for model in models]
buffers = {
k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
for k in all_buffers[0]
}
return params, buffers
def construct_stacked_leaf(
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
) -> Tensor:
all_requires_grad = all(t.requires_grad for t in tensors)
none_requires_grad = all(not t.requires_grad for t in tensors)
if not all_requires_grad and not none_requires_grad:
raise RuntimeError(
f"Expected {name} from each model to have the same .requires_grad"
)
result = torch.stack(tensors)
if all_requires_grad:
result = result.detach().requires_grad_()
return result

View File

@ -0,0 +1,501 @@
# mypy: ignore-errors
import copy
import math
import os
import sys
from dataclasses import dataclass
from functools import partial, wraps
from typing import Callable, List
import torch
import torch.fx as fx
from torch.hub import tqdm
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._content_store import ContentStoreWriter
from .compile_utils import get_outputs, get_placeholders
is_tuple = object()
@dataclass
class LoadTensorMeta:
size: List[int]
stride: List[int]
dtype: torch.dtype
device: torch.device
class ConcreteProp(torch.fx.Interpreter):
def __init__(self, mod, *, writer=None, skip_offload=False):
super().__init__(mod)
self.writer = writer
self.skip_offload = skip_offload
self.seen_storages = set()
def run_node(self, n):
self.pbar.update(1)
r = super().run_node(n)
name = n.name
if isinstance(r, torch.Tensor):
if self.writer is None:
n.meta["concrete_value"] = r
else:
if StorageWeakRef(r.untyped_storage()) in self.seen_storages:
# Refuse to offload tensors which alias other live
# tensors, because this will violate operator contracts
n.meta["concrete_value"] = None
else:
if not self.skip_offload:
self.writer.write_tensor(os.path.join("eager", name), r)
n.meta["concrete_value"] = LoadTensorMeta(
r.size(), r.stride(), r.dtype, r.device
)
self.seen_storages.add(StorageWeakRef(r.untyped_storage()))
else:
n.meta["concrete_value"] = is_tuple
return r
def propagate(self, *args):
with tqdm(
desc="Saving intermediates for delta debugging",
total=len(self.module.graph.nodes),
disable=self.writer is None,
) as pbar:
self.pbar = pbar
r = super().run(*args)
if not self.skip_offload:
pbar.set_description(
"Saved! To skip next time, run with --skip-saving-eager-intermediates"
)
return r
def is_load_tensor_node(node):
return (
node.op == "call_function"
and node.target is torch.ops.debugprims.load_tensor.default
)
# inplace modifies node/inps
def _convert_node_to_placeholder(graph, node, inps):
if node.op == "output" or node.op == "placeholder":
return False
if is_load_tensor_node(node):
return False
concrete_val = node.meta.get("concrete_value", None)
if isinstance(concrete_val, torch.Tensor):
node.op = "placeholder"
node.target = node.name
node.args = ()
node.kwargs = {}
inps.append(concrete_val)
return True
elif concrete_val is None:
return False
elif concrete_val is is_tuple:
r = False
for tuple_user in list(node.users):
r = _convert_node_to_placeholder(graph, tuple_user, inps) or r
# NB: We must not erase the node at this point, because
# we are iterating over the nodes and this would change
# the iteration order
# graph.erase_node(node)
return r
elif isinstance(concrete_val, LoadTensorMeta):
node.op = "call_function"
node.target = torch.ops.debugprims.load_tensor.default
node.args = (
os.path.join("eager", node.name),
concrete_val.size,
concrete_val.stride,
)
node.kwargs = {
"device": concrete_val.device,
"dtype": concrete_val.dtype,
}
return True
return False
def create_minified_hlo_graph(minified_fx_graph, inputs):
"""
Takes minified FX graph as primary input, and ports it to HLO via StableHLO
Provides minified HLO graph as output, and archive them to local directory
"""
hlo_dir = f"{os.getcwd()}/hlo_files"
os.makedirs(hlo_dir, exists_ok=True)
from torch_xla.stablehlo import save_torch_model_as_stablehlo
save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir)
def dump_state(fx_g, inps):
print(
f"""
# Working Repro with {len(fx_g.graph.nodes)} nodes
inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
{fx_g.code}
"""
)
def is_power_of_two(n):
if n == 0:
return False
return (n & (n - 1)) == 0
@dataclass
class ReproState:
graph: fx.Graph
inps: List[torch.Tensor]
def __post_init__(self):
ph_nodes = get_placeholders(self.graph)
assert len(ph_nodes) == len(self.inps)
def minifier(
fail_f: fx.GraphModule,
inps,
module_fails,
dump_state: Callable = dump_state,
*,
save_dir=None,
offload_to_disk=False,
skip_offload=False,
skip_sanity=False,
max_granularity=None,
):
"""
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
Does 2 main strategies:
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
tries replacing quarter of the graph, etc.
>>> # xdoctest: +SKIP(failing)
>>> failing_function = fx.symbolic_trace(f)
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
note: module_fails returns True if it fails.
"""
assert isinstance(inps, (tuple, list))
failing_graph = fail_f.graph
cur_size = len(failing_graph.nodes)
if max_granularity is not None and not is_power_of_two(max_granularity):
raise RuntimeError(f"max_granularity {max_granularity} not power of two")
num_queries = 0
def deepcopy_fx_graph(fx_graph):
return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
def graph_fails(graph, inps):
nonlocal num_queries
graph = copy.deepcopy(graph)
num_queries += 1
mod = fx.GraphModule(fail_f, graph)
mod.graph.lint()
return module_fails(mod, inps)
writer = None
if offload_to_disk:
writer = ContentStoreWriter(save_dir)
ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps)
if not skip_sanity and not graph_fails(failing_graph, inps):
raise RuntimeError("Input graph did not fail the tester")
print(f"Started off with {cur_size} nodes", file=sys.stderr)
def _register_strategy(strategy: Callable, name: str):
@wraps(strategy)
def new_func(old_state: ReproState, granularity=1):
print(file=sys.stderr)
print(
f"Strategy: {name} (G: {granularity}) "
f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)",
file=sys.stderr,
)
new_state = strategy(
deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity
)
if new_state is not None:
new_nodes = len(new_state.graph.nodes)
old_nodes = len(old_state.graph.nodes)
new_inps = len(new_state.inps)
old_inps = len(old_state.inps)
new_outs = len(get_outputs(new_state.graph))
old_outs = len(get_outputs(old_state.graph))
progress_made = False
if new_nodes < old_nodes:
progress_made = True
print(
f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes",
file=sys.stderr,
)
if new_inps > old_inps:
progress_made = True
print(
f"SUCCESS: Went from {old_inps} to {new_inps} inputs",
file=sys.stderr,
)
if new_outs < old_outs:
progress_made = True
print(
f"SUCCESS: Went from {old_outs} to {new_outs} outputs",
file=sys.stderr,
)
if not progress_made:
raise RuntimeError("Success raised but no progress made?")
if not graph_fails(new_state.graph, new_state.inps):
print(
"WARNING: Something went wrong, not applying this minification",
file=sys.stderr,
)
return None
return new_state
else:
print(f"FAIL: {name}", file=sys.stderr)
return None
return new_func
def register_strategy(name: str):
return partial(_register_strategy, name=name)
@register_strategy("Truncate suffix")
def remove_suffix(cur_graph, cur_inps, granularity):
tested = set()
new_graph = fx.Graph()
env = {}
for idx, node in enumerate(cur_graph.nodes):
new_node = new_graph.node_copy(node, lambda x: env[x])
if node.op not in ["placeholder", "output"]:
# If idx is divisible by (granularity * 2), it would have been checked already.
if (
idx % granularity == 0
and (idx % (granularity * 2) != 0)
and idx not in tested
):
output_node = new_graph.output((new_node,))
if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(
new_graph, cur_inps
):
return ReproState(new_graph, cur_inps)
else:
tested.add(idx)
new_graph.erase_node(output_node)
env[node] = new_node
return None
@register_strategy("Remove outputs")
def remove_outputs(cur_graph, cur_inps, granularity):
granularity = max(1, granularity // 2)
for idx, node in enumerate(cur_graph.nodes):
node.idx = idx
if node.op == "output":
output = node
break
if isinstance(output.args[0], fx.Node):
return None
output_args = sorted(
output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)
)
if len(output_args) == 1:
return None
for idx in range(0, len(output_args), granularity):
output.args = (output_args[:idx] + output_args[idx + granularity :],)
if graph_fails(cur_graph, cur_inps):
return ReproState(cur_graph, cur_inps)
return None
def remove_unused_inputs_unchecked(cur_state: ReproState):
cur_graph = cur_state.graph
cur_inps = cur_state.inps
ph_nodes = get_placeholders(cur_graph)
assert len(ph_nodes) == len(cur_inps)
new_inps = []
for idx in range(len(ph_nodes)):
if len(ph_nodes[idx].users) == 0:
cur_graph.erase_node(ph_nodes[idx])
else:
new_inps.append(cur_inps[idx])
if len(new_inps) < len(cur_inps):
return ReproState(cur_graph, new_inps)
return None
def remove_unused_inputs_checked(cur_state: ReproState):
new_state = remove_unused_inputs_unchecked(cur_state)
if new_state is not None and graph_fails(new_state.graph, new_state.inps):
return new_state
return None
def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))
remove_unused_inputs = register_strategy("Remove unused inputs")(
_remove_unused_wrapper
)
@register_strategy("Eliminate dead code")
def eliminate_dead_code(cur_graph, cur_inps, granularity):
if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
return ReproState(cur_graph, cur_inps)
return None
def _consolidate_placeholders(cur_graph, inps):
new_graph = fx.Graph()
env = {}
seen_non_placeholder = False
# Move all placeholders to the front; also, if any load_tensor
# is at the front, convert it into an input (because it can be live
# all the time)
for node in cur_graph.nodes:
if node.op == "placeholder":
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
elif not seen_non_placeholder and is_load_tensor_node(node):
new_node = new_graph.placeholder(node.name)
env[node] = new_node
inps.append(
torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs)
)
else:
seen_non_placeholder = True
# Move everyone else
for node in cur_graph.nodes:
if node not in env:
new_node = new_graph.node_copy(node, lambda x: env[x])
env[node] = new_node
return new_graph
@register_strategy("Delta Debugging")
def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
num_nodes = len(cur_graph.nodes)
for start_range in range(0, num_nodes, granularity):
is_removing = False
new_graph = deepcopy_fx_graph(cur_graph)
new_inps = cur_inps[:]
end_range = min(num_nodes, start_range + granularity)
for idx in range(start_range, end_range):
new_node = list(new_graph.nodes)[idx]
if _convert_node_to_placeholder(new_graph, new_node, new_inps):
is_removing = True
if not is_removing:
continue
new_graph.eliminate_dead_code()
new_graph = _consolidate_placeholders(new_graph, new_inps)
new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
if new_state is None:
new_state = ReproState(new_graph, new_inps)
if graph_fails(new_state.graph, new_state.inps):
return ReproState(new_state.graph, new_state.inps)
return None
@register_strategy("Consolidate Inputs")
def consolidate_inputs(cur_graph, cur_inps, granularity):
old_len = len(cur_inps)
cur_graph = _consolidate_placeholders(cur_graph, cur_inps)
if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps):
return ReproState(cur_graph, cur_inps)
return None
failing_state = ReproState(failing_graph, inps)
def try_granularity(failing_state, granularity, use_non_granular):
print(f"Trying granularity {granularity}", file=sys.stderr)
strategies = []
num_nodes = len(failing_state.graph.nodes)
num_outputs = len(get_outputs(failing_state.graph))
if num_outputs > num_nodes // 2:
strategies += [remove_outputs]
if use_non_granular:
strategies += [
eliminate_dead_code,
remove_unused_inputs,
consolidate_inputs,
]
strategies += [remove_suffix, delta_debugging]
for strategy in strategies:
new_state = strategy(failing_state, granularity)
if new_state is not None:
return new_state
return None
while True:
dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes)))))
if max_granularity is not None:
granularity = min(max_granularity, granularity)
new_state = try_granularity(failing_state, granularity, use_non_granular=True)
if new_state is not None:
failing_state = new_state
continue
granularity //= 2
has_progress = False
while granularity >= 1:
new_state = try_granularity(
failing_state, granularity, use_non_granular=False
)
if new_state is not None:
failing_state = new_state
has_progress = True
break
granularity //= 2
if has_progress:
continue
new_state = remove_outputs(failing_state, 1)
if new_state is not None:
failing_state = new_state
continue
break
if not graph_fails(failing_state.graph, failing_state.inps):
raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
print(f"Made {num_queries} queries", file=sys.stderr)
failing_fx = fx.GraphModule(fail_f, failing_state.graph)
# If XLA debugging environment is enabled, create minified HLO graph as well
if "XLA_HLO_DEBUG" in os.environ:
create_minified_hlo_graph(failing_fx, failing_state.inps)
dump_state(failing_fx, failing_state.inps)
print("Wrote minimal repro out to repro.py", file=sys.stderr)
return failing_fx, failing_state.inps

View File

@ -0,0 +1,617 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Sequence,
Tuple,
Type,
Union,
)
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def raise_parameter_tying_error() -> NoReturn:
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
"https://github.com/pytorch/functorch/issues/446"
)
def create_names_map(
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
) -> Dict[str, List[str]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = dict(named_params)
tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())
def _extract_members(
mod: nn.Module,
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
subclass: Callable[[Tensor], Tensor],
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
all_named_members = tuple(named_members(remove_duplicate=False))
unique_named_members = tuple(named_members(remove_duplicate=True))
names_map = create_names_map(unique_named_members, all_named_members)
# Remove all the members in the model
memo = {}
accessor = NamedMemberAccessor(mod)
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device="meta"))
replacement = memo[p]
accessor.set_tensor(name, replacement)
if len(unique_named_members) == 0:
names, params = (), ()
else:
names, params = zip(*unique_named_members) # type: ignore[assignment]
return params, names, names_map
def extract_weights(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
return _extract_members(mod, mod.named_parameters, nn.Parameter)
def extract_buffers(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
return _extract_members(mod, mod.named_buffers, lambda x: x)
def load_weights(
mod: nn.Module,
names: Sequence[str],
params: Sequence[Tensor],
as_params: bool = False,
) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
accessor = NamedMemberAccessor(mod)
if as_params:
params = [nn.Parameter(p) for p in params]
accessor.set_tensors(names, params)
def _swap_state(
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
) -> List[Tensor]:
result: List[Tensor] = []
accessor = NamedMemberAccessor(mod)
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(accessor.swap_tensor(attr_name, elem))
else:
accessor.set_tensor(attr_name, elem)
return result
def load_buffers(
mod: nn.Module,
names: Sequence[str],
buffers: Sequence[Tensor],
as_params: bool = False,
) -> None:
accessor = NamedMemberAccessor(mod)
accessor.set_tensors(names, buffers)
def load_state(
model: nn.Module,
weights: Sequence[Tensor],
weight_names: Sequence[str],
buffers: Sequence[Tensor] = (),
buffer_names: Sequence[str] = (),
) -> nn.Module:
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional_deprecated_v1`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
if len(buffers) > 0:
assert len(buffer_names) == len(buffers)
load_buffers(model, buffer_names, buffers)
return model
def make_functional_deprecated_v1(model: nn.Module):
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional_deprecated_v1(model)
func(weights, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional_deprecated_v1(model)
grad_weights = grad(func)(weights, (x,))
```
To put the state back into a model, use `load_state`.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError(
"make_functional_deprecated_v1(model): `model` has buffers. Please use "
"make_functional_with_buffers_deprecated_v1(model) instead."
)
weights, descriptors, _ = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
To put the state back into a model, use `load_state`.
"""
weights, weight_descriptors, _ = extract_weights(model)
buffers, buf_descriptors, _ = extract_buffers(model)
def fun(weights, buffers, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, weight_descriptors, weights)
load_buffers(mutable_model, buf_descriptors, buffers)
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors
class FunctionalModuleWithBuffers(nn.Module):
"""
This is the callable object returned by :func:`make_functional_with_buffers`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
buffer_names: Tuple[str, ...],
param_names_map: Dict[str, List[str]],
buffer_names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.buffer_names = buffer_names
self.all_names_map = dict(param_names_map)
self.all_names_map.update(buffer_names_map)
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return (
FunctionalModuleWithBuffers(
model_copy, param_names, buffer_names, param_names_map, buffer_names_map
),
params,
buffers,
)
def forward(
self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
tuple(params) + tuple(buffers),
)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.all_names_map, old_state)
class FunctionalModule(nn.Module):
"""
This is the callable object returned by :func:`make_functional`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.names_map = names_map
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params
def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.names_map, old_state)
def make_functional(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
"""make_functional(model, disable_autograd_tracking=False) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
(params) and returns a functional version of the model, ``func``. This
makes it so that it is possible use transforms over the parameters of
``model``.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)
And here is an example of applying the grad transform over the parameters
of a model.
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
def compute_loss(params, x, t):
y = func(params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, x, t)
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
Args:
model (torch.nn.Module): Input model.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError(
"make_functional(model): `model` has buffers. Please use "
"make_functional_with_buffers(model) instead."
)
return FunctionalModule._create_from(
model, disable_autograd_tracking=disable_autograd_tracking
)
def make_functional_with_buffers(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
state (params and buffers) and returns a functional version of the model
``func`` that can be invoked like a function.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
func(params, buffers, x)
And here is an example of applying the grad transform over the parameters
of a model:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, buffers, x, t)
Args:
model (torch.nn.Module): Input model.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
return FunctionalModuleWithBuffers._create_from(
model, disable_autograd_tracking=disable_autograd_tracking
)
def transpose_stack(
tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
) -> Tuple[Tensor, ...]:
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
results = tuple(
torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
)
return results
def combine_state_for_ensemble(
models: Sequence[nn.Module],
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""combine_state_for_ensemble(models) -> func, params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
parameters and buffers together to make ``params`` and ``buffers``.
Each parameter and buffer in the result will have an additional dimension
of size ``M``.
:func:`combine_state_for_ensemble` also returns ``func``, a functional
version of one of the models in :attr:`models`. One cannot directly run
``func(params, buffers, *args, **kwargs)`` directly, you probably want to
use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
Here's an example of how to ensemble over a very simple model:
.. code-block:: python
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
.. warning::
All of the modules being stacked together must be the same (except for
the values of their parameters/buffers). For example, they should be in the
same mode (training vs eval).
This API is subject to change -- we're investigating better ways to
create ensembles and would love your feedback how to improve this.
"""
if len(models) == 0:
raise RuntimeError(
"combine_state_for_ensemble: Expected at least one model, got 0."
)
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError(
"combine_state_for_ensemble: Expected all models to "
"have the same training/eval mode."
)
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError(
"combine_state_for_ensemble: Expected all models to be of the same class."
)
funcs, params, buffers = zip(
*[make_functional_with_buffers(model) for model in models]
)
params = transpose_stack(params)
buffers = transpose_stack(buffers)
return funcs[0], params, buffers
def functional_init(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(
model_class(*args, **kwargs).to(device) for _ in range(num_models)
)
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
return wrapped
def functional_init_with_buffers(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(
model_class(*args, **kwargs).to(device) for _ in range(num_models)
)
(
_,
_,
fn,
weight_names,
buffer_names,
) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
weights, buffers = zip(
*tuple(
make_functional_with_buffers_deprecated_v1(model)[:2]
for model in models
)
)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
buffers = tuple(zip(*buffers))
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
return weights, buffers, fn, weight_names, buffer_names
return wrapped

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,294 @@
# mypy: allow-untyped-defs
import contextlib
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
import torch
import torch.utils._pytree as pytree
from torch._C._functorch import (
CFunctionalizeInterpreterPtr,
CGradInterpreterPtr,
CInterpreter,
CJvpInterpreterPtr,
CVmapInterpreterPtr,
pop_dynamic_layer_stack,
push_dynamic_layer_stack,
RandomnessType,
TransformType,
)
from torch.autograd.forward_ad import _set_fwd_grad_enabled
"""
This file contains the functorch integration with PyDispatcher.
PyDispatcher does not understand functorch's DynamicLayerStack dispatching
logic because it is entirely implemented in C++ in the fallbacks for two
dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
to directly reuse C++ boxed fallbacks).
Instead of trying to hammer PyDispatcher into understanding those fallbacks,
we re-implement the logic of peeking the top of the stack for an interpreter,
selecting the interpreter to dispatch on, etc, in Python. This leads to a
simpler design.
The main difference between C++ functorch and PyDispatcher's functorch logic
is that:
- C++ functorch needs to manually tweak dispatch keys to ping-pong between
DynamicLayerFrontMode and DynamicLayerBackMode.
- PyDispatcher's functorch logic pops an Interpreter from the top of the stack
and asks it to execute the rule associated with the Interpreter.
In C++ we do the ping-pong because e.g. vmap rules are associated with the
batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
the user to register a batching rule directly to a transform that an
interpreter then invokes.
"""
# FuncTorchInterpreter is the Python version of Interpreter (recall that
# the DynamicLayerStack is a stack of interpreters).
# It is a wrapper around the actual C++ Interpreter object.
#
# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
class FuncTorchInterpreter(ABC):
def __init__(self, cptr: Any):
self._cptr = cptr
# Process an operation. eg for vmap, this is invoking a batching rule.
# Conceptually this is analogous to Interpreter::process in C++
@abstractmethod
def process(self, op, args, kwargs):
pass
# lower an operation from this Interpreter to the next Interpreter on the stack.
# Concretely, this involves temporarily popping the current Interpreter.
# Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
def lower(self):
return temporarily_pop_interpreter_stack()
def level(self):
return self._cptr.level()
def key(self):
return self._cptr.key()
def get_state(self):
raise NotImplementedError
def check_state(self, state):
return state == self.get_state()
@contextlib.contextmanager
def temporarily_pop_interpreter_stack():
try:
saved = pop_dynamic_layer_stack()
yield
finally:
push_dynamic_layer_stack(saved)
@contextlib.contextmanager
def temporarily_clear_interpreter_stack():
stack = []
try:
while torch._C._functorch.peek_interpreter_stack() is not None:
stack.append(pop_dynamic_layer_stack())
yield list(stack)
finally:
while stack:
push_dynamic_layer_stack(stack.pop())
@contextlib.contextmanager
def temporarily_restore_interpreter_stack(stack):
pushed = []
try:
for s in reversed(stack):
push_dynamic_layer_stack(s)
pushed.append(s)
yield
finally:
for s in reversed(pushed):
# TODO: would be nice to assert that the layers are the same, but
# Python object identity is not preserved
pop_dynamic_layer_stack()
class VmapInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Vmap
# NOTE: [Interpreter cdata vs cptr]
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
# so that we can access methods specific to the vmap interpreter
self._cdata = cdata
self._cptr = CVmapInterpreterPtr(cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Vmap]
return kernel(self, *args, **kwargs)
def batch_size(self):
return self._cptr.batchSize()
def randomness(self):
typ = self._cptr.randomness()
if typ == RandomnessType.Error:
return "error"
elif typ == RandomnessType.Same:
return "same"
elif typ == RandomnessType.Different:
return "different"
raise RuntimeError(f"Unknown RandomnessType: {typ}")
def get_state(self):
return (self.key().name, self.level(), self.randomness())
@contextlib.contextmanager
def nested(*contexts):
with contextlib.ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx)
yield contexts
class GradInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Grad
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CGradInterpreterPtr(cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(
torch.Tensor, self._cptr.lift, [args, kwargs]
)
return args, kwargs
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Grad]
args, kwargs = self.lift(args, kwargs)
return kernel(self, *args, **kwargs)
# GradInterpreter has custom lower because of the no_grad interaction
# See NOTE [grad and vjp interaction with no_grad]
# This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
def lower(self):
prev_grad_mode = self.prev_grad_mode()
if not prev_grad_mode:
return nested(torch.no_grad(), super().lower())
return super().lower()
def prev_grad_mode(self):
return self._cptr.prevGradMode()
def get_state(self):
return (self.key().name, self.level(), self.prev_grad_mode())
class JvpInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Jvp
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CJvpInterpreterPtr(cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(
torch.Tensor, self._cptr.lift, [args, kwargs]
)
return args, kwargs
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Jvp]
args, kwargs = self.lift(args, kwargs)
return kernel(self, *args, **kwargs)
# Jvp has custom lower because of the no_fwd_grad interaction
# See NOTE [grad and vjp interaction with no_grad] for related info.
# This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
def lower(self):
prev_fwd_grad_mode = self.prev_fwd_grad_mode()
if not prev_fwd_grad_mode:
return nested(_set_fwd_grad_enabled(False), super().lower())
return super().lower()
def prev_fwd_grad_mode(self):
return self._cptr.prevFwdGradMode()
def get_state(self):
return (self.key().name, self.level(), self.prev_fwd_grad_mode())
class FunctionalizeInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Functionalize
self._cdata = cdata
self._cptr = CFunctionalizeInterpreterPtr(cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Functionalize]
return kernel(self, *args, **kwargs)
def functionalize_add_back_views(self):
return self._cptr.functionalizeAddBackViews()
def get_state(self):
return (self.key().name, self.level())
def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
key = cinterpreter.key()
if key == TransformType.Grad:
return GradInterpreter(cinterpreter)
if key == TransformType.Vmap:
return VmapInterpreter(cinterpreter)
if key == TransformType.Jvp:
return JvpInterpreter(cinterpreter)
if key == TransformType.Functionalize:
return FunctionalizeInterpreter(cinterpreter)
raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
interpreter = torch._C._functorch.peek_interpreter_stack()
assert interpreter is not None
return coerce_cinterpreter(interpreter)
def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
cis = torch._C._functorch.get_interpreter_stack()
if cis is None:
return []
return [coerce_cinterpreter(ci) for ci in cis]
def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
# There are four possible cases covered here:
# 1. Current stack empty AND stack when generated not empty -> Invalidate
# 2. Current stack not empty AND stack when generated empty -> Invalidate
# 3. Current stack and generated stack empty -> Valid FX graph
# 4. Current stack and generated stack not empty -> Valid if both states match
peek = torch._C._functorch.peek_interpreter_stack()
if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0):
return False
cis = retrieve_all_functorch_interpreters()
return len(cis) == len(states) and all(
ci.check_state(state) for ci, state in zip(cis, states)
)
def dispatch_functorch(op, args, kwargs):
interpreter = retrieve_current_functorch_interpreter()
# In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
# unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
# PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
# transforms, so we manually unwrap the dead tensors here.
# This logic won't need to exist when we have mode-only functorch.
args, kwargs = pytree.tree_map_only(
torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
)
return interpreter.process(op, args, kwargs)

View File

@ -0,0 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
from torch.fx.experimental.proxy_tensor import (
decompose,
dispatch_trace,
make_fx,
PythonKeyTracer,
)
pythonkey_decompose = decompose

View File

@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
# TODO: remove this file when the migration of the pytree utility is done
from torch.utils._pytree import tree_map_, treespec_pprint
__all__ = ["tree_map_", "treespec_pprint"]
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`torch._functorch.pytree_hacks` is deprecated and will be removed in a future release. "
"Please `use torch.utils._pytree` instead.",
DeprecationWarning,
stacklevel=2,
)

View File

@ -0,0 +1,632 @@
# mypy: ignore-errors
"""
From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0
Try to keep this list in sync with that.
"""
import operator
top_torch = [
("t", 6837449),
("tensor", 585786),
("mode", 462182),
("cat", 394818),
("max", 368038),
("zeros", 329495),
("load", 327756),
("no_grad", 294694),
("save", 265130),
("from_numpy", 243063),
("manual_seed", 165044),
("ones", 153696),
("randn", 150796),
("stack", 133358),
("sum", 130772),
("arange", 98087),
("rand", 94715),
("mean", 88546),
("exp", 73883),
("zeros_like", 72831),
("min", 72248),
("sigmoid", 66798),
("log", 62135),
("matmul", 47811),
("clamp", 45304),
("sqrt", 44911),
("abs", 43535),
("tanh", 42793),
("empty", 40311),
("argmax", 38435),
("bmm", 33984),
("pow", 33571),
("norm", 31125),
("mm", 30995),
("is_tensor", 29546),
("ones_like", 29512),
("nonzero", 28681),
("full", 28373),
("unsqueeze", 27911),
("where", 26585),
("randperm", 26450),
("eye", 24342),
("mul", 23236),
("topk", 22537),
("as_tensor", 21967),
("sort", 21412),
("squeeze", 20863),
("randint", 20771),
("linspace", 20041),
("add", 19201),
("transpose", 18663),
("split", 18325),
("gather", 17904),
("set_grad_enabled", 16013),
("sin", 15669),
("cos", 15562),
("div", 15513),
("index_select", 14866),
("multinomial", 14331),
("flatten", 14267),
("isnan", 14170),
("randn_like", 13096),
("eq", 12680),
("einsum", 12480),
("round", 12367),
("floor", 11628),
("allclose", 11000),
("reshape", 10605),
("diag", 10167),
("chunk", 9581),
("std", 9379),
("set_default_tensor_type", 9281),
("triu", 8559),
("meshgrid", 8292),
("set_num_threads", 8126),
("unique", 7964),
("full_like", 7780),
("tril", 7538),
("dot", 7275),
("sign", 6943),
("equal", 6916),
("normal", 6750),
("cumsum", 6556),
("dist", 6058),
("isfinite", 6030),
("gt", 5935),
("set_printoptions", 5888),
("range", 5491),
("empty_like", 5351),
("flip", 5342),
("masked_select", 5341),
("bernoulli", 5262),
("atan", 5253),
("var", 5247),
("prod", 5200),
("erf", 5088),
("inverse", 5072),
("addmm", 4854),
("logsumexp", 4582),
("fft", 4436),
("lt", 4421),
("log2", 4316),
("enable_grad", 4238),
("rand_like", 4187),
("argsort", 3972),
("seed", 3932),
("mv", 3547),
("ger", 3309),
("ge", 3248),
("atan2", 3210),
("ceil", 3202),
("ne", 3075),
("bincount", 3063),
("acos", 3055),
("rsqrt", 3031),
("svd", 3029),
("numel", 3003),
("log1p", 2840),
("unbind", 2808),
("le", 2714),
("isinf", 2707),
("cross", 2646),
("set_default_dtype", 2536),
("argmin", 2535),
("sparse_coo_tensor", 2489),
("log10", 2304),
("kthvalue", 2192),
("set_rng_state", 2158),
("get_rng_state", 1996),
("get_default_dtype", 1879),
("det", 1868),
("qr", 1864),
("histc", 1852),
("symeig", 1832),
("trace", 1801),
("median", 1795),
("addcmul", 1751),
("remainder", 1717),
("baddbmm", 1693),
("lgamma", 1665),
("repeat_interleave", 1598),
("fmod", 1576),
("reciprocal", 1575),
("tan", 1560),
("initial_seed", 1532),
("take", 1529),
("stft", 1487),
("get_num_threads", 1477),
("real", 1459),
("cholesky", 1406),
("quantize_per_tensor", 1392),
("diag_embed", 1364),
("lerp", 1363),
("asin", 1345),
("eig", 1333),
("trunc", 1290),
("diagonal", 1287),
("cosh", 1279),
("rfft", 1269),
("cumprod", 1260),
("addr", 1211),
("roll", 1198),
("narrow", 1188),
("digamma", 1172),
("square", 1163),
("sinh", 1131),
("logspace", 1084),
("broadcast_tensors", 1070),
("irfft", 1013),
("frac", 997),
("hann_window", 994),
("solve", 989),
("logdet", 977),
("expm1", 968),
("cdist", 946),
("addmv", 903),
("randint_like", 888),
("tensordot", 888),
("ifft", 877),
("true_divide", 854),
("erfinv", 830),
("addcdiv", 819),
("addbmm", 813),
("renorm", 781),
("pinverse", 753),
("isclose", 740),
("erfc", 729),
("is_storage", 725),
("triangular_solve", 723),
("rot90", 709),
("logical_not", 686),
("geqrf", 681),
("slogdet", 677),
("lu", 665),
("hamming_window", 659),
("orgqr", 651),
("ormqr", 622),
("is_floating_point", 602),
("diagflat", 562),
("cholesky_solve", 559),
("tril_indices", 552),
("chain_matmul", 551),
("triu_indices", 548),
("angle", 522),
("poisson", 505),
("matrix_power", 485),
("unique_consecutive", 471),
("quantize_per_channel", 465),
("std_mean", 458),
("bartlett_window", 447),
("var_mean", 428),
("lstsq", 421),
("logical_and", 419),
("mvlgamma", 411),
("blackman_window", 400),
("bitwise_not", 395),
("cholesky_inverse", 388),
("as_strided", 384),
("floor_divide", 353),
("cartesian_prod", 321),
("lu_solve", 317),
("set_flush_denormal", 310),
("empty_strided", 283),
("logical_xor", 282),
("polygamma", 282),
("logical_or", 280),
("set_num_interop_threads", 278),
("combinations", 274),
("trapz", 270),
("matrix_rank", 260),
("lu_unpack", 255),
("result_type", 244),
("conj", 231),
("cummax", 230),
("lobpcg", 229),
("bitwise_xor", 217),
("promote_types", 213),
("get_num_interop_threads", 211),
("cummin", 205),
("bitwise_and", 198),
("dequantize", 192),
("bitwise_or", 191),
("imag", 191),
("can_cast", 184),
("istft", 180),
("compiled_with_cxx11_abi", 159),
("is_complex", 151),
("block_diag", 136),
("pca_lowrank", 124),
("absolute", 122),
("svd_lowrank", 108),
("neg", 2),
]
top_nn_functional = [
("nn.functional.softmax", 10522),
("nn.functional.relu", 8572),
("nn.functional.interpolate", 7277),
("nn.functional.pad", 5207),
("nn.functional.log_softmax", 4699),
("nn.functional.normalize", 2338),
("nn.functional.cross_entropy", 2083),
("nn.functional.grid_sample", 1970),
("nn.functional.one_hot", 1967),
("nn.functional.mse_loss", 1920),
("nn.functional.conv2d", 1593),
("nn.functional.dropout", 1516),
("nn.functional.softplus", 1385),
("nn.functional.sigmoid", 1128),
("nn.functional.linear", 1036),
("nn.functional.gelu", 930),
("nn.functional.avg_pool2d", 899),
("nn.functional.max_pool2d", 876),
("nn.functional.nll_loss", 863),
("nn.functional.embedding", 737),
("nn.functional.tanh", 664),
("nn.functional.leaky_relu", 640),
("nn.functional.adaptive_avg_pool2d", 633),
("nn.functional.cosine_similarity", 627),
("nn.functional.unfold", 609),
("nn.functional.conv1d", 596),
("nn.functional.binary_cross_entropy_with_logits", 591),
("nn.functional.l1_loss", 571),
("nn.functional.binary_cross_entropy", 492),
("nn.functional.elu", 416),
("nn.functional.batch_norm", 413),
("nn.functional.upsample", 413),
("nn.functional.fold", 305),
("nn.functional.affine_grid", 298),
("nn.functional.max_pool1d", 297),
("nn.functional.torch", 294),
("nn.functional.threshold", 263),
("nn.functional.smooth_l1_loss", 262),
("nn.functional.pairwise_distance", 253),
("nn.functional.logsigmoid", 243),
("nn.functional.adaptive_max_pool2d", 235),
("nn.functional.relu6", 213),
("nn.functional.pixel_shuffle", 209),
("nn.functional.avg_pool3d", 203),
("nn.functional.bilinear", 203),
("nn.functional.conv_transpose2d", 201),
("nn.functional.gumbel_softmax", 197),
("nn.functional.max_unpool2d", 196),
("nn.functional.kl_div", 191),
("nn.functional.hardtanh", 189),
("nn.functional.ctc_loss", 185),
("nn.functional.layer_norm", 178),
("nn.functional.conv3d", 172),
("nn.functional.max_unpool3d", 167),
("nn.functional.hardshrink", 165),
("nn.functional.hardswish", 156),
("nn.functional.selu", 156),
("nn.functional.glu", 155),
("nn.functional.assert_int_or_pair", 150),
("nn.functional.hardsigmoid", 146),
("nn.functional.upsample_bilinear", 146),
("nn.functional.max_pool3d", 140),
("nn.functional.adaptive_avg_pool3d", 139),
("nn.functional.instance_norm", 124),
("nn.functional.embedding_bag", 122),
("nn.functional.upsample_nearest", 110),
("nn.functional.avg_pool1d", 105),
("nn.functional.prelu", 102),
("nn.functional.celu", 92),
("nn.functional.dropout2d", 86),
("nn.functional.hinge_embedding_loss", 82),
("nn.functional.softsign", 81),
("nn.functional.max_unpool1d", 74),
("nn.functional.silu", 74),
("nn.functional.softshrink", 70),
("nn.functional.leaky_relu_", 68),
("nn.functional.softmin", 67),
("nn.functional.channel_shuffle", 66),
("nn.functional.multilabel_margin_loss", 66),
("nn.functional.dropout3d", 65),
("nn.functional.multi_margin_loss", 65),
("nn.functional.lp_pool2d", 64),
("nn.functional.conv_transpose1d", 62),
("nn.functional.triplet_margin_loss", 62),
("nn.functional.tanhshrink", 61),
("nn.functional.adaptive_max_pool1d", 59),
("nn.functional.cosine_embedding_loss", 58),
("nn.functional.multi_head_attention_forward", 58),
("nn.functional.max_pool1d_with_indices", 53),
("nn.functional.poisson_nll_loss", 53),
("nn.functional.margin_ranking_loss", 52),
("nn.functional.soft_margin_loss", 52),
("nn.functional.adaptive_max_pool3d", 51),
("nn.functional.group_norm", 51),
("nn.functional.local_response_norm", 51),
("nn.functional.multilabel_soft_margin_loss", 51),
("nn.functional.relu_", 50),
("nn.functional.alpha_dropout", 49),
("nn.functional.feature_alpha_dropout", 49),
("nn.functional.lp_pool1d", 49),
("nn.functional.adaptive_max_pool1d_with_indices", 48),
("nn.functional.adaptive_max_pool2d_with_indices", 48),
("nn.functional.adaptive_max_pool3d_with_indices", 48),
("nn.functional.fractional_max_pool2d", 48),
("nn.functional.fractional_max_pool2d_with_indices", 48),
("nn.functional.fractional_max_pool3d", 48),
("nn.functional.fractional_max_pool3d_with_indices", 48),
("nn.functional.max_pool2d_with_indices", 48),
("nn.functional.max_pool3d_with_indices", 48),
("nn.functional.handle_torch_function", 47),
("nn.functional.has_torch_function", 47),
("nn.functional.adaptive_avg_pool1d", 43),
("nn.functional.pdist", 43),
("nn.functional.rrelu_", 37),
("nn.functional.elu_", 34),
("nn.functional.boolean_dispatch", 33),
("nn.functional.hardtanh_", 26),
("nn.functional.triplet_margin_with_distance_loss", 23),
("nn.functional.selu_", 20),
("nn.functional.pixel_unshuffle", 19),
("nn.functional.conv_transpose3d", 18),
("nn.functional.gaussian_nll_loss", 15),
("nn.functional.has_torch_function_unary", 15),
("nn.functional.has_torch_function_variadic", 15),
("nn.functional.celu_", 13),
("nn.functional.huber_loss", 7),
("nn.functional.mish", 4),
("nn.functional.threshold_", 3),
("nn.functional.grad", 2),
("nn.functional.conv_tbc", 1),
("nn.functional.math", 1),
]
top_nn_module = [
("nn.Module", 927129, None),
("nn.Linear", 530688, "nn.functional.linear"),
("nn.Sequential", 384968, None),
("nn.Conv2d", 383320, "nn.functional.conv2d"),
("nn.ReLU", 318877, "nn.functional.relu"),
("nn.BatchNorm2d", 233265, "nn.functional.batch_norm"),
("nn.Dropout", 179268, "nn.functional.dropout"),
("nn.ModuleList", 171225, None),
("nn.Parameter", 153291, None),
("nn.CrossEntropyLoss", 152696, "nn.functional.cross_entropy"),
("nn.MaxPool2d", 138619, "nn.functional.max_pool2d"),
("nn.Embedding", 111844, "nn.functional.embedding"),
("nn.DataParallel", 104238, None),
("nn.MSELoss", 82954, "nn.functional.mse_loss"),
("nn.Sigmoid", 75810, "nn.functional.sigmoid"),
("nn.LeakyReLU", 65632, "nn.functional.leaky_relu"),
("nn.BatchNorm1d", 65374, "nn.functional.batch_norm"),
("nn.Softmax", 65114, "nn.functional.softmax"),
("nn.Tanh", 59445, "nn.functional.tanh"),
("nn.AdaptiveAvgPool2d", 59071, "nn.functional.adaptive_avg_pool2d"),
("nn.AvgPool2d", 58377, "nn.functional.avg_pool2d"),
("nn.ConvTranspose2d", 57524, "nn.functional.conv_transpose2d"),
("nn.LSTM", 57411, None),
("nn.Conv1d", 41108, "nn.functional.conv1d"),
("nn.LayerNorm", 36089, "nn.functional.layer_norm"),
("nn.BCELoss", 34005, "nn.functional.binary_cross_entropy"),
("nn.Upsample", 32527, "nn.functional.interpolate"),
("nn.BCEWithLogitsLoss", 29944, "nn.functional.binary_cross_entropy_with_logits"),
("nn.GRU", 25421, None),
("nn.Dropout2d", 23512, "nn.functional.dropout2d"),
("nn.LogSoftmax", 22897, "nn.functional.log_softmax"),
("nn.L1Loss", 22778, "nn.functional.l1_loss"),
("nn.GroupNorm", 22183, "nn.functional.group_norm"),
("nn.NLLLoss", 21751, "nn.functional.nll_loss"),
("nn.Conv3d", 20874, "nn.functional.conv3d"),
("nn.Identity", 17911, None),
("nn.InstanceNorm2d", 16426, "nn.functional.instance_norm"),
("nn.BatchNorm3d", 16378, "nn.functional.batch_norm"),
("nn.PReLU", 13472, "nn.functional.prelu"),
("nn.ReLU6", 12622, "nn.functional.relu6"),
("nn.ELU", 12508, "nn.functional.elu"),
("nn.LSTMCell", 10885, None),
("nn.Flatten", 10384, "torch.flatten"),
("nn.ModuleDict", 10255, None),
("nn.ReflectionPad2d", 9954, "nn.functional.pad"),
("nn.MaxPool3d", 9526, "nn.functional.max_pool3d"),
("nn.MaxPool1d", 9154, "nn.functional.max_pool1d"),
("nn.RNN", 9154, None),
("nn.ZeroPad2d", 8847, "nn.functional.pad"),
("nn.ParameterList", 7702, None),
("nn.SyncBatchNorm", 6814, None),
("nn.PixelShuffle", 6571, "nn.functional.pixel_shuffle"),
("nn.SmoothL1Loss", 6517, "nn.functional.smooth_l1_loss"),
("nn.Hardswish", 6458, "nn.functional.hardswish"),
("nn.AdaptiveMaxPool2d", 6071, "nn.functional.adaptive_max_pool2d"),
("nn.SELU", 6043, "nn.functional.selu"),
("nn.ConvTranspose3d", 6039, "nn.functional.conv_transpose3d"),
("nn.GRUCell", 5840, None),
("nn.ReplicationPad2d", 5600, "nn.functional.pad"),
("nn.KLDivLoss", 5541, "nn.functional.kl_div"),
("nn.ConvTranspose1d", 5183, "nn.functional.conv_transpose1d"),
("nn.Softplus", 5120, "nn.functional.softplus"),
("nn.SiLU", 4895, "nn.functional.silu"),
("nn.AvgPool3d", 4523, "nn.functional.avg_pool3d"),
("nn.CosineSimilarity", 4058, "nn.functional.cosine_similarity"),
("nn.GELU", 3932, "nn.functional.gelu"),
("nn.UpsamplingBilinear2d", 3673, "nn.functional.interpolate"),
("nn.InstanceNorm1d", 3658, "nn.functional.instance_norm"),
("nn.Transformer", 3604, None),
("nn.MultiheadAttention", 3435, "nn.functional.multi_head_attention_forward"),
("nn.AvgPool1d", 3195, "nn.functional.avg_pool1d"),
("nn.Dropout3d", 2964, "nn.functional.dropout3d"),
("nn.AdaptiveAvgPool3d", 2915, "nn.functional.adaptive_avg_pool3d"),
("nn.InstanceNorm3d", 2893, "nn.functional.instance_norm"),
("nn.Hardtanh", 2613, "nn.functional.hardtanh"),
("nn.MarginRankingLoss", 2568, "nn.functional.margin_ranking_loss"),
("nn.GLU", 2526, "nn.functional.glu"),
("nn.AdaptiveAvgPool1d", 2481, "nn.functional.adaptive_avg_pool1d"),
("nn.EmbeddingBag", 2344, "nn.functional.embedding_bag"),
("nn.TransformerEncoderLayer", 2292, None),
("nn.TransformerEncoder", 2091, None),
("nn.MaxUnpool2d", 2031, "nn.functional.max_unpool2d"),
("nn.UpsamplingNearest2d", 2004, "nn.functional.interpolate"),
("nn.ConstantPad1d", 1904, "nn.functional.pad"),
("nn.ConstantPad2d", 1791, "nn.functional.pad"),
("nn.CTCLoss", 1789, "nn.functional.ctc_loss"),
("nn.AdaptiveMaxPool1d", 1713, "nn.functional.adaptive_max_pool1d"),
("nn.AdaptiveLogSoftmaxWithLoss", 1665, None),
("nn.Bilinear", 1664, "nn.functional.bilinear"),
("nn.RNNCell", 1653, None),
("nn.MultiLabelSoftMarginLoss", 1624, "nn.functional.multilabel_soft_margin_loss"),
("nn.Unfold", 1452, "nn.functional.unfold"),
("nn.RReLU", 1431, "nn.functional.rrelu"),
("nn.CosineEmbeddingLoss", 1357, "nn.functional.cosine_embedding_loss"),
("nn.LocalResponseNorm", 1331, "nn.functional.local_response_norm"),
("nn.Softmax2d", 1300, "nn.functional.softmax"),
("nn.PairwiseDistance", 1241, "nn.functional.pairwise_distance"),
("nn.LogSigmoid", 1235, "nn.functional.logsigmoid"),
("nn.TripletMarginLoss", 1230, "nn.functional.triplet_margin_loss"),
("nn.RNNBase", 1133, None),
("nn.Threshold", 1043, "nn.functional.threshold"),
("nn.AdaptiveMaxPool3d", 1025, "nn.functional.adaptive_max_pool3d"),
("nn.CELU", 1018, "nn.functional.celu"),
("nn.NLLLoss2d", 966, "nn.functional.nll_loss"),
("nn.Softsign", 877, "nn.functional.softsign"),
("nn.ReplicationPad1d", 862, "nn.functional.pad"),
("nn.SoftMarginLoss", 856, "nn.functional.soft_margin_loss"),
("nn.ParameterDict", 742, None),
("nn.ReflectionPad1d", 731, "nn.functional.pad"),
("nn.Softshrink", 713, "nn.functional.softshrink"),
("nn.AlphaDropout", 710, "nn.functional.alpha_dropout"),
("nn.Tanhshrink", 681, "nn.functional.tanhshrink"),
("nn.PoissonNLLLoss", 676, "nn.functional.poisson_nll_loss"),
("nn.MaxUnpool3d", 660, "nn.functional.max_unpool3d"),
("nn.Fold", 630, "nn.functional.fold"),
("nn.MultiMarginLoss", 622, "nn.functional.multi_margin_loss"),
("nn.TransformerDecoderLayer", 614, None),
("nn.TransformerDecoder", 607, None),
("nn.Hardshrink", 592, "nn.functional.hardshrink"),
("nn.ConstantPad3d", 582, "nn.functional.pad"),
("nn.MultiLabelMarginLoss", 580, "nn.functional.multilabel_margin_loss"),
("nn.LPPool2d", 550, "nn.functional.lp_pool2d"),
("nn.Softmin", 537, "nn.functional.softmin"),
("nn.MaxUnpool1d", 518, "nn.functional.max_unpool1d"),
("nn.FractionalMaxPool2d", 484, "nn.functional.fractional_max_pool2d"),
("nn.Hardsigmoid", 477, "nn.functional.hardsigmoid"),
("nn.ReplicationPad3d", 470, "nn.functional.pad"),
("nn.HingeEmbeddingLoss", 442, "nn.functional.hinge_embedding_loss"),
("nn.LPPool1d", 386, "nn.functional.lp_pool1d"),
("nn.FractionalMaxPool3d", 252, "nn.functional.fractional_max_pool3d"),
("nn.Container", 217, None),
("nn.Unflatten", 206, "nn.functional.unflatten"),
("nn.FeatureAlphaDropout", 136, "nn.functional.feature_alpha_dropout"),
(
"nn.TripletMarginWithDistanceLoss",
107,
"nn.functional.triplet_margin_with_distance_loss",
),
("nn.ChannelShuffle", 90, "nn.functional.channel_shuffle"),
("nn.RNNCellBase", 88, None),
("nn.LazyLinear", 81, "nn.functional.linear"),
("nn.UninitializedParameter", 60, None),
("nn.CrossMapLRN2d", 59, None),
("nn.GaussianNLLLoss", 55, "nn.functional.gaussian_nll_loss"),
("nn.PixelUnshuffle", 45, "nn.functional.pixel_unshuffle"),
("nn.Mish", 31, "nn.functional.mish"),
("nn.ReflectionPad3d", 22, "nn.functional.pad"),
("nn.HuberLoss", 18, "nn.functional.huber_loss"),
("nn.LazyConv2d", 15, None),
("nn.LazyConv1d", 9, None),
("nn.LazyConv3d", 8, None),
("nn.LazyConvTranspose1d", 8, None),
("nn.LazyConvTranspose2d", 8, None),
("nn.LazyConvTranspose3d", 8, None),
("nn.LazyBatchNorm1d", 3, None),
("nn.LazyBatchNorm2d", 3, None),
("nn.LazyBatchNorm3d", 3, None),
("nn.UninitializedBuffer", 3, None),
]
# No rankings because these are a little hard to get rankings for
method_only_ops = [
"bfloat16",
"bool",
"byte",
"char",
"contiguous",
"cpu",
"cuda",
"detach",
"double",
"expand",
"expand_as",
"float",
"get_device",
"half",
"hardshrink",
"index_add",
"index_copy",
"index_fill",
"index_put",
"int",
"is_contiguous",
"is_pinned",
"is_set_to",
"is_shared",
"is_signed",
"item",
"long",
"masked_scatter",
"masked_fill",
"narrow_copy",
"numpy",
"pin_memory",
"repeat",
"reshape_as",
"select",
"short",
"storage_offset",
"sum_to_size",
"to",
"to_mkldnn",
"tolist",
"type",
"type_as",
"unfold",
"view",
"view_as",
]
def get_nn_functional_top_list():
top_nn_functional_ = dict(top_nn_functional)
for _, count, functional_name in top_nn_module:
if functional_name is None:
continue
if functional_name == "torch.flatten":
continue
if functional_name not in top_nn_functional_:
top_nn_functional_[functional_name] = count
else:
top_nn_functional_[functional_name] += count
top_nn_functional_ = list(top_nn_functional_.items())
top_nn_functional_.sort(key=operator.itemgetter(1), reverse=True)
return top_nn_functional_
usage_count = {}
for k, v in get_nn_functional_top_list():
usage_count[k] = v
for k, v in top_torch:
usage_count[k] = v

View File

@ -0,0 +1,40 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Tuple, Union
import torch
from torch._C._functorch import (
get_single_level_autograd_function_allowed,
set_single_level_autograd_function_allowed,
unwrap_if_dead,
)
from torch.utils._exposed_in import exposed_in
__all__ = [
"exposed_in",
"argnums_t",
"enable_single_level_autograd_function",
"unwrap_dead_wrappers",
]
@contextlib.contextmanager
def enable_single_level_autograd_function():
try:
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
)
return result
argnums_t = Union[int, Tuple[int, ...]]

View File

@ -0,0 +1,532 @@
# mypy: ignore-errors
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import functools
import itertools
import os
import threading
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch._C._functorch import (
_add_batch_dim,
_remove_batch_dim,
_vmap_decrement_nesting,
_vmap_increment_nesting,
is_batchedtensor,
)
from torch.utils._pytree import (
_broadcast_to_and_flatten,
tree_flatten,
tree_map_,
tree_unflatten,
TreeSpec,
)
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
def doesnt_support_saved_tensors_hooks(f):
message = (
"torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)
@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]], flat_args: List
) -> int:
batch_sizes = [
arg.size(in_dim)
for in_dim, arg in zip(flat_in_dims, flat_args)
if in_dim is not None
]
if len(batch_sizes) == 0:
raise ValueError("vmap: Expected at least one Tensor to vmap over")
if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
raise ValueError(
f"vmap: Expected all tensors to have the same size in the mapped "
f"dimension, got sizes {batch_sizes} for the mapped dimension"
)
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(
value: Any, num_elements: int, error_message_lambda: Callable[[], str]
) -> Tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
raise ValueError(error_message_lambda())
return value
def _process_batched_inputs(
in_dims: in_dims_t, args: Tuple, func: Callable
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
f"expected `in_dims` to be int or a (potentially nested) tuple "
f"matching the structure of inputs, got: {type(in_dims)}."
)
if len(args) == 0:
raise ValueError(
f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
f"inputs, or you are trying to vmap over a function with no inputs. "
f"The latter is unsupported."
)
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
f"in_dims is not compatible with the structure of `inputs`. "
f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
f"has structure {args_spec}."
)
for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
if not isinstance(in_dim, int) and in_dim is not None:
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
f"Got in_dim={in_dim} for an input but in_dim must be either "
f"an integer dimension or None."
)
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
f"Got in_dim={in_dim} for an input but the input is of type "
f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
f"please use None as the respective in_dim"
)
if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
f"Got in_dim={in_dim} for some input, but that input is a Tensor "
f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
f"-{arg.dim()} <= in_dim < {arg.dim()}."
)
if in_dim is not None and in_dim < 0:
flat_in_dims[i] = in_dim % arg.dim()
return (
_validate_and_get_batch_size(flat_in_dims, flat_args),
flat_in_dims,
flat_args,
args_spec,
)
# Creates BatchedTensors for every Tensor in arg that should be batched.
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec
) -> Tuple:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [
arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
for in_dim, arg in zip(flat_in_dims, flat_args)
]
return tree_unflatten(batched_inputs, args_spec)
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
if out_dim is None:
if isinstance(batched_output, torch.Tensor) and is_batchedtensor(
batched_output
):
raise ValueError(
f"vmap({name}, ...): `{name}` can not return a "
f"BatchedTensor when out_dim is None"
)
return batched_output
# out_dim is non None
if not isinstance(batched_output, torch.Tensor):
raise ValueError(
f"vmap({name}, ...): `{name}` must only return "
f"Tensors, got type {type(batched_output)}. "
"Did you mean to set out_dims= to None for output?"
)
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int,
batch_size: int,
func: Callable,
) -> Tuple:
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
def incompatible_error():
raise ValueError(
f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
f"out_dims is not compatible with the structure of `outputs`. "
f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
f"has structure {output_spec}."
)
if isinstance(batched_outputs, torch.Tensor):
# Some weird edge case requires us to spell out the following
# see test_out_dims_edge_case
if isinstance(out_dims, int):
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
elif out_dims is None:
flat_out_dims = [out_dims]
else:
incompatible_error()
else:
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
if flat_out_dims is None:
incompatible_error()
flat_outputs = [
_maybe_remove_batch_dim(
_get_name(func), batched_output, vmap_level, batch_size, out_dim
)
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
]
return tree_unflatten(flat_outputs, output_spec)
def _check_int_or_none(x, func, out_dims):
if isinstance(x, int):
return
if x is None:
return
raise ValueError(
f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
f"an int, None or a python collection of ints representing where in the outputs the "
f"vmapped dimension should appear."
)
def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
if isinstance(out_dims, int):
return
tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
def _get_name(func: Callable):
if hasattr(func, "__name__"):
return func.__name__
# Not all callables have __name__, in fact, only static functions/methods do.
# A callable created via functools.partial or an nn.Module, to name some
# examples, don't have a __name__.
return repr(func)
DECOMPOSITIONS_LOADED = False
DECOMPOSITIONS_LOCK = threading.Lock()
VMAP_DECOMPOSITIONS_LIB = None
# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
# decompositions. Only load them when needed if possible.
def lazy_load_decompositions():
global DECOMPOSITIONS_LOADED
if DECOMPOSITIONS_LOADED:
return
with DECOMPOSITIONS_LOCK:
if DECOMPOSITIONS_LOADED:
return
if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
DECOMPOSITIONS_LOADED = True
return
# use an alternate way to register an operator into the decomposition table
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
# because the Tensor types generated cannot be unioned by torchscript
# decomp should be type OpOverload
global VMAP_DECOMPOSITIONS_LIB
VMAP_DECOMPOSITIONS_LIB = torch.library.Library(
"aten", "IMPL", "FuncTorchBatched"
)
from torch._decomp import decomposition_table
def _register_python_decomposition_vmap(decomp):
if decomp in decomposition_table:
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(
torch.ops.aten.smooth_l1_loss_backward.default
)
_register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
DECOMPOSITIONS_LOADED = True
def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
lazy_load_decompositions()
_check_out_dims_is_int_or_int_pytree(out_dims, func)
batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
in_dims, args, func
)
if chunk_size is not None:
chunks_flat_args = _get_chunked_inputs(
flat_args, flat_in_dims, batch_size, chunk_size
)
return _chunked_vmap(
func,
flat_in_dims,
chunks_flat_args,
args_spec,
out_dims,
randomness,
**kwargs,
)
# If chunk_size is not specified.
return _flat_vmap(
func,
batch_size,
flat_in_dims,
flat_args,
args_spec,
out_dims,
randomness,
**kwargs,
)
def get_chunk_sizes(total_elems, chunk_size):
n_chunks = n_chunks = total_elems // chunk_size
chunk_sizes = [chunk_size] * n_chunks
# remainder chunk
remainder = total_elems % chunk_size
if remainder != 0:
chunk_sizes.append(remainder)
return chunk_sizes
def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
split_idxs = (batch_size,)
if chunk_size is not None:
chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
split_idxs = tuple(itertools.accumulate(chunk_sizes))
flat_args_chunks = tuple(
t.tensor_split(split_idxs, dim=in_dim)
if in_dim is not None
else [
t,
]
* len(split_idxs)
for t, in_dim in zip(flat_args, flat_in_dims)
)
# transpose chunk dim and flatten structure
# chunks_flat_args is a list of flatten args
chunks_flat_args = zip(*flat_args_chunks)
return chunks_flat_args
def _flatten_chunks_output(chunks_output_):
# chunks_output is a list of chunked outputs
# flatten chunked outputs:
flat_chunks_output = []
arg_spec = None
for output in chunks_output_:
flat_output, arg_specs = tree_flatten(output)
flat_chunks_output.append(flat_output)
if arg_spec is None:
arg_spec = arg_specs
# transpose chunk dim and flatten structure
# flat_output_chunks is flat list of chunks
flat_output_chunks = list(zip(*flat_chunks_output))
return flat_output_chunks, arg_spec
def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
# concat chunks on out_dim
flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
assert len(flat_out_dims) == len(flat_output_chunks)
flat_output = []
for idx, out_dim in enumerate(flat_out_dims):
flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
# release tensors
flat_output_chunks[idx] = None
return flat_output
# Applies vmap on chunked_input and returns concatenated output over the chunks.
def _chunked_vmap(
func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs
):
chunks_output = []
rs = torch.get_rng_state() if randomness == "same" else None
for flat_args in chunks_flat_args:
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
# The way we compute split the input in `_get_chunked_inputs`,
# we may get a tensor with `0` batch-size. We skip any computation
# in that case.
# Eg.
# >>> chunk_size = 1
# >>> batch_size = 6
# >>> t = torch.zeros(batch_size, 1)
# >>> t.tensor_split([1, 2, 3, 4, 5, 6])
# (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
# tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
if batch_size == 0:
continue
if rs is not None:
torch.set_rng_state(rs)
chunks_output.append(
_flat_vmap(
func,
batch_size,
flat_in_dims,
flat_args,
args_spec,
out_dims,
randomness,
**kwargs,
)
)
flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
# chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
# eagerly remove the reference from `chunks_output`.
del chunks_output
# concat chunks on out_dim
flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
# finally unflatten the output
return tree_unflatten(flat_output, arg_spec)
# Vmap refactored helper functions:
def _check_randomness_arg(randomness):
if randomness not in ["error", "different", "same"]:
raise RuntimeError(
f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}"
)
@contextlib.contextmanager
def vmap_increment_nesting(batch_size, randomness):
try:
vmap_level = _vmap_increment_nesting(batch_size, randomness)
yield vmap_level
finally:
_vmap_decrement_nesting()
def _flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
):
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
batched_inputs = _create_batched_inputs(
flat_in_dims, flat_args, vmap_level, args_spec
)
batched_outputs = func(*batched_inputs, **kwargs)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
# `restore_vmap` is a private helper function. It is vmap but has the following
# differences:
# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
# out_dims is a pytree of same shape as outputs and contains Optional[int]
# specifying where the vmapped dimension, if it exists, is in the corresponding output.
# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
# restore_vmap allows for no inputs to have the vmap dimension
# - does no validation on outputs (vmap expects only Tensor outputs)
# restore_vmap allows for return of arbitrary outputs (not just Tensors)
#
# The TL;DR is that restore_vmap is more general than vmap and has a slightly
# different API. The relaxations are so that we can "pause" vmap in the middle
# of its execution and then "restore" it later (this is what we do in
# the generate_vmap_rule=True implementation of autograd.Function).
#
# restore_vmap can be technically used in the implementation of vmap, but doing
# that refactor is a bit technically challenging because:
# - vmap couples the tensor-wrapping code with error checking
# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
# in python because it overlaps with unwrap_batched
def restore_vmap(func, in_dims, batch_size, randomness):
def inner(*args, **kwargs):
with vmap_increment_nesting(batch_size, randomness) as vmap_level:
batched_inputs = wrap_batched(args, in_dims, vmap_level)
batched_outputs = func(*batched_inputs, **kwargs)
return unwrap_batched(batched_outputs, vmap_level)
return inner
def wrap_batched(args, bdims, level):
flat_args, spec = tree_flatten(args)
flat_bdims = _broadcast_to_and_flatten(bdims, spec)
assert flat_bdims is not None
result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
return result
def unwrap_batched(args, level):
flat_args, spec = tree_flatten(args)
if len(flat_args) == 0:
return args, ()
result = [
torch._C._functorch._unwrap_batched(arg, level)
if isinstance(arg, torch.Tensor)
else (arg, None)
for arg in flat_args
]
output, bdims = zip(*result)
return tree_unflatten(output, spec), tree_unflatten(bdims, spec)