I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,599 @@
# mypy: allow-untyped-defs
"""
``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions.
It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s
for which gradients should be computed with the ``requires_grad=True`` keyword.
As of now, we only support autograd for floating point :class:`Tensor` types (
half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
"""
import warnings
from typing import cast, List, Optional, Sequence, Tuple, Union
import torch
from torch import _vmap_internals
from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
from . import forward_ad, functional, graph
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from .function import Function, NestedIOFunction
from .grad_mode import (
_force_original_view_tracking,
_unsafe_preserve_version_counter,
enable_grad,
inference_mode,
no_grad,
set_grad_enabled,
set_multithreading_enabled,
)
from .gradcheck import gradcheck, gradgradcheck
from .graph import _engine_run_backward
from .variable import Variable
__all__ = [
"Variable",
"Function",
"backward",
"grad_mode",
"NestedIOFunction",
"detect_anomaly",
"enable_grad",
"grad",
"gradcheck",
"gradgradcheck",
"inference_mode",
"no_grad",
"set_detect_anomaly",
"set_grad_enabled",
"set_multithreading_enabled",
"variable",
]
_OptionalTensor = Optional[torch.Tensor]
_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
def _calculate_shape(
output: Union[torch.Tensor, graph.GradientEdge],
grad: torch.Tensor,
is_grads_batched: bool,
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
# is_same_size ensures that both tensors are either nested or non nested
# circular import
from torch.nested._internal.nested_tensor import NestedTensor
if isinstance(output, graph.GradientEdge):
# We have already checked that we are not a C++ NestedTensor
if is_grads_batched:
raise RuntimeError("Batched grads are not supported with GradientEdge")
out_metadata = output.node._input_metadata[output.output_nr]
return torch.Size(out_metadata.shape), grad.shape
if output.is_nested and not isinstance(output, NestedTensor):
if is_grads_batched:
raise RuntimeError("Batched grads are not supported with Nested Tensor.")
out_shape = output._nested_tensor_size()
grad_shape = grad._nested_tensor_size()
return out_shape, grad_shape
reg_out_shape = output.shape
reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
return reg_out_shape, reg_grad_shape
def _make_grads(
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
grads: Sequence[_OptionalTensor],
is_grads_batched: bool,
) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_size = None
out_device = None
if isinstance(out, graph.GradientEdge):
out_metadata = out.node._input_metadata[out.output_nr]
out_size = torch.Size(out_metadata.shape)
out_dtype = out_metadata.dtype
out_device = out_metadata.device
out_is_nested = out_metadata.is_nested_tensor
if out_metadata.is_cpp_nested_tensor:
raise RuntimeError(
"C++ NestedTensor are not supported with GradientEdge"
)
out_is_cpp_nested = False
else:
# circular import
from torch.nested._internal.nested_tensor import NestedTensor
assert isinstance(out, torch.Tensor)
out_dtype = out.dtype
out_is_nested = out.is_nested
out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
if not out_is_cpp_nested:
out_size = out.shape
if isinstance(grad, torch.Tensor):
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
first_grad = grad if not is_grads_batched else grad[0]
# TODO: We can remove this conditional once we uniformly use
# singleton int to represent jagged dimension, so that size() call
# on nested tensor works.
if out_is_cpp_nested:
assert isinstance(out, torch.Tensor)
shape_matches = torch.is_same_size(out, first_grad)
else:
# We need to do a regular size check, without going through
# the operator, to be able to handle unbacked symints
# (expect_true ensures we can deal with unbacked)
assert out_size is not None
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
if not shape_matches:
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_shape, grad_shape = _calculate_shape(
out, first_grad, is_grads_batched
)
if is_grads_batched:
raise RuntimeError(
"If `is_grads_batched=True`, we interpret the first "
"dimension of each grad_output as the batch dimension. "
"The sizes of the remaining dimensions are expected to match "
"the shape of corresponding output, but a mismatch "
"was detected: grad_output["
+ str(grads.index(grad))
+ "] has a shape of "
+ str(grad_shape)
+ " and output["
+ str(outputs.index(out))
+ "] has a shape of "
+ str(out_shape)
+ ". "
"If you only want some tensors in `grad_output` to be considered "
"batched, consider using vmap."
)
else:
raise RuntimeError(
"Mismatch in shape: grad_output["
+ str(grads.index(grad))
+ "] has a shape of "
+ str(grad_shape)
+ " and output["
+ str(outputs.index(out))
+ "] has a shape of "
+ str(out_shape)
+ "."
)
if out_dtype.is_complex != grad.dtype.is_complex:
raise RuntimeError(
"For complex Tensors, both grad_output and output"
" are required to have the same dtype."
" Mismatch in dtype: grad_output["
+ str(grads.index(grad))
+ "] has a dtype of "
+ str(grad.dtype)
+ " and output["
+ str(outputs.index(out))
+ "] has a dtype of "
+ str(out_dtype)
+ "."
)
new_grads.append(grad)
elif grad is None:
if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined]
if isinstance(out, graph.GradientEdge):
assert out_size is not None
out_numel_is_1 = all(o == 1 for o in out_size)
else:
assert isinstance(out, torch.Tensor)
out_numel_is_1 = out.numel() == 1
if not out_numel_is_1:
raise RuntimeError(
"grad can be implicitly created only for scalar outputs"
)
if not out_dtype.is_floating_point:
msg = (
"grad can be implicitly created only for real scalar outputs"
f" but got {out_dtype}"
)
raise RuntimeError(msg)
if isinstance(out, graph.GradientEdge):
assert out_size is not None
assert out_device is not None
new_grads.append(
torch.ones(
out_size,
dtype=out_dtype,
device=out_device,
)
)
else:
assert isinstance(out, torch.Tensor)
new_grads.append(
torch.ones_like(out, memory_format=torch.preserve_format)
)
else:
new_grads.append(None)
else:
raise TypeError(
"gradients can be either Tensors or None, but got "
+ type(grad).__name__
)
return tuple(new_grads)
def _tensor_or_tensors_to_tuple(
tensors: Optional[_TensorOrTensors], length: int
) -> Tuple[_OptionalTensor, ...]:
if tensors is None:
return (None,) * length
if isinstance(tensors, torch.Tensor):
return (tensors,)
return tuple(tensors)
def backward(
tensors: _TensorOrTensors,
grad_tensors: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
grad_variables: Optional[_TensorOrTensors] = None,
inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
) -> None:
r"""Compute the sum of gradients of given tensors with respect to graph leaves.
The graph is differentiated using the chain rule. If any of ``tensors``
are non-scalar (i.e. their data has more than one element) and require
gradient, then the Jacobian-vector product would be computed, in this
case the function additionally requires specifying ``grad_tensors``.
It should be a sequence of matching length, that contains the "vector"
in the Jacobian-vector product, usually the gradient of the differentiated
function w.r.t. corresponding tensors (``None`` is an acceptable value for
all tensors that don't need gradient tensors).
This function accumulates gradients in the leaves - you might need to zero
``.grad`` attributes or set them to ``None`` before calling it.
See :ref:`Default gradient layouts<default-grad-layouts>`
for details on the memory layout of accumulated gradients.
.. note::
Using this method with ``create_graph=True`` will create a reference cycle
between the parameter and its gradient which can cause a memory leak.
We recommend using ``autograd.grad`` when creating the graph to avoid this.
If you have to use this function, make sure to reset the ``.grad`` fields of your
parameters to ``None`` after use to break the cycle and avoid the leak.
.. note::
If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
in a user-specified CUDA stream context, see
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
.. note::
When ``inputs`` are provided and a given input is not a leaf,
the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
It is an implementation detail on which the user should not rely.
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
Args:
tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
computed.
grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
the Jacobian-vector product, usually gradients w.r.t. each element of
corresponding tensors. None values can be specified for scalar Tensors or
ones that don't require grad. If a None value would be acceptable for all
grad_tensors, then this argument is optional.
retain_graph (bool, optional): If ``False``, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this option to ``True``
is not needed and often can be worked around in a much more efficient
way. Defaults to the value of ``create_graph``.
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Defaults to ``False``.
inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
be will accumulated into ``.grad``. All other Tensors will be ignored. If
not provided, the gradient is accumulated into all the leaf Tensors that
were used to compute the :attr:`tensors`.
"""
if torch._C._are_functorch_transforms_active():
raise RuntimeError(
"backward() called inside a functorch transform. This is not "
"supported, please use functorch.grad or functorch.vjp instead "
"or call backward() outside of functorch transforms."
)
if grad_variables is not None:
warnings.warn(
"`grad_variables` is deprecated. Use `grad_tensors` instead.",
FutureWarning,
stacklevel=2,
)
if grad_tensors is None:
grad_tensors = grad_variables
else:
raise RuntimeError(
"`grad_tensors` and `grad_variables` (deprecated) "
"arguments both passed to `backward()`. Please only "
"use `grad_tensors`."
)
if inputs is not None and len(inputs) == 0:
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
inputs = (
(inputs,)
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
else tuple(inputs)
if inputs is not None
else ()
)
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
if retain_graph is None:
retain_graph = create_graph
# The reason we repeat the same comment below is that
# some Python versions print out the first line of a multi-line function
# calls in the traceback and some print out the last line
_engine_run_backward(
tensors,
grad_tensors_,
retain_graph,
create_graph,
inputs,
allow_unreachable=True,
accumulate_grad=True,
)
def grad(
outputs: _TensorOrTensorsOrGradEdge,
inputs: _TensorOrTensorsOrGradEdge,
grad_outputs: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
only_inputs: bool = True,
allow_unused: Optional[bool] = None,
is_grads_batched: bool = False,
materialize_grads: bool = False,
) -> Tuple[torch.Tensor, ...]:
r"""Compute and return the sum of gradients of outputs with respect to the inputs.
``grad_outputs`` should be a sequence of length matching ``output``
containing the "vector" in vector-Jacobian product, usually the pre-computed
gradients w.r.t. each of the outputs. If an output doesn't require_grad,
then the gradient can be ``None``).
.. note::
If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
in a user-specified CUDA stream context, see
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
.. note::
``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
To accumulate gradient for other parts of the graph, please use
``torch.autograd.backward``.
Args:
outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
returned (and not accumulated into ``.grad``).
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
Usually gradients w.r.t. each output. None values can be specified for scalar
Tensors or ones that don't require grad. If a None value would be acceptable
for all grad_tensors, then this argument is optional. Default: None.
retain_graph (bool, optional): If ``False``, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this option to ``True``
is not needed and often can be worked around in a much more efficient
way. Defaults to the value of ``create_graph``.
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Default: ``False``.
allow_unused (Optional[bool], optional): If ``False``, specifying inputs
that were not used when computing outputs (and therefore their grad is
always zero) is an error. Defaults to the value of ``materialize_grads``.
is_grads_batched (bool, optional): If ``True``, the first dimension of each
tensor in ``grad_outputs`` will be interpreted as the batch dimension.
Instead of computing a single vector-Jacobian product, we compute a
batch of vector-Jacobian products for each "vector" in the batch.
We use the vmap prototype feature as the backend to vectorize calls
to the autograd engine so that this computation can be performed in a
single call. This should lead to performance improvements when compared
to manually looping and performing backward multiple times. Note that
due to this feature being experimental, there may be performance
cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
to show any performance warnings and file an issue on github if warnings exist
for your use case. Defaults to ``False``.
materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
to zero instead of None. This is useful when computing higher-order derivatives.
If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
will be raised. Defaults to ``False``.
"""
if materialize_grads and allow_unused is False:
raise ValueError(
"Expected allow_unused to be True or not passed when materialize_grads=True, "
"but got: allow_unused=False."
)
if allow_unused is None:
allow_unused = materialize_grads
if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
outputs = cast(
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
)
else:
outputs = tuple(outputs)
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
else:
inputs = tuple(inputs)
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
overridable_args = t_outputs + t_inputs
if has_torch_function(overridable_args):
return handle_torch_function(
grad,
overridable_args,
outputs,
inputs,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
only_inputs=only_inputs,
allow_unused=allow_unused,
is_grads_batched=is_grads_batched,
materialize_grads=materialize_grads,
)
if not only_inputs:
warnings.warn(
"only_inputs argument is deprecated and is ignored now "
"(defaults to True). To accumulate gradient for other "
"parts of the graph, please use torch.autograd.backward.",
FutureWarning,
stacklevel=2,
)
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
grad_outputs_ = _make_grads(
outputs, grad_outputs_, is_grads_batched=is_grads_batched
)
if retain_graph is None:
retain_graph = create_graph
# The reason we repeat the same comment several times below is because
# some Python versions print out the first line of multi-line function
# calls in the traceback and some print out the last line
if is_grads_batched:
def vjp(gO):
return _engine_run_backward(
outputs,
gO,
retain_graph,
create_graph,
inputs,
allow_unused,
accumulate_grad=False,
)
result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
grad_outputs_
)
else:
result = _engine_run_backward(
outputs,
grad_outputs_,
retain_graph,
create_graph,
inputs,
allow_unused,
accumulate_grad=False,
)
if materialize_grads:
if any(
result[i] is None and not is_tensor_like(inputs[i])
for i in range(len(inputs))
):
raise RuntimeError(
"materialize_grads cannot be used when the given input is a GradientEdge"
)
result = tuple(
output
if output is not None
else torch.zeros_like(input, requires_grad=True)
for (output, input) in zip(result, inputs)
)
return result
# This function applies in case of gradient checkpointing for memory
# optimization. Currently, gradient checkpointing is supported only if the
# execution engine is invoked through torch.autograd.backward() and its
# inputs argument is not passed. It is not supported for torch.autograd.grad().
# This is because if inputs are specified, the gradient won't be calculated for
# anything else e.g. model parameters like weights, bias etc.
#
# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
# in the stack and before a NodeTask is executed in evaluate_function, it
# checks for whether reentrant backwards is imperative or not.
# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
def _is_checkpoint_valid():
return Variable._execution_engine.is_checkpoint_valid()
def variable(*args, **kwargs): # noqa: D103
raise RuntimeError(
"torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
)
# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
# output of an FX graph. Unfortunately the module name torch.autograd.variable is shadowed by the
# deprecated function - variable(...).
variable.Variable = Variable # type: ignore[attr-defined]
if not torch._C._autograd_init():
raise RuntimeError("autograd initialization failed")
# Import all native method/classes
from torch._C._autograd import (
_add_metadata_json,
_disable_profiler,
_disable_profiler_legacy,
_enable_profiler,
_enable_profiler_legacy,
_enable_record_function,
_get_sequence_nr,
_kineto_step,
_KinetoEvent,
_pop_saved_tensors_default_hooks,
_prepare_profiler,
_profiler_enabled,
_ProfilerResult,
_push_saved_tensors_default_hooks,
_record_function_with_args_enter,
_record_function_with_args_exit,
_set_empty_test_observer,
_supported_activities,
_toggle_collection_dynamic,
DeviceType,
kineto_available,
ProfilerEvent,
SavedTensor,
)
from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
from . import profiler
def _register_py_tensor_class_for_device(device, cls):
if not isinstance(cls, type):
raise RuntimeError("cls isn't a typeinfo object")
torch._C._register_py_class_for_device(device, cls)
is_multithreading_enabled = torch._C._is_multithreading_enabled
torch._C._add_docstr(
is_multithreading_enabled, "Returns True if multithreading is currently enabled."
)
is_view_replay_enabled = torch._C._is_view_replay_enabled
torch._C._add_docstr(
is_view_replay_enabled, "Returns True if view-replay is currently enabled."
)

View File

@ -0,0 +1 @@
from .tensor import * # noqa: F403

View File

@ -0,0 +1,65 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
from typing_extensions import deprecated
import torch
import torch._utils
from torch.autograd.function import Function
class Type(Function):
@staticmethod
@deprecated(
"`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
"please use `torch.tensor.to(dtype=dtype)` instead.",
category=FutureWarning,
)
def forward(ctx, i, dest_type):
ctx.input_type = type(i)
ctx.input_device = -1 if not i.is_cuda else i.get_device()
return i.type(dest_type)
@staticmethod
def backward(ctx, grad_output):
if ctx.input_device == -1:
return grad_output.type(ctx.input_type), None
else:
with torch.cuda.device(ctx.input_device):
return grad_output.type(ctx.input_type), None
# TODO: deprecate this
class Resize(Function):
@staticmethod
def forward(ctx, tensor, sizes):
ctx.sizes = sizes
ctx.numel = reduce(operator.mul, sizes, 1)
if tensor.numel() != ctx.numel:
raise RuntimeError(
(
"requested resize to {} ({} elements in total), "
"but the given tensor has a size of {} ({} elements). "
"autograd's resize can only change the shape of a given "
"tensor, while preserving the number of elements. "
).format(
"x".join(map(str, sizes)),
ctx.numel,
"x".join(map(str, tensor.size())),
tensor.numel(),
)
)
ctx.input_sizes = tensor.size()
if tensor.is_quantized:
tensor.copy_(tensor)
return tensor.contiguous().view(*sizes)
if tensor.is_contiguous():
result = tensor.new(tensor).contiguous().view(*sizes)
return result
else:
return tensor.contiguous().view(*sizes)
@staticmethod
def backward(ctx, grad_output):
assert grad_output.numel() == ctx.numel
return grad_output.contiguous().view(ctx.input_sizes), None

View File

@ -0,0 +1,63 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
def maybe_view(tensor, size, check_same_size=True):
if check_same_size and tensor.size() == size:
return tensor
return tensor.contiguous().view(size)
def maybe_unexpand(tensor, old_size, check_same_size=True):
if check_same_size and tensor.size() == old_size:
return tensor
num_unsqueezed = tensor.dim() - len(old_size)
expanded_dims = [
dim
for dim, (expanded, original) in enumerate(
zip(tensor.size()[num_unsqueezed:], old_size)
)
if expanded != original
]
for _ in range(num_unsqueezed):
tensor = tensor.sum(0, keepdim=False)
for dim in expanded_dims:
tensor = tensor.sum(dim, keepdim=True)
return tensor
# Check whether the op enable broadcasting, and whether it is supported by ONNX.
# If dims1 and dims2 are different, then broadcast is True.
# We always assume the combination of dims1 and dims2 is broadcastable.
# The following types of broadcasting are supported in ONNX:
# 1) Only one element in dims2, such as dims2 = [1, 1]
# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4]
# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
def check_onnx_broadcast(dims1, dims2):
broadcast = False
supported = True
len1 = len(dims1)
len2 = len(dims2)
numel1 = reduce(operator.mul, dims1)
numel2 = reduce(operator.mul, dims2)
if len1 < len2:
broadcast = True
if numel2 != 1:
supported = False
elif len1 > len2:
broadcast = True
if numel2 != 1 and dims1[len1 - len2 :] != dims2:
supported = False
else:
if dims1 != dims2:
broadcast = True
if numel2 != 1:
supported = False
if not supported:
raise ValueError(
f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}"
)
return broadcast

View File

@ -0,0 +1,121 @@
# mypy: allow-untyped-defs
r"""Autograd anomaly mode."""
import warnings
import torch
__all__ = ["detect_anomaly", "set_detect_anomaly"]
class detect_anomaly:
r"""Context-manager that enable anomaly detection for the autograd engine.
This does two things:
- Running the forward pass with detection enabled will allow the backward
pass to print the traceback of the forward operation that created the failing
backward function.
- If ``check_nan`` is ``True``, any backward computation that generate "nan"
value will raise an error. Default ``True``.
.. warning::
This mode should be enabled only for debugging as the different tests
will slow down your program execution.
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY)
>>> import torch
>>> from torch import autograd
>>> class MyFunc(autograd.Function):
... @staticmethod
... def forward(ctx, inp):
... return inp.clone()
... @staticmethod
... def backward(ctx, gO):
... # Error during the backward pass
... raise RuntimeError("Some error in backward")
... return gO.clone()
>>> def run_fn(a):
... out = MyFunc.apply(a)
... return out.sum()
>>> inp = torch.rand(10, 10, requires_grad=True)
>>> out = run_fn(inp)
>>> out.backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
return self._forward_cls.backward(self, *args)
File "<stdin>", line 8, in backward
RuntimeError: Some error in backward
>>> with autograd.detect_anomaly():
... inp = torch.rand(10, 10, requires_grad=True)
... out = run_fn(inp)
... out.backward()
Traceback of forward call that caused the error:
File "tmp.py", line 53, in <module>
out = run_fn(inp)
File "tmp.py", line 44, in run_fn
out = MyFunc.apply(a)
Traceback (most recent call last):
File "<stdin>", line 4, in <module>
File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
return self._forward_cls.backward(self, *args)
File "<stdin>", line 8, in backward
RuntimeError: Some error in backward
"""
def __init__(self, check_nan=True) -> None: # noqa: D107
self.prev = torch.is_anomaly_enabled()
self.check_nan = check_nan
self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
warnings.warn(
"Anomaly Detection has been enabled. "
"This mode will increase the runtime "
"and should only be enabled for debugging.",
stacklevel=2,
)
def __enter__(self) -> None: # noqa: D105
torch.set_anomaly_enabled(True, self.check_nan)
def __exit__(self, *args: object) -> None: # noqa: D105
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
class set_detect_anomaly:
r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
``set_detect_anomaly`` will enable or disable the autograd anomaly detection
based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
See ``detect_anomaly`` above for details of the anomaly detection behaviour.
Args:
mode (bool): Flag whether to enable anomaly detection (``True``),
or disable (``False``).
check_nan (bool): Flag whether to raise an error when the backward
generate "nan"
"""
def __init__(self, mode: bool, check_nan: bool = True) -> None: # noqa: D107
self.prev = torch.is_anomaly_enabled()
self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
torch.set_anomaly_enabled(mode, check_nan)
def __enter__(self) -> None: # noqa: D105
pass
def __exit__(self, *args: object) -> None: # noqa: D105
torch.set_anomaly_enabled(self.prev, self.prev_check_nan)

View File

@ -0,0 +1,231 @@
# mypy: allow-untyped-defs
import os
from collections import namedtuple
from typing import Any
import torch
from .grad_mode import _DecoratorContextManager
__all__ = [
"UnpackedDualTensor",
"enter_dual_level",
"exit_dual_level",
"make_dual",
"unpack_dual",
"dual_level",
]
# Global variable used to make the python API simpler to use
_current_level = -1
def enter_dual_level():
r"""Enter a new forward grad level.
This level can be used to make and unpack dual Tensors to compute
forward gradients.
This function also updates the current level that is used by default
by the other functions in this API.
"""
global _current_level
new_level = torch._C._enter_dual_level()
if new_level != _current_level + 1:
raise RuntimeError(
"Entering a new forward AD level but the current level "
"is not valid. Make sure you did not modified it directly."
)
_current_level = new_level
return new_level
def exit_dual_level(*, level=None):
r"""Exit a forward grad level.
This function deletes all the gradients associated with this
level. Only deleting the latest entered level is allowed.
This function also updates the current level that is used by default
by the other functions in this API.
"""
global _current_level
if level is None:
level = _current_level
if level != _current_level:
raise RuntimeError(
"Trying to exit a forward AD level that was not the last one "
"that was created. This is not supported."
)
torch._C._exit_dual_level(level=level)
_current_level = level - 1
def _maybe_load_decompositions():
if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__:
from torch._decomp import decompositions_for_jvp # noqa: F401
def make_dual(tensor, tangent, *, level=None):
r"""Associate a tensor value with its tangent to create a "dual tensor" for forward AD gradient computation.
The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded
as an attribute as-is if it has the same storage layout or copied otherwise.
The tangent attribute can be recovered with :func:`unpack_dual`.
This function is backward differentiable.
Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`)
between `J` and a given vector `v` as follows.
Example::
>>> # xdoctest: +SKIP("Undefined variables")
>>> with dual_level():
... inp = make_dual(x, v)
... out = f(inp)
... y, jvp = unpack_dual(out)
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API.
"""
# See NOTE: [forward-mode AD decompositions mechanism]
#
# Import from torch._decomp import decompositions_for_jvp to register
# decompositions for jvp to the jit registry
#
# FIXME: We specify that __debug__ must be True because
# if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the
# following error:
#
# Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of
# type Tuple[Tensor, Tensor]:
# File ".../torch/_decomp/__init__.py", line 1585
# else:
# buffer = z
# return min - torch.log1p(z), buffer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_maybe_load_decompositions()
if level is None:
level = _current_level
if level < 0:
raise RuntimeError(
"Trying to create a dual Tensor for forward AD but no level "
"exists, make sure to enter_dual_level() first."
)
if not (tensor.is_floating_point() or tensor.is_complex()):
raise ValueError(
f"Expected primal to be floating point or complex, but got: {tensor.dtype}"
)
if not (tangent.is_floating_point() or tangent.is_complex()):
raise ValueError(
f"Expected tangent to be floating point or complex, but got: {tangent.dtype}"
)
return torch._VF._make_dual(tensor, tangent, level=level)
_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"])
class UnpackedDualTensor(_UnpackedDualTensor):
r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor.
See :func:`unpack_dual` for more details.
"""
def unpack_dual(tensor, *, level=None):
r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient.
The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of
:attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is.
Neither of these tensors can be dual tensor of level :attr:`level`.
This function is backward differentiable.
Example::
>>> # xdoctest: +SKIP("Undefined variables")
>>> with dual_level():
... inp = make_dual(x, x_t)
... out = f(inp)
... y, jvp = unpack_dual(out)
... jvp = unpack_dual(out).tangent
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API.
"""
if level is None:
level = _current_level
if level < 0:
return UnpackedDualTensor(tensor, None)
primal, dual = torch._VF._unpack_dual(tensor, level=level)
return UnpackedDualTensor(primal, dual)
class dual_level(_DecoratorContextManager):
r"""Context-manager for forward AD, where all forward AD computation must occur within the ``dual_level`` context.
.. Note::
The ``dual_level`` context appropriately enters and exit the dual level to
controls the current forward AD level, which is used by default by the other
functions in this API.
We currently don't plan to support nested ``dual_level`` contexts, however, so
only a single forward AD level is supported. To compute higher-order
forward grads, one can use :func:`torch.func.jvp`.
Example::
>>> # xdoctest: +SKIP("Undefined variables")
>>> x = torch.tensor([1])
>>> x_t = torch.tensor([1])
>>> with dual_level():
... inp = make_dual(x, x_t)
... # Do computations with inp
... out = your_fn(inp)
... _, grad = unpack_dual(out)
>>> grad is None
False
>>> # After exiting the level, the grad is deleted
>>> _, grad_after = unpack_dual(out)
>>> grad is None
True
Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API.
"""
def __enter__(self):
return enter_dual_level()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
exit_dual_level()
# Private helper functions
_is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled
# Private helper function to enable or disable fwd grad.
# If you're a user and want to use this, please file an issue to discuss the use case.
class _set_fwd_grad_enabled(_DecoratorContextManager):
def __init__(self, mode: bool) -> None:
self.prev = _is_fwd_grad_enabled()
torch._C._set_fwd_grad_enabled(mode)
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_fwd_grad_enabled(self.prev)

View File

@ -0,0 +1,844 @@
# mypy: allow-untyped-defs
import functools
import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
from typing_extensions import deprecated
import torch
import torch._C as _C
import torch._functorch as _functorch
import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call
__all__ = [
"FunctionCtx",
"BackwardCFunction",
"FunctionMeta",
"Function",
"once_differentiable",
"InplaceFunction",
"NestedIOFunction",
]
# Unique id provider for each class inheriting from Function
# This is incremented in FunctionMeta during class definition
AUTOGRAD_FUNCTION_COUNTER = itertools.count()
# Formerly known as: _ContextMethodMixin
class FunctionCtx:
def save_for_backward(self, *tensors: torch.Tensor):
r"""Save given tensors for a future call to :func:`~Function.backward`.
``save_for_backward`` should be called at most once, in either the
:func:`setup_context` or :func:`forward` methods, and only with tensors.
All tensors intended to be used in the backward pass should be saved
with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent
incorrect gradients and memory leaks, and enable the application of saved
tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
Note that if intermediary tensors, tensors that are neither inputs
nor outputs of :func:`forward`, are saved for backward, your custom Function
may not support double backward.
Custom Functions that do not support double backward should decorate their
:func:`backward` method with ``@once_differentiable`` so that performing
double backward raises an error. If you'd like to support double backward,
you can either recompute intermediaries based on the inputs during backward
or return the intermediaries as the outputs of the custom Function. See the
`double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_
for more details.
In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
attribute. Before returning them to the user, a check is made to ensure
they weren't used in any in-place operation that modified their content.
Arguments can also be ``None``. This is a no-op.
See :ref:`extending-autograd` for more details on how to use this method.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Func(Function):
>>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>> w = x * z
>>> out = x * y + y * z + w * y
>>> ctx.save_for_backward(x, y, w, out)
>>> ctx.z = z # z is not a tensor
>>> return out
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, grad_out):
>>> x, y, w, out = ctx.saved_tensors
>>> z = ctx.z
>>> gx = grad_out * (y + y * z)
>>> gy = grad_out * (x + z + w)
>>> gz = None
>>> return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
"""
self.to_save = tensors
def save_for_forward(self, *tensors: torch.Tensor):
r"""Save given tensors for a future call to :func:`~Function.jvp`.
``save_for_forward`` should be called at most once, in either the
:func:`setup_context` or :func:`forward` methods, and all arguments
should be tensors.
In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`
attribute.
Arguments can also be ``None``. This is a no-op.
See :ref:`extending-autograd` for more details on how to use this method.
Example::
>>> # xdoctest: +SKIP
>>> class Func(torch.autograd.Function):
>>> @staticmethod
>>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>> ctx.save_for_backward(x, y)
>>> ctx.save_for_forward(x, y)
>>> ctx.z = z
>>> return x * y * z
>>>
>>> @staticmethod
>>> def jvp(ctx, x_t, y_t, _):
>>> x, y = ctx.saved_tensors
>>> z = ctx.z
>>> return z * (y * x_t + x * y_t)
>>>
>>> @staticmethod
>>> def vjp(ctx, grad_out):
>>> x, y = ctx.saved_tensors
>>> z = ctx.z
>>> return z * grad_out * y, z * grad_out * x, None
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> t = torch.tensor(1., dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>>
>>> with fwAD.dual_level():
>>> a_dual = fwAD.make_dual(a, t)
>>> d = Func.apply(a_dual, b, c)
"""
for tensor in tensors:
assert isinstance(tensor, torch.Tensor) or tensor is None, (
"save_for_forward expects all arguments to be tensors; you should "
"save non-tensors as attributes on ctx."
)
self.saved_for_forward = tensors
def mark_dirty(self, *args: torch.Tensor):
r"""Mark given tensors as modified in an in-place operation.
This should be called at most once, in either the :func:`setup_context`
or :func:`forward` methods, and all arguments should be inputs.
Every tensor that's been modified in-place in a call to :func:`forward`
should be given to this function, to ensure correctness of our checks.
It doesn't matter whether the function is called before or after
modification.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Inplace(Function):
>>> @staticmethod
>>> def forward(ctx, x):
>>> x_npy = x.numpy() # x_npy shares storage with x
>>> x_npy += 1
>>> ctx.mark_dirty(x)
>>> return x
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, grad_output):
>>> return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a) # This would lead to wrong gradients!
>>> # but the engine would not know unless we mark_dirty
>>> # xdoctest: +SKIP
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>> # computation has been modified by an inplace operation
"""
self.dirty_tensors = args
@deprecated(
"`mark_shared_storage` is deprecated. "
"Tensors with shared storages are automatically tracked. "
"Note that calls to `set_()` are not tracked",
category=FutureWarning,
)
def mark_shared_storage(self, *pairs):
pass
def mark_non_differentiable(self, *args: torch.Tensor):
r"""Mark outputs as non-differentiable.
This should be called at most once, in either the :func:`setup_context`
or :func:`forward` methods, and all arguments should be tensor outputs.
This will mark outputs as not requiring gradients, increasing the
efficiency of backward computation. You still need to accept a gradient
for each output in :meth:`~Function.backward`, but it's always going to
be a zero tensor with the same shape as the shape of a corresponding
output.
This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>> @staticmethod
>>> def forward(ctx, x):
>>> sorted, idx = x.sort()
>>> ctx.mark_non_differentiable(idx)
>>> ctx.save_for_backward(x, idx)
>>> return sorted, idx
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, g1, g2): # still need to accept g2
>>> x, idx = ctx.saved_tensors
>>> grad_input = torch.zeros_like(x)
>>> grad_input.index_add_(0, idx, g1)
>>> return grad_input
"""
self.non_differentiable = args
def set_materialize_grads(self, value: bool):
r"""Set whether to materialize grad tensors. Default is ``True``.
This should be called only from either the :func:`setup_context` or
:func:`forward` methods.
If ``True``, undefined grad tensors will be expanded to tensors full of zeros
prior to calling the :func:`backward` and :func:`jvp` methods.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class SimpleFunc(Function):
>>> @staticmethod
>>> def forward(ctx, x):
>>> return x.clone(), x.clone()
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, g1, g2):
>>> return g1 + g2 # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>> @staticmethod
>>> def forward(ctx, x):
>>> ctx.set_materialize_grads(False)
>>> ctx.save_for_backward(x)
>>> return x.clone(), x.clone()
>>>
>>> @staticmethod
>>> @once_differentiable
>>> def backward(ctx, g1, g2):
>>> x, = ctx.saved_tensors
>>> grad_input = torch.zeros_like(x)
>>> if g1 is not None: # We must check for None now
>>> grad_input += g1
>>> if g2 is not None:
>>> grad_input += g2
>>> return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a) # induces g2 to be undefined
"""
self.materialize_grads = value
# DO NOT USE: This is only defined to be able to load old serialized models
_ContextMethodMixin = FunctionCtx
class _HookMixin:
@staticmethod
def _register_hook(backward_hooks, hook):
if backward_hooks is None:
backward_hooks = OrderedDict()
handle = hooks.RemovableHandle(backward_hooks)
backward_hooks[handle.id] = hook
return backward_hooks, handle
class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
r"""
This class is used for internal autograd work. Do not use.
"""
def apply(self, *args):
r"""
Apply method used when executing this Node during the backward
"""
# _forward_cls is defined by derived class
# The user should define either backward or vjp but never both.
backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined]
if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
raise RuntimeError(
"Implementing both 'backward' and 'vjp' for a custom "
"Function is not allowed. You should only implement one "
"of them."
)
user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
return user_fn(self, *args)
def apply_jvp(self, *args):
r"""
Apply method used when executing forward mode AD during the forward
"""
# _forward_cls is defined by derived class
return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
def _compiled_autograd_key(self):
return self._forward_cls._compiled_autograd_key(self) # type: ignore[attr-defined]
class FunctionMeta(type):
"""Function metaclass.
This metaclass sets up the following properties:
_backward_cls: The Function class corresponding to the differentiated
version of this function (which is generated on the fly by this
metaclass).
"""
def __init__(cls, name, bases, attrs):
backward_fn = type(
name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
)
backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined]
backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined]
"_compiled_autograd_should_lift", True
)
cls._backward_cls = backward_fn
super().__init__(name, bases, attrs)
class _SingleLevelFunction(
_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
):
@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
r"""Define the forward of the custom autograd Function.
This function is to be overridden by all subclasses.
There are two ways to define forward:
Usage 1 (Combined forward and ctx)::
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
pass
- It must accept a context ctx as the first argument, followed by any
number of arguments (tensors or other types).
- See :ref:`combining-forward-context` for more details
Usage 2 (Separate forward and ctx)::
@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
pass
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
pass
- The forward no longer accepts a ctx argument.
- Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
staticmethod to handle setting up the ``ctx`` object.
``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
to the forward.
- See :ref:`extending-autograd` for more details
The context can be used to store arbitrary data that can be then
retrieved during the backward pass. Tensors should not be stored
directly on `ctx` (though this is not currently enforced for
backward compatibility). Instead, tensors should be saved either with
:func:`ctx.save_for_backward` if they are intended to be used in
``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
if they are intended to be used for in ``jvp``.
"""
raise NotImplementedError(
"You must implement the forward function for custom autograd.Function."
)
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
r"""There are two ways to define the forward pass of an autograd.Function.
Either:
1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
``setup_context`` is not overridden. Setting up the ctx for backward
happens inside the ``forward``.
2. Override forward with the signature ``forward(*args, **kwargs)`` and
override ``setup_context``. Setting up the ctx for backward happens
inside ``setup_context`` (as opposed to inside the ``forward``)
See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
"""
raise NotImplementedError("setup_context is not implemented.")
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
r"""Define a formula for differentiating the operation with backward mode automatic differentiation.
This function is to be overridden by all subclasses.
(Defining this function is equivalent to defining the ``vjp`` function.)
It must accept a context :attr:`ctx` as the first argument, followed by
as many outputs as the :func:`forward` returned (None will be passed in
for non tensor outputs of the forward function),
and it should return as many tensors, as there were inputs to
:func:`forward`. Each argument is the gradient w.r.t the given output,
and each returned value should be the gradient w.r.t. the
corresponding input. If an input is not a Tensor or is a Tensor not
requiring grads, you can just pass None as a gradient for that input.
The context can be used to retrieve tensors saved during the forward
pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
of booleans representing whether each input needs gradient. E.g.,
:func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
first input to :func:`forward` needs gradient computed w.r.t. the
output.
"""
raise NotImplementedError(
"You must implement either the backward or vjp method for "
"your custom autograd.Function to use it with backward "
"mode AD."
)
# vjp and backward are alias of each other
vjp = backward
@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
r"""Define a formula for differentiating the operation with forward mode automatic differentiation.
This function is to be overridden by all subclasses.
It must accept a context :attr:`ctx` as the first argument, followed by
as many inputs as the :func:`forward` got (None will be passed in
for non tensor inputs of the forward function),
and it should return as many tensors as there were outputs to
:func:`forward`. Each argument is the gradient w.r.t the given input,
and each returned value should be the gradient w.r.t. the
corresponding output. If an output is not a Tensor or the function is not
differentiable with respect to that output, you can just pass None as a
gradient for that input.
You can use the :attr:`ctx` object to pass any value from the forward to this
functions.
"""
raise NotImplementedError(
"You must implement the jvp function for custom "
"autograd.Function to use it with forward mode AD."
)
class Function(_SingleLevelFunction):
r"""Base class to create custom `autograd.Function`.
To create a custom `autograd.Function`, subclass this class and implement
the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom
op in the forward pass, call the class method ``apply``. Do not call
:meth:`forward` directly.
To ensure correctness and best performance, make sure you are calling the
correct methods on ``ctx`` and validating your backward function using
:func:`torch.autograd.gradcheck`.
See :ref:`extending-autograd` for more details on how to use this class.
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> class Exp(Function):
>>> @staticmethod
>>> def forward(ctx, i):
>>> result = i.exp()
>>> ctx.save_for_backward(result)
>>> return result
>>>
>>> @staticmethod
>>> def backward(ctx, grad_output):
>>> result, = ctx.saved_tensors
>>> return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> # xdoctest: +SKIP
>>> output = Exp.apply(input)
"""
def __init__(self, *args, **kwargs):
warnings.warn(
f"{self.__class__} should not be instantiated. Methods on autograd functions"
"are all static, so you should invoke them on the class itself. "
"Instantiating an autograd function will raise an "
"error in a future version of PyTorch.",
DeprecationWarning,
stacklevel=2,
)
def __call__(self, *args, **kwargs):
raise RuntimeError(
"Legacy autograd function with non-static forward method is deprecated. "
"Please use new-style autograd function with static forward method. "
"(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
)
"""
Bool that specifies if PyTorch should attempt to autogenerate
:func:`torch.vmap` support for this autograd.Function. You may set this to
True only if this autograd.Function's forward, backward, and jvp (if they
exist) are written using PyTorch operations; otherwise, please override
:meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.
Please see :ref:`func-autograd-function` for more details.
"""
generate_vmap_rule = False
@staticmethod
def vmap(info, in_dims, *args):
r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.
For a :func:`torch.autograd.Function` to support
:func:`torch.vmap`, you must either override this static method, or set
``generate_vmap_rule`` to ``True`` (you may not do both).
If you choose to override this staticmethod: it must accept
- an ``info`` object as the first argument. ``info.batch_size``
specifies the size of the dimension being vmapped over,
while ``info.randomness`` is the randomness option passed to
:func:`torch.vmap`.
- an ``in_dims`` tuple as the second argument.
For each arg in ``args``, ``in_dims`` has a corresponding
``Optional[int]``. It is ``None`` if the arg is not a Tensor or if
the arg is not being vmapped over, otherwise, it is an integer
specifying what dimension of the Tensor is being vmapped over.
- ``*args``, which is the same as the args to :meth:`~Function.forward`.
The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.
Similar to ``in_dims``, ``out_dims`` should be of the same structure as
``output`` and contain one ``out_dim`` per output that specifies if the
output has the vmapped dimension and what index it is in.
Please see :ref:`func-autograd-function` for more details.
"""
raise NotImplementedError(
"To use autograd.Function with vmap, you must either override the "
"vmap staticmethod or set generate_vmap_rule=True."
)
@classmethod
def apply(cls, *args, **kwargs):
def bind_default_args(func, *args, **kwargs):
signature = inspect.signature(func)
bound_args = signature.bind(*args, **kwargs)
bound_args.apply_defaults()
return bound_args.args
is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
if is_setup_ctx_defined:
args = bind_default_args(cls.forward, *args, **kwargs)
if not torch._C._are_functorch_transforms_active():
# See NOTE: [functorch vjp and autograd interaction]
args = _functorch.utils.unwrap_dead_wrappers(args)
return super().apply(*args, **kwargs) # type: ignore[misc]
if not is_setup_ctx_defined:
raise RuntimeError(
"In order to use an autograd.Function with functorch transforms "
"(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
"staticmethod. For more details, please see "
"https://pytorch.org/docs/main/notes/extending.func.html"
)
return custom_function_call(cls, *args, **kwargs)
@staticmethod
def _compiled_autograd_key(ctx):
return (ctx._autograd_function_id,)
def _is_setup_context_defined(fn):
return fn != _SingleLevelFunction.setup_context
def once_differentiable(fn):
@functools.wraps(fn)
def wrapper(ctx, *args):
with torch.no_grad():
outputs = fn(ctx, *args)
if not torch.is_grad_enabled():
return outputs
# If any of the inputs have requires_grad=True, we force the outputs
# to have requires_grad=True but point to a grad_fn which throws an
# error message during (double) back-propagation.
# XXX: this is only an approximation of requires_grad - there's no way
# to figure out if fn didn't use ctx.saved_tensors and as a result
# some Tensors might require grad, even if no args do.
# Unfortunately, this leads to unexpected error messages ("no nodes
# require computing gradients"), but I don't have a better idea.
# These functions would raise an error in backward anyway.
requires_grad = any(
isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
)
if not requires_grad:
return outputs
if not isinstance(outputs, tuple):
outputs = (outputs,)
err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked "
b"with @once_differentiable",
len(outputs),
)
# Create aliases of each output that has requires_grad=True. We need
# at least one of the inputs to err_fn to require grad so that the
# output will have a grad_fn.
def fake_requires_grad(var):
if var is not None:
var = var.detach()
var.requires_grad = True
return var
return err_fn(*[fake_requires_grad(v) for v in outputs])
return wrapper
class InplaceFunction(Function):
r"""
This class is here only for backward compatibility reasons.
Use :class:`Function` instead of this for any new use case.
"""
def __init__(self, inplace=False):
super().__init__()
self.inplace = inplace
def _nested_map(condition, fn, condition_msg=None):
def _map(obj):
if condition(obj):
return fn(obj)
elif obj is None:
return None
elif isinstance(obj, (list, tuple)):
mapped = (_map(x) for x in obj)
if hasattr(obj, "_fields"):
# obj is namedtuple
return type(obj)(*mapped)
return type(obj)(mapped)
elif isinstance(obj, dict):
return {x: _map(obj[x]) for x in obj}
else:
raise ValueError(
"Auto nesting doesn't know how to process "
"an input object of type "
+ torch.typename(obj)
+ (
". Accepted types: " + condition_msg + ", or lists/tuples of them"
if condition_msg
else ""
)
)
return _map
def _jit_unwrap_structured(obj):
if hasattr(obj, "_jit_unwrap"):
return obj._jit_unwrap()
return obj
def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
def _iter(obj):
if conversion is not None:
obj = conversion(obj)
if condition(obj):
yield obj
elif obj is None:
return
elif isinstance(obj, (list, tuple)):
for o in obj:
yield from _iter(o)
elif isinstance(obj, dict):
# We only accept primitive key types, so we needn't inspect them
for o in obj.values():
yield from _iter(o)
elif allow_unknown:
yield obj
else:
raise ValueError(
"Auto nesting doesn't know how to process "
"an input object of type "
+ torch.typename(obj)
+ (
". Accepted types: " + condition_msg + ", or lists/tuples of them"
if condition_msg
else ""
)
)
return _iter
def _unflatten(input, proto):
# unflatten a list or tuple input into a nested list/tuple structure
# specified by proto
def unflatten_helper(input, proto):
res: List[Optional[torch.Tensor]] = []
if hasattr(proto, "_jit_wrap"):
return proto._jit_wrap(input)
if not isinstance(proto, (list, tuple)):
return input[0], input[1:]
for e in proto:
if e is None:
res.append(e)
else:
res_e, input = unflatten_helper(input, e)
res.append(res_e)
return type(proto)(res), input
return unflatten_helper(input, proto)[0]
_iter_jit_values = _iter_filter(
lambda o: o is None or isinstance(o, torch._C.Value),
condition_msg="jit's Values or None",
)
_iter_tensors = _iter_filter(
lambda x: isinstance(x, torch.Tensor),
condition_msg="Tensors",
conversion=_jit_unwrap_structured,
)
_iter_tensors_permissive = _iter_filter(
lambda x: isinstance(x, torch.Tensor),
allow_unknown=True,
condition_msg="Tensors (permissive)",
)
_iter_None_tensors = _iter_filter(
lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
)
_map_tensor_data = _nested_map(
lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
)
class NestedIOFunction(Function):
r"""
This class is here only for backward compatibility reasons.
Use :class:`Function` instead of this for any new use case.
"""
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
def _do_forward(self, *input):
self._nested_input = input
flat_input = tuple(_iter_tensors(input))
flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
nested_output = self._nested_output
nested_tensors = _unflatten(flat_output, self._nested_output)
return nested_tensors
def _do_backward(self, gradients, retain_variables):
self.retain_variables = retain_variables
result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
if not retain_variables:
del self._nested_output
del self._to_save_nested
return result
def backward(self, *gradients: Any) -> Any: # type: ignore[override]
r"""
Shared backward utility.
"""
nested_gradients = _unflatten(gradients, self._nested_output)
result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
return tuple(_iter_None_tensors(result))
__call__ = _do_forward
def forward(self, *args: Any) -> Any: # type: ignore[override]
r"""
Shared forward utility.
"""
nested_tensors = _map_tensor_data(self._nested_input)
result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
del self._nested_input
self._nested_output = result
return tuple(_iter_tensors(result))
def save_for_backward(self, *args: Any) -> None:
r"""
See :meth:`Function.save_for_backward`.
"""
self.to_save = tuple(_iter_tensors(args))
self._to_save_nested = args
@property
def saved_tensors(self):
r"""
See :meth:`Function.saved_tensors`.
"""
flat_tensors = super().saved_tensors # type: ignore[misc]
return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
r"""
See :meth:`Function.mark_dirty`.
"""
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
r"""
See :meth:`Function.mark_non_differentiable`.
"""
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
def forward_extended(self, *input: Any) -> None:
r"""
User defined forward.
"""
raise NotImplementedError
def backward_extended(self, *grad_output: Any) -> None:
r"""
User defined backward.
"""
raise NotImplementedError

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,397 @@
# mypy: allow-untyped-defs
from typing import Any
import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
"no_grad",
"enable_grad",
"set_grad_enabled",
"inference_mode",
"set_multithreading_enabled",
]
class no_grad(_NoParamDecoratorContextManager):
r"""Context-manager that disables gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
There is an exception! All factory functions, or functions that create
a new Tensor and take a requires_grad kwarg, will NOT be affected by
this mode.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
No-grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
If you want to disable forward AD for a computation, you can unpack
your dual tensors.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad()
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # factory function exception
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
class enable_grad(_NoParamDecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
or :class:`~set_grad_enabled`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
enable_grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
tensor([2.])
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad()
... def tripler(x):
... return x * 3
>>> with torch.no_grad():
... z = tripler(x)
>>> z.requires_grad
True
"""
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable grad (``True``), or disable
(``False``). This can be used to conditionally enable
gradients.
.. note::
set_grad_enabled is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> _ = torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> _ = torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
"""
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
self.mode = mode
torch._C._set_grad_enabled(mode)
def __call__(self, orig_func: F) -> F:
torch._C._set_grad_enabled(self.prev)
return super().__call__(orig_func)
def __enter__(self) -> None:
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
def clone(self) -> "set_grad_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class inference_mode(_DecoratorContextManager):
r"""Context-manager that enables or disables inference mode.
InferenceMode is a context manager analogous to :class:`~no_grad`
to be used when you are certain your operations will have no interactions
with autograd (e.g., model training). Code run under this mode gets better
performance by disabling view tracking and version counter bumps. Note that
unlike some other mechanisms that locally enable or disable grad,
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
Inference mode is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
Args:
mode (bool or function): Either a boolean flag whether to enable or
disable inference mode or a Python function to decorate with
inference mode enabled
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> with torch.inference_mode():
... y = x * x
>>> y.requires_grad
False
>>> # xdoctest: +SKIP("want string isnt quite right")
>>> y._version
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Inference tensors do not track version counter.
>>> @torch.inference_mode()
... def func(x):
... return x * x
>>> out = func(x)
>>> out.requires_grad
False
>>> @torch.inference_mode()
... def doubler(x):
... return x * 2
>>> out = doubler(x)
>>> out.requires_grad
False
"""
def __init__(self, mode: bool = True) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.mode = mode
def __new__(cls, mode=True):
if isinstance(mode, bool):
return super().__new__(cls)
return cls()(mode)
def __enter__(self) -> None:
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
def clone(self) -> "inference_mode":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
def _enter_inference_mode(mode):
mode_context = torch._C._InferenceMode(mode)
mode_context.__enter__()
return mode_context
def _exit_inference_mode(mode):
mode.__exit__(None, None, None)
class set_multithreading_enabled(_DecoratorContextManager):
r"""Context-manager that sets multithreaded backwards on or off.
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
(``False``).
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_multithreading_enabled()
torch._C._set_multithreading_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_multithreading_enabled(self.prev)
def clone(self) -> "set_multithreading_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class _force_original_view_tracking(_DecoratorContextManager):
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
When a tensor view is mutated, the autograd engine needs to decide whether or not
to regenerate the "updated view" by either replaying the chain of views from the updated base,
or with a single call to as_strided.
If set_view_replay_enabled is set to True, then autograd will always use view replay.
Otherwise, it will fall back to its existing logic.
Args:
mode (bool): Flag whether to enable view-replay (``True``), or disable
(``False``).
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_view_replay_enabled()
torch._C._set_view_replay_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_view_replay_enabled(self.prev)
def clone(self):
return self.__class__(self.mode)
class _unsafe_preserve_version_counter(_DecoratorContextManager):
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
(even the ones not touched directly by the context manager)!
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
and error out in this situation.
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
the tensor right before it is needed by autograd.
Args:
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor
self.prev_version = tensor._version
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,830 @@
import abc
import contextlib
import functools
import logging
import threading
from collections import defaultdict, deque
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
MutableMapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeAlias
from weakref import WeakKeyDictionary, WeakValueDictionary
import torch
from torch.autograd.variable import Variable
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.hooks import RemovableHandle
if TYPE_CHECKING:
from torch._ops import OpOverload
__all__ = [
"saved_tensors_hooks",
"save_on_cpu",
"disable_saved_tensors_hooks",
"register_multi_grad_hook",
"allow_mutation_on_saved_tensors",
"Node",
"GradientEdge",
"get_gradient_edge",
"increment_version",
]
log = logging.getLogger(__name__)
class Node(abc.ABC):
@abc.abstractmethod
def name(self) -> str:
r"""Return the name.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> print(b.grad_fn.name())
CloneBackward0
"""
raise NotImplementedError
@property
@abc.abstractmethod
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
raise NotImplementedError
@abc.abstractmethod
def metadata(self) -> dict:
r"""Return the metadata."""
raise NotImplementedError
@property
@abc.abstractmethod
def _input_metadata(self) -> List[Any]:
raise NotImplementedError
@abc.abstractmethod
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
raise NotImplementedError
@abc.abstractmethod
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Register a backward hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_inputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
.. note::
In the rare case where the hook is registered while the Node has already
begun execution, there is no longer any guarantee on :attr:`grad_outputs`
content (it might be as usual or empty depending on other factors). The
hook can still optionally return a new gradient to be used in place of
:attr:`grad_inputs` independent of :attr:`grad_outputs`.
Example::
>>> import torch
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove() # Removes the hook
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
raise NotImplementedError
@abc.abstractmethod
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
r"""Register a backward pre-hook.
The hook will be called every time a gradient with respect to the
Node is computed. The hook should have the following signature::
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad_outputs`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
>>> b = a.clone()
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([2., 2., 2.])
>>> handle.remove()
>>> a.grad = None
>>> b.sum().backward(retain_graph=True)
>>> print(a.grad)
tensor([1., 1., 1.])
"""
raise NotImplementedError
@classmethod
def __subclasshook__(cls, subclass: type) -> bool:
if cls is Node and (
(
subclass is not None
and subclass is getattr(torch._C._functions, subclass.__name__, None)
)
or issubclass(subclass, torch.autograd.function.BackwardCFunction)
):
return True
return NotImplemented
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
if isinstance(t, GradientEdge):
return t.node
if t.requires_grad and t.grad_fn is None:
with torch.enable_grad():
node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
else:
node = t.grad_fn
assert node is not None
return node
class GradientEdge(NamedTuple):
"""Object representing a given gradient edge within the autograd graph.
To get the gradient edge where a given Tensor gradient will be computed,
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
"""
node: Node
output_nr: int
def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
"""Get the gradient edge for computing the gradient of the given Tensor.
In particular, it is equivalent to call
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
"""
if not tensor.requires_grad:
raise RuntimeError(
"It is not possible to get the gradient edge for a Tensor "
"that does not require gradients",
)
grad_fn = _get_grad_fn_or_grad_acc(tensor)
# Note that output_nr default to 0 which is the right value
# for the AccumulateGrad node.
return GradientEdge(grad_fn, tensor.output_nr)
def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
"""Update autograd metadata tracking whether the given Tensor was modified in place.
This is to enable more accurate error checking within the autograd engine.
It is already done automatically by PyTorch functions and within custom Function
when mark_dirty() is called appropriately so you only need to call this explicitly
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
Note that incrementing the version counter multiple times for a single inplace operation
is not problematic.
Note that if you pass in tensor constructed under torch.inference_mode(),
we will not bump its version counter (because your tensor does not have one).
"""
if isinstance(tensor, torch.Tensor):
tensor = (tensor,)
torch._C._increment_version(tensor)
class saved_tensors_hooks:
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
Use this context-manager to define how intermediary results of an operation
should be packed before saving, and unpacked on retrieval.
In that context, the ``pack_hook`` function will be called everytime an
operation saves a tensor for backward (this includes intermediary results
saved using
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
also those recorded by a PyTorch-defined operation). The output of
``pack_hook`` is then stored in the computation graph instead of the
original tensor.
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
namely when executing :func:`torch.Tensor.backward()` or
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
returned by ``pack_hook`` and should return a tensor which has the same
content as the original tensor (passed as input to the corresponding
``pack_hook``).
The hooks should have the following signatures:
pack_hook(tensor: Tensor) -> Any
unpack_hook(Any) -> Tensor
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
of value, size, dtype and device.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> def pack_hook(x):
... print("Packing", x)
... return x
>>>
>>> def unpack_hook(x):
... print("Unpacking", x)
... return x
>>>
>>> a = torch.ones(5, requires_grad=True)
>>> b = torch.ones(5, requires_grad=True) * 2
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
... y = a * b
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
>>> y.sum().backward()
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
.. warning ::
Performing an inplace operation on the input to either hooks may lead
to undefined behavior.
.. warning ::
Only one pair of hooks is allowed at a time. When recursively nesting this
context-manager, only the inner-most pair of hooks will be applied.
"""
def __init__(
self,
pack_hook: Callable[[torch.Tensor], Any],
unpack_hook: Callable[[Any], torch.Tensor],
) -> None:
self.pack_hook = pack_hook
self.unpack_hook = unpack_hook
def __enter__(self) -> None:
torch._C._autograd._push_saved_tensors_default_hooks(
self.pack_hook, self.unpack_hook
)
def __exit__(self, *args: object) -> None:
torch._C._autograd._pop_saved_tensors_default_hooks()
class save_on_cpu(saved_tensors_hooks):
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
When performing operations within this context manager, intermediary
results saved in the graph during the forward pass will be moved to CPU,
then copied back to the original device when needed for the backward pass.
If the graph was already on CPU, no tensor copy is performed.
Use this context-manager to trade compute for GPU memory usage (e.g.
when your model doesn't fit in GPU memory during training).
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>>
>>> def f(a, b, c):
... prod_1 = a * b # a and b are saved on GPU
... with torch.autograd.graph.save_on_cpu():
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
... y = prod_2 * a # prod_2 and a are saved on GPU
... return y
>>>
>>> y = f(a, b, c)
>>> del a, b, c # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward
"""
def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
device_module = getattr(torch, device_type, torch.cuda)
def pack_to_cpu(tensor: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
if not pin_memory:
return (tensor.device, tensor.cpu())
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(device_module.is_available() and not tensor.is_sparse),
)
packed.copy_(tensor)
return (tensor.device, packed)
def unpack_from_cpu(packed: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
super().__init__(pack_to_cpu, unpack_from_cpu)
@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]:
"""Context-manager that disables the saved tensors default hooks feature.
Useful for if you are creating a feature that does not work with saved
tensors default hooks.
Args:
error_message (str): When saved tensors default hooks are used when they
have been are disabled, a RuntimeError with this
error message gets raised.
Example::
>>> # xdoctest: +SKIP(failing)
>>> message = "saved tensors default hooks are disabled"
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
... # Raises RuntimeError: saved tensors default hooks are disabled
... with torch.autograd.graph.save_on_cpu():
... pass
"""
maybe_prev_message = None
try:
maybe_prev_message = (
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
)
torch._C._autograd._saved_tensors_hooks_disable(error_message)
yield
finally:
# See NOTE: [disabled_error_message invariant]
if maybe_prev_message is None:
torch._C._autograd._saved_tensors_hooks_enable()
else:
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
class _MultiHandle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]
def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
self.handles = handles
def remove(self) -> None:
for handle in self.handles:
handle.remove()
def __getstate__(self) -> Tuple[RemovableHandle, ...]:
return self.handles
def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
self.handles = state
def register_multi_grad_hook(
tensors: Sequence[torch.Tensor],
fn: Union[
Callable[[Sequence[Optional[torch.Tensor]]], None],
Callable[[torch.Tensor], None],
],
*,
mode: Literal["all", "any"] = "all",
) -> RemovableHandle:
r"""Register a multi-grad backward hook.
There are two supported modes: ``"all"`` and ``"any"``.
Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
is not part of the graph, or if a tensor is not needed to compute the gradients
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
this tensor will be ignored and the hook will not wait for its gradient to be
computed.
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
called with those gradients. ``None`` will be passed for tensors that did not
have their gradients computed.
Under the ``"any"`` mode, the hook will be called after the first gradient
with respect to a tensor in :attr:`tensors` has been computed. The hook
will be called with that gradient as its argument.
The hook should not modify its arguments.
This function returns a handle with a method ``handle.remove()`` that removes the hook.
.. note::
See :ref:`backward-hooks-execution` for more information on how when this hook
is executed, and how its execution is ordered relative to other hooks.
Example::
>>> import torch
>>>
>>> a = torch.rand(2, 3, requires_grad=True)
>>> b = torch.rand(2, 3, requires_grad=True)
>>> c = a * b
>>> d = a * b
>>>
>>> def fn(grads):
... print([g is not None for g in grads])
...
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
>>>
>>> c.sum().backward(retain_graph=True)
[True, True, True, False]
>>> c.sum().backward(inputs=(a,), retain_graph=True)
[True, False, True, False]
>>>
"""
supported_modes = ("all", "any")
lock = threading.Lock()
if mode not in supported_modes:
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
if mode == "all":
count: Dict[int, int] = {}
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = {}
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)
def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]:
def inner_hook(grad: torch.Tensor) -> None:
nonlocal count, nb_calls, buffer, fn
id = torch._C._current_graph_task_id()
assert (
id != -1
), "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
buffer[id] = buffer.get(id, [None] * len_tensors)
with lock:
curr_count, count[id] = count[id], count[id] + 1
if curr_count == 0:
# On the first call, compute the actual nb_calls and buffer
nb_calls = sum(
map(torch._C._will_engine_execute_node, grad_fns)
)
buffer[id][idx] = grad
assert nb_calls is not None
if curr_count == nb_calls - 1:
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
fn(buffer[id])
del count[id]
del buffer[id]
return inner_hook
handles = tuple(
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
)
elif mode == "any":
fn = cast(Callable[[torch.Tensor], None], fn)
ran_hook: Dict[int, bool] = defaultdict(bool)
@functools.wraps(fn)
def wrapped_fn(grad: torch.Tensor) -> None:
nonlocal ran_hook
id = torch._C._current_graph_task_id()
assert id != -1, "expected this hook to be called inside a backward call"
with lock:
prev, ran_hook[id] = ran_hook[id], True
if prev:
return
fn(grad)
handles = tuple(
tensor.register_hook(wrapped_fn)
for tensor in tensors
if tensor.requires_grad
)
return _MultiHandle(handles) # type: ignore[possibly-undefined]
# NOTE [Allow mutation on tensors saved for backward]
#
# 1. Tensor gets saved for backward
# - remember the python object id and the version of the tensor
# - remember aliasing information (data_ptr of base + version)
# - save the original so we control its lifetime
# 2. Any time a tensor gets in-placed
# - for each tensor aliased to it:
# - check using its object id and version to see if it has been saved
# - if it has been saved, clone it
# - delete the reference to the original
# 3. during backward
# - if the clone exists, the tensor must've been modified in-place
_allow_mutation_on_saved_tensors_enabled: bool = False
_TID: TypeAlias = Tuple[int, int, int]
_SID: TypeAlias = Tuple[int, int]
def _get_tid(tensor: torch.Tensor) -> _TID:
# FIXME: This is almost definitely a bug.
if isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = tensor.data_ptr()
return (id(tensor), data_ptr, tensor._version)
def _get_sid(tensor: torch.Tensor) -> _SID:
# FIXME: This is almost definitely a bug.
if isinstance(
tensor,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = tensor.data_ptr()
return (data_ptr, tensor._version)
class _Handle:
pass
class _swap_with_cloned(saved_tensors_hooks):
def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
def pack_hook(tensor: torch.Tensor) -> _Handle:
tid = _get_tid(tensor)
sid = _get_sid(tensor)
# Tensors saved for backward have an entry in _tid_to_weakhandle
handle: Optional[_Handle] = None
# Save aliasing information
ctx.sid_to_tid[sid].add(tid)
# NB: The same tensor (of the same version) can be saved multiple times
if tid not in ctx.tid_to_weakhandle:
handle = _Handle()
ctx.tid_to_weakhandle[tid] = handle
ctx.original[handle] = tensor
else:
# Store an additional strong reference to the handle
handle = ctx.tid_to_weakhandle[tid]
return handle
def unpack_hook(handle: _Handle) -> torch.Tensor:
error_msg = (
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
"in which the graph was originally recorded."
)
assert _allow_mutation_on_saved_tensors_enabled, error_msg
if handle in ctx.cloned:
res = ctx.cloned[handle]
else:
assert handle in ctx.original, error_msg
res = ctx.original[handle]
return res
super().__init__(pack_hook, unpack_hook)
class _CloneArgBeforeMutateMode(TorchDispatchMode):
def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
self.ctx = ctx
def __torch_dispatch__(
self,
func: "OpOverload",
types: Iterable[type],
args: Tuple[Any, ...] = (),
kwargs: Optional[Dict[Any, Any]] = None,
) -> Any:
kwargs = kwargs or {}
for idx, arg in enumerate(func._schema.arguments):
if arg.alias_info is not None and arg.alias_info.is_write:
t = kwargs["out"] if arg.is_out else args[idx]
tid = _get_tid(t)
sid = _get_sid(t)
ctx = self.ctx
if sid in ctx.sid_to_tid:
for tid in ctx.sid_to_tid[sid]:
if tid not in ctx.tid_to_weakhandle:
# We know that if tid is in sid_to_tid, then it must also be in
# tid_to_weakhandle. However, it is possible for the tensor to be
# saved at one point, but cleared by backward before it is modified
# in-place. Consider the following example:
#
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
# >>> out = (a**2).sum()
# >>> out.backward()
# >>> a.sin_()
continue
handle = ctx.tid_to_weakhandle[tid]
if handle in ctx.cloned:
# The same exact tensor has been cloned already
continue
ctx.cloned[handle] = ctx.original[handle].clone()
del ctx.original[handle]
return func(*args, **kwargs)
class _AllowMutationOnSavedContext:
def __init__(self) -> None:
self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
self.sid_to_tid: Dict[_SID, Set[_TID]] = defaultdict(set)
def clear(self) -> None:
self.cloned.clear()
self.original.clear()
self.tid_to_weakhandle.clear()
self.sid_to_tid.clear()
@contextlib.contextmanager
def allow_mutation_on_saved_tensors() -> (
Generator[_AllowMutationOnSavedContext, None, None]
):
"""Context manager under which mutating tensors saved for backward is allowed.
Under this context manager, tensors saved for backward are cloned on mutation,
so the original version can still be used during backward. Normally, mutating a tensor
saved for backward will result in an error raised when it's used during backward.
To ensure the correct behavior, both the forward and backward should be run under
the same context manager.
Returns:
An _AllowMutationOnSavedContext object storing the state managed by this
context manager. This object can be useful for debugging purposes. The state
managed by the context manager is automatically cleared upon exiting.
Example::
>>> import torch
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
... # forward
... a = torch.ones(2, 3, requires_grad=True)
... b = a.clone()
... out = (b**2).sum()
... b.sin_()
... # backward
... out.sum().backward()
...
tensor([[0.8415, 0.8415, 0.8415],
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
"""
global _allow_mutation_on_saved_tensors_enabled
ctx = _AllowMutationOnSavedContext()
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
try:
if _allow_mutation_on_saved_tensors_enabled:
raise RuntimeError(
"allow_mutation_on_saved_tensors contexts cannot be nested"
)
_allow_mutation_on_saved_tensors_enabled = True
yield ctx
finally:
ctx.clear()
_allow_mutation_on_saved_tensors_enabled = False
def _register_logging_hooks_on_whole_graph(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
) -> Callable[[], None]:
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
def iter_graph(roots: List[Node]) -> Iterator[Node]:
if not roots:
return
seen: Set[Node] = set()
q: Deque[Node] = deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _ in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def fmt(t: Optional[torch.Tensor]) -> str:
# Avoid circular import
from torch.testing._internal.common_utils import dtype_abbrs
if t is None:
return "None"
return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None:
node = torch._C._current_autograd_node()
grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
log.debug(log_str)
handles = []
for node in iter_graph(grad_fns):
handles.append(node.register_prehook(prehook))
def unregister_hooks() -> None:
for handle in handles:
handle.remove()
return unregister_hooks
def _engine_run_backward(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> Tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
finally:
if attach_logging_hooks:
unregister_hooks() # type: ignore[possibly-undefined]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,314 @@
# mypy: allow-untyped-defs
import itertools
import warnings
from typing_extensions import deprecated
import torch
import torch.cuda
from torch.autograd import (
_disable_profiler_legacy,
_enable_profiler_legacy,
DeviceType,
ProfilerConfig,
ProfilerState,
)
from torch.autograd.profiler_util import (
_filter_name,
_filter_stack_entry,
_rewrite_name,
EventList,
FunctionEvent,
MEMORY_EVENT_NAME,
)
__all__ = ["profile"]
@deprecated(
"`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. "
"Please use `torch.profiler` instead.",
category=None, # TODO: change to `FutureWarning`
)
class profile:
"""DEPRECATED: use torch.profiler instead."""
def __init__(
self,
enabled=True,
*,
use_cuda=False,
record_shapes=False,
with_flops=False,
profile_memory=False,
with_stack=False,
with_modules=False,
):
self.enabled: bool = enabled
if not self.enabled:
return
self.use_cuda = use_cuda
self.function_events = None
self.entered = False
self.record_shapes = record_shapes
self.with_flops = with_flops
self.record_shapes |= self.with_flops
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_modules = with_modules
if self.use_cuda and not torch.cuda.is_available():
warnings.warn(
"CUDA is not available, disabling CUDA profiling",
stacklevel=2,
)
self.use_cuda = False
if self.use_cuda:
self.profiler_kind = ProfilerState.CUDA
else:
self.profiler_kind = ProfilerState.CPU
def config(self):
return ProfilerConfig(
self.profiler_kind,
self.record_shapes,
self.profile_memory,
self.with_stack,
self.with_flops,
self.with_modules,
# avoid exposing _ExperimentalConfig this in legacy public API
torch._C._profiler._ExperimentalConfig(),
)
def __enter__(self):
if not self.enabled:
return
if self.entered:
raise RuntimeError("Profiler context manager is not reentrant")
self.entered = True
self._start_trace()
return self
def _start_trace(self):
_enable_profiler_legacy(self.config())
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
if self.use_cuda:
torch.cuda.synchronize()
records = _disable_profiler_legacy()
parsed_results = _parse_legacy_records(records)
self.function_events = EventList(
parsed_results,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
)
self.function_events._build_tree()
return False
def __repr__(self):
if self.function_events is None:
return "<unfinished profiler_legacy.profile>"
return repr(self.function_events)
def __str__(self):
if self.function_events is None:
return "<unfinished profile.profiler_legacy.profile>"
return str(self.function_events)
def _check_finish(self):
if self.function_events is None:
raise RuntimeError("Profiler didn't finish running")
def table(
self,
sort_by=None,
row_limit=100,
max_src_column_width=75,
max_name_column_width=55,
max_shapes_column_width=80,
header=None,
top_level_events_only=False,
):
self._check_finish()
assert self.function_events is not None
return self.function_events.table(
sort_by=sort_by,
row_limit=row_limit,
max_src_column_width=max_src_column_width,
max_name_column_width=max_name_column_width,
max_shapes_column_width=max_shapes_column_width,
header=header,
top_level_events_only=top_level_events_only,
)
table.__doc__ = EventList.table.__doc__
def export_chrome_trace(self, path):
self._check_finish()
assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
assert self.with_stack, "export_stacks() requires with_stack=True"
return self.function_events.export_stacks(path, metric)
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
key_averages.__doc__ = EventList.key_averages.__doc__
def total_average(self):
self._check_finish()
assert self.function_events is not None, "Expected profiling results"
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__
@property
def self_cpu_time_total(self):
"""Return CPU time as the sum of self times across all events."""
self._check_finish()
assert self.function_events is not None
return self.function_events.self_cpu_time_total
def _parse_legacy_records(thread_records):
def _get_record_key(record):
"""Return a tuple for correlating start and end records in `_parse_legacy_records`."""
return (record.handle(), record.node_id())
next_id = 0
start_record = None
functions = []
record_stack = []
# '__start_profile' is not guaranteed to be first, so we must find it here
for record in itertools.chain.from_iterable(thread_records):
name = record.name()
if start_record is None and name == "__start_profile":
start_record = record
assert start_record is not None and not start_record.is_remote()
for thread_record_list in thread_records:
# accumulated memory allocations per handle
cpu_memory_allocs = {}
cuda_memory_allocs = {}
# ranges per handle
range_starts = {}
filtered_handles = set()
prev_record = None
for record in thread_record_list:
record_key = _get_record_key(record)
if _filter_name(record.name()) or record_key in filtered_handles:
filtered_handles.add(record_key)
continue
if record.kind() == "push":
# workaround to reduce double logging from operator
# wrappers and redispatch
if prev_record is not None:
duplicate = (
prev_record.name() == record.name()
and prev_record.kind() == record.kind()
and prev_record.node_id() == record.node_id()
)
if duplicate:
filtered_handles.add(record_key)
continue
range_starts[record_key] = record
cpu_memory_allocs[record_key] = 0
cuda_memory_allocs[record_key] = 0
elif record.kind() == "pop":
assert (
record_key in range_starts
), f"""Expected record with key {record_key} to exist in range_starts.
This means that the pop event did not have a corresponding push."""
start = range_starts[record_key]
cpu_memory_usage = cpu_memory_allocs[record_key]
cuda_memory_usage = cuda_memory_allocs[record_key]
is_async = start.is_async() or (start.thread_id() != record.thread_id())
is_remote_event = record.is_remote()
start_flops = start.flops()
fe = FunctionEvent(
id=record.handle(),
node_id=record.node_id(),
name=_rewrite_name(name=start.name(), with_wildcard=True),
trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
thread=start.thread_id(),
start_us=start_record.cpu_elapsed_us(start),
end_us=start_record.cpu_elapsed_us(record),
fwd_thread=start.fwd_thread_id(),
input_shapes=start.shapes(),
stack=[
entry for entry in start.stack() if _filter_stack_entry(entry)
],
scope=start.scope(),
use_device="cuda" if start.has_cuda() else None,
cpu_memory_usage=cpu_memory_usage,
device_memory_usage=cuda_memory_usage,
is_async=is_async,
is_remote=is_remote_event,
sequence_nr=start.sequence_nr(),
device_type=DeviceType.CPU,
is_legacy=True,
flops=start_flops,
)
# note: async events have only cpu total time
if not is_async and start.has_cuda():
duration = start.cuda_elapsed_us(record)
if duration > 0:
fe.append_kernel(start.name(), start.device(), duration)
functions.append(fe)
del range_starts[record_key]
del cpu_memory_allocs[record_key]
del cuda_memory_allocs[record_key]
elif record.kind() == "memory_alloc":
num_open_handles_cpu = len(cpu_memory_allocs)
num_open_handles_cuda = len(cuda_memory_allocs)
assert num_open_handles_cpu == num_open_handles_cuda
for handle in cpu_memory_allocs.keys():
cpu_memory_allocs[handle] += record.cpu_memory_usage()
for handle in cuda_memory_allocs.keys():
cuda_memory_allocs[handle] += record.cuda_memory_usage()
if num_open_handles_cpu == 0:
# output event as a top-level memory event
fe = FunctionEvent(
id=0,
name=MEMORY_EVENT_NAME,
trace_name=None,
thread=0,
start_us=0,
end_us=0,
stack=[],
cpu_memory_usage=record.cpu_memory_usage(),
device_memory_usage=record.cuda_memory_usage(),
is_legacy=True,
)
functions.append(fe)
prev_record = record
# Sort functions by start time then by end time ascending.
# This ensures that--in the case of nested events which
# have the same start time (which may happen due to the
# granularity of the given clock tick)--we always show
# the outermost nested call first. This adds stability
# in how FunctionEvents appear
functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
return functions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,15 @@
# mypy: allow-untyped-defs
import torch
from torch._C import _ImperativeEngine as ImperativeEngine
__all__ = ["VariableMeta", "Variable"]
class VariableMeta(type):
def __instancecheck__(cls, other):
return isinstance(other, torch.Tensor)
class Variable(torch._C._LegacyVariableBase, metaclass=VariableMeta): # type: ignore[misc]
_execution_engine = ImperativeEngine()