I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from torch._C import FileCheck as FileCheck
from . import _utils
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,269 @@
"""
This module contains tensor creation utilities.
"""
import collections.abc
import math
import warnings
from typing import cast, List, Optional, Tuple, Union
import torch
_INTEGRAL_TYPES = [
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint16,
torch.uint32,
torch.uint64,
]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
_FLOATING_8BIT_TYPES = [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
# uniform_ requires to-from <= std::numeric_limits<scalar_t>::max()
# Work around this by scaling the range before and after the PRNG
if high - low >= torch.finfo(t.dtype).max:
return t.uniform_(low / 2, high / 2).mul_(2)
else:
return t.uniform_(low, high)
def make_tensor(
*shape: Union[int, torch.Size, List[int], Tuple[int, ...]],
dtype: torch.dtype,
device: Union[str, torch.device],
low: Optional[float] = None,
high: Optional[float] = None,
requires_grad: bool = False,
noncontiguous: bool = False,
exclude_zero: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> torch.Tensor:
r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
values uniformly drawn from ``[low, high)``.
If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
finite values then they are clamped to the lowest or highest representable finite value, respectively.
If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
which depend on :attr:`dtype`.
+---------------------------+------------+----------+
| ``dtype`` | ``low`` | ``high`` |
+===========================+============+==========+
| boolean type | ``0`` | ``2`` |
+---------------------------+------------+----------+
| unsigned integral type | ``0`` | ``10`` |
+---------------------------+------------+----------+
| signed integral types | ``-9`` | ``10`` |
+---------------------------+------------+----------+
| floating types | ``-9`` | ``9`` |
+---------------------------+------------+----------+
| complex types | ``-9`` | ``9`` |
+---------------------------+------------+----------+
Args:
shape (Tuple[int, ...]): Single integer or a sequence of integers defining the shape of the output tensor.
dtype (:class:`torch.dtype`): The data type of the returned tensor.
device (Union[str, torch.device]): The device of the returned tensor.
low (Optional[Number]): Sets the lower limit (inclusive) of the given range. If a number is provided it is
clamped to the least representable finite value of the given dtype. When ``None`` (default),
this value is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
high (Optional[Number]): Sets the upper limit (exclusive) of the given range. If a number is provided it is
clamped to the greatest representable finite value of the given dtype. When ``None`` (default) this value
is determined based on the :attr:`dtype` (see the table above). Default: ``None``.
.. deprecated:: 2.1
Passing ``low==high`` to :func:`~torch.testing.make_tensor` for floating or complex types is deprecated
since 2.1 and will be removed in 2.3. Use :func:`torch.full` instead.
requires_grad (Optional[bool]): If autograd should record operations on the returned tensor. Default: ``False``.
noncontiguous (Optional[bool]): If `True`, the returned tensor will be noncontiguous. This argument is
ignored if the constructed tensor has fewer than two elements. Mutually exclusive with ``memory_format``.
exclude_zero (Optional[bool]): If ``True`` then zeros are replaced with the dtype's small positive value
depending on the :attr:`dtype`. For bool and integer types zero is replaced with one. For floating
point types it is replaced with the dtype's smallest positive normal number (the "tiny" value of the
:attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number
whose real and imaginary parts are both the smallest positive normal number representable by the complex
type. Default ``False``.
memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Mutually exclusive
with ``noncontiguous``.
Raises:
ValueError: If ``requires_grad=True`` is passed for integral `dtype`
ValueError: If ``low >= high``.
ValueError: If either :attr:`low` or :attr:`high` is ``nan``.
ValueError: If both :attr:`noncontiguous` and :attr:`memory_format` are passed.
TypeError: If :attr:`dtype` isn't supported by this function.
Examples:
>>> # xdoctest: +SKIP
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
>>> # xdoctest: +SKIP
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
[False, True]], device='cuda:0')
"""
def modify_low_high(
low: Optional[float],
high: Optional[float],
*,
lowest_inclusive: float,
highest_exclusive: float,
default_low: float,
default_high: float,
) -> Tuple[float, float]:
"""
Modifies (and raises ValueError when appropriate) low and high values given by the user (input_low, input_high)
if required.
"""
def clamp(a: float, l: float, h: float) -> float:
return min(max(a, l), h)
low = low if low is not None else default_low
high = high if high is not None else default_high
if any(isinstance(value, float) and math.isnan(value) for value in [low, high]):
raise ValueError(
f"`low` and `high` cannot be NaN, but got {low=} and {high=}"
)
elif low == high and dtype in _FLOATING_OR_COMPLEX_TYPES:
warnings.warn(
"Passing `low==high` to `torch.testing.make_tensor` for floating or complex types "
"is deprecated since 2.1 and will be removed in 2.3. "
"Use `torch.full(...)` instead.",
FutureWarning,
stacklevel=3,
)
elif low >= high:
raise ValueError(f"`low` must be less than `high`, but got {low} >= {high}")
elif high < lowest_inclusive or low >= highest_exclusive:
raise ValueError(
f"The value interval specified by `low` and `high` is [{low}, {high}), "
f"but {dtype} only supports [{lowest_inclusive}, {highest_exclusive})"
)
low = clamp(low, lowest_inclusive, highest_exclusive)
high = clamp(high, lowest_inclusive, highest_exclusive)
if dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
# 1. `low` is ceiled to avoid creating values smaller than `low` and thus outside the specified interval
# 2. Following the same reasoning as for 1., `high` should be floored. However, the higher bound of
# `torch.randint` is exclusive, and thus we need to ceil here as well.
return math.ceil(low), math.ceil(high)
return low, high
if len(shape) == 1 and isinstance(shape[0], collections.abc.Sequence):
shape = shape[0] # type: ignore[assignment]
shape = cast(Tuple[int, ...], tuple(shape))
if noncontiguous and memory_format is not None:
raise ValueError(
f"The parameters `noncontiguous` and `memory_format` are mutually exclusive, "
f"but got {noncontiguous=} and {memory_format=}"
)
if requires_grad and dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
raise ValueError(
f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}"
)
if dtype is torch.bool:
low, high = cast(
Tuple[int, int],
modify_low_high(
low,
high,
lowest_inclusive=0,
highest_exclusive=2,
default_low=0,
default_high=2,
),
)
result = torch.randint(low, high, shape, device=device, dtype=dtype)
elif dtype in _BOOLEAN_OR_INTEGRAL_TYPES:
low, high = cast(
Tuple[int, int],
modify_low_high(
low,
high,
lowest_inclusive=torch.iinfo(dtype).min,
highest_exclusive=torch.iinfo(dtype).max
# In theory, `highest_exclusive` should always be the maximum value + 1. However, `torch.randint`
# internally converts the bounds to an int64 and would overflow. In other words: `torch.randint` cannot
# sample 2**63 - 1, i.e. the maximum value of `torch.int64` and we need to account for that here.
+ (1 if dtype is not torch.int64 else 0),
# This is incorrect for `torch.uint8`, but since we clamp to `lowest`, i.e. 0 for `torch.uint8`,
# _after_ we use the default value, we don't need to special case it here
default_low=-9,
default_high=10,
),
)
result = torch.randint(low, high, shape, device=device, dtype=dtype)
elif dtype in _FLOATING_OR_COMPLEX_TYPES:
low, high = modify_low_high(
low,
high,
lowest_inclusive=torch.finfo(dtype).min,
highest_exclusive=torch.finfo(dtype).max,
default_low=-9,
default_high=9,
)
result = torch.empty(shape, device=device, dtype=dtype)
_uniform_random_(
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
)
elif dtype in _FLOATING_8BIT_TYPES:
low, high = modify_low_high(
low,
high,
lowest_inclusive=torch.finfo(dtype).min,
highest_exclusive=torch.finfo(dtype).max,
default_low=-9,
default_high=9,
)
result = torch.empty(shape, device=device, dtype=torch.float32)
_uniform_random_(result, low, high)
result = result.to(dtype)
else:
raise TypeError(
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
" To request support, file an issue at: https://github.com/pytorch/pytorch/issues"
)
if noncontiguous and result.numel() > 1:
result = torch.repeat_interleave(result, 2, dim=-1)
result = result[..., ::2]
elif memory_format is not None:
result = result.clone(memory_format=memory_format)
if exclude_zero:
result[result == 0] = (
1 if dtype in _BOOLEAN_OR_INTEGRAL_TYPES else torch.finfo(dtype).tiny
)
if dtype in _FLOATING_OR_COMPLEX_TYPES:
result.requires_grad = requires_grad
return result

View File

@ -0,0 +1,474 @@
# mypy: ignore-errors
import collections
import torch
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.common_utils import TestCase
class AutocastTestLists:
def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
torch.randn((n, n), device=dev, dtype=torch.float32),)
weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_ih
torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_hh
torch.randn((num_chunks * n), device=dev, dtype=torch.float32), # bias_ih
torch.randn((num_chunks * n), device=dev, dtype=torch.float32)) # bias_hh
# returns args as a tuple
return input + hx + weights
# Supplies ops and arguments for test_autocast_* in test/test_cuda.py
def __init__(self, dev):
super().__init__()
n = 8
# Utility arguments, created as one-element tuples
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]
bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
# The lists below organize ops that autocast needs to test.
# self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
# Each op is associated with a tuple of valid arguments.
# In addition, cudnn conv ops are not supported on ROCm and hence will
# be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_fp16 = [
# deprecated _convolution
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
(0, 0), 1, False, True, True)),
# the current _convolution
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
(0, 0), 1, False, True, True, True)),
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("conv_tbc", conv_args_fp32[0] + bias_fp32),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
(1, 1), 1, False, True, True), TEST_WITH_ROCM),
("prelu", pointwise0_fp32 + element0_fp32),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
("mm", mat0_fp32 + mat1_fp32),
("mv", mat0_fp32 + pointwise0_fp32),
("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
# _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
# ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
# ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
]
self.torch_fp32 = [
("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
("cosh", pointwise0_fp16),
("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
("exp", pointwise0_fp16),
("expm1", pointwise0_fp16),
("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
("reciprocal", pointwise0_fp16),
("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
("sinh", pointwise0_fp16),
("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
# ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
("softmax", pointwise0_fp16 + (0,)),
("log_softmax", pointwise0_fp16 + (0,)),
("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
("group_norm", mat0_fp16 + (1,)),
("norm", pointwise0_fp16),
("norm", pointwise0_fp16, {"dim": 0}),
# these need magma
# ("norm", mat0_fp16, {"p": "nuc"}),
# ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
("norm", pointwise0_fp16, {"p": 1}),
("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
("cosine_similarity", mat0_fp16 + mat1_fp16),
("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
torch.tensor([1], device=dev, dtype=torch.int))),
("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
("cumprod", pointwise0_fp16 + (0,)),
("cumsum", pointwise0_fp16 + (0,)),
("dist", pointwise0_fp16 + pointwise1_fp16),
("pdist", mat0_fp16),
("cdist", mat0_fp16 + mat1_fp16),
("prod", pointwise0_fp16),
("prod", pointwise0_fp16 + (0,)),
("renorm", mat0_fp16 + (2, 0, 1.0)),
("sum", pointwise0_fp16),
("sum", mat0_fp16 + (1,)),
("logsumexp", mat0_fp16 + (1,)),
]
self.torch_need_autocast_promote = [
("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
torch.randn((1, 2), dtype=torch.float32, device=dev),
torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
torch.randn((1,), dtype=torch.float32, device=dev))),
("cross", (torch.randn(3, dtype=torch.float32, device=dev),
torch.randn(3, dtype=torch.float16, device=dev))),
("dot", pointwise0_fp16 + pointwise1_fp32),
("vdot", pointwise0_fp16 + pointwise1_fp32),
("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
0, 0, False)),
("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
torch.randn(1, device=dev, dtype=torch.float16))),
("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
torch.randn(1, device=dev, dtype=torch.float32))),
("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
]
self.nn_fp16 = [
("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
]
self.nn_fp32 = [
("softplus", pointwise0_fp16),
("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
torch.zeros((n,), device=dev, dtype=torch.long))),
("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
torch.zeros((n, n, n), device=dev, dtype=torch.long))),
("l1_loss", mat0_fp16 + mat1_fp16),
("smooth_l1_loss", mat0_fp16 + mat1_fp16),
("mse_loss", mat0_fp16 + mat1_fp16),
("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
self.linalg_fp16 = [
("linalg_vecdot", mat0_fp32 + mat0_fp32),
("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
]
self.methods_fp16 = [
("__matmul__", mat0_fp32 + mat1_fp32)
]
self.methods_fp32 = [
("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
]
self.banned = [
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
]
class AutocastCPUTestLists:
# Supplies ops and arguments for test_autocast_* in test/test_cpu.py
def __init__(self, dev):
super().__init__()
n = 8
# Utility arguments, created as one-element tuples
pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
for dimset in dummy_dimsets]
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),
torch.randn(dimset, dtype=torch.bfloat16, device=dev))
for dimset in dimsets]
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]
bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),)
for dimset in dummy_dimsets]
# The lists below organize ops that autocast needs to test.
# self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
# Each op is associated with a tuple of valid arguments.
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
torch.randn(5, device=dev, dtype=torch.float32),
0)),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("prelu", pointwise0_fp32 + element0_fp32),
("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
torch.randn((3 * n), device=dev, dtype=torch.float32),
torch.randn((n, n), device=dev, dtype=torch.float32),
torch.randn((n), device=dev, dtype=torch.float32))),
]
self.torch_fp32 = [
("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
torch.tensor([1], device=dev, dtype=torch.int))),
("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
]
self.nn_16 = [
("linear", mat0_fp32 + mat1_fp32, {}),
]
self.nn_fp32 = [
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
(torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
torch.zeros((n,), device=dev, dtype=torch.long))),
("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
torch.zeros((n, n, n), device=dev, dtype=torch.long))),
("l1_loss", mat0_bf16 + mat1_bf16),
("smooth_l1_loss", mat0_bf16 + mat1_bf16),
("mse_loss", mat0_bf16 + mat1_bf16),
("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
("huber_loss", mat0_bf16 + mat1_bf16),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
]
class TestAutocast(TestCase):
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
device,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_enabled(device_type=device))
with torch.amp.autocast(device_type=device, dtype=amp_dtype):
self.assertTrue(torch.is_autocast_enabled(device_type=device))
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(
out_type == output.dtype,
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
)
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(
out_type == output_method.dtype,
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
)
self.assertTrue(
(output is not None) or (output_method is not None),
f"{op} not found as an attribute on either Tensor or the requested module {module}",
)
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(
comparison, f"torch.{op} result did not match Tensor.{op} result"
)
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.amp.autocast(device_type=device, enabled=False):
self.assertFalse(
torch.is_autocast_enabled(device_type=device)
)
if module is not None and hasattr(module, op):
control = getattr(module, op)(
*cast(args, run_as_type), **add_kwargs
)
else:
control = getattr(args[0].to(run_as_type), op)(
*cast(args[1:], run_as_type), **add_kwargs
)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, f"torch.{op} result did not match control")
self.assertTrue(torch.is_autocast_enabled(device_type=device))
self.assertFalse(torch.is_autocast_enabled(device_type=device))

View File

@ -0,0 +1,635 @@
# mypy: ignore-errors
import torch
from functools import partial
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
# Note: [autograd.Function db]
#
# This is a collection of autograd.Function test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if a subsystem
# supports autograd.Function.
#
# Axes:
# - saves {output, input, intermediate, non-tensor}
# - {inputs, output} x {single tensor, tensors, arbitrary objects}
# - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs[0], output[1])
ctx.save_for_forward(inputs[0], output[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
@staticmethod
def vmap(info, in_dims, input):
result = NumpyCube.apply(input)
return result, (in_dims[0], in_dims[0])
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
class CubeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x):
return x ** 3, 3 * x ** 2
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(inputs[0], outputs[1])
ctx.save_for_forward(inputs[0], outputs[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
input, dinput = ctx.saved_tensors
result = grad_output * dinput + 6 * dinput
return result
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(1, low=0.8, high=2), args=())
class NumpyCubeNotComposable(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
return torch.tensor(input_np ** 3, device=input.device), input_np
@staticmethod
def setup_context(ctx, inputs, output):
_, input_np = output
ctx.input_np = input_np
ctx.device = inputs[0].device
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output, grad_saved):
result_np = 3 * (ctx.input_np ** 2)
return torch.tensor(result_np, device=ctx.device)
class NumpyMul(torch.autograd.Function):
@staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = NumpyMul.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
@staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Broadcasting
yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14})
class MulGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x * y
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = MulGenVmap.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = MulGenVmap.apply(grad_output, x)
return gx, gy
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
class NumpyExp_(torch.autograd.Function):
@staticmethod
def forward(x):
x_np = to_numpy(x)
np.exp(x_np, x_np)
return x
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(output)
ctx.save_for_forward(output)
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return NumpyMul.apply(grad_output, output)
@staticmethod
def vmap(info, in_dims, x):
NumpyExp_.apply(x)
return x, in_dims[0]
@staticmethod
def jvp(ctx, x_tangent):
# Doesn't call numpy operations because I didn't want to write NumpyMul_
output, = ctx.saved_tensors
x_tangent.mul_(output)
return x_tangent
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(x, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
# wrap dim
dim = dim if dim >= 0 else dim + x.dim() - 1
return NumpySort.apply(x, dim + 1), (0, 0, 0)
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
class SortGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, dim):
device = x.device
ind = torch.argsort(x, dim=dim)
ind_inv = torch.argsort(ind, axis=dim)
result = torch.take_along_dim(x, ind, dim=dim)
return result, ind, ind_inv
@staticmethod
def setup_context(ctx, inputs, outputs):
x, dim = inputs
_, ind, ind_inv = outputs
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(1,))
def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
tensor = make_arg(3, 5)
dim = 1
_, ind, ind_inv = NumpySort.apply(tensor, 1)
yield SampleInput(tensor, args=(ind, ind_inv, dim))
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
assert ind_tangent is None
assert ind_inv_tangent is None
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
class TakeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, ind, ind_inv, dim):
return torch.take_along_dim(x, ind, dim)
@staticmethod
def setup_context(ctx, inputs, outputs):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
class Select(torch.autograd.Function):
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return Select.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return Select.apply(x_tangent, ctx.idx)
class SelectGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, outputs):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def jvp(ctx, x_tangent, _):
return SelectGenVmap.apply(x_tangent, ctx.idx)
def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(2,))
class ScaleGradGenVmap(torch.autograd.Function):
generate_vmap_rule = True
scale = 3.14
@staticmethod
def forward(x):
return x.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, grad_output):
return grad_output * ScaleGradGenVmap.scale
@staticmethod
def jvp(ctx, x_tangent):
return x_tangent * ScaleGradGenVmap.scale
class ZeroGradientsGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x.clone(), y.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
# Intentionally too-large gradient
torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
@staticmethod
def jvp(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5))
class ForwardHasDefaultArgs(torch.autograd.Function):
@staticmethod
def forward(x, idx=(2,)):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return ForwardHasDefaultArgs.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
autograd_function_db = [
OpInfo(
'NumpyCubeAutogradFunction',
op=NumpyCube.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyExpMarkDirtyAutogradFunction',
op=lambda x: NumpyExp_.apply(x.clone()),
inplace_variant=NumpyExp_.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulAutogradFunction',
op=NumpyMul.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyCubeNotComposableAutogradFunction',
op=lambda x: NumpyCubeNotComposable.apply(x)[0],
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortAutogradFunction',
op=NumpySort.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'NumpyTakeAutogradFunction',
op=NumpyTake.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SelectAutogradFunction',
op=Select.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'CubeGenVmapAutogradFunction',
op=CubeGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'MulGenVmapAutogradFunction',
op=MulGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SortGenVmapAutogradFunction',
op=SortGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'SelectGenVmapAutogradFunction',
op=SelectGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ScaleGradGenVmapAutogradFunction',
op=ScaleGradGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ZeroGradientsGenVmapAutogradFunction',
op=ZeroGradientsGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ForwardHasDefaultArgsAutogradFunction',
op=ForwardHasDefaultArgs.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_forward_default_args,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
]

View File

@ -0,0 +1,165 @@
# mypy: ignore-errors
import os
import re
import sys
from typing import List
__all__ = [
"check_code_for_cuda_kernel_launches",
"check_cuda_kernel_launches",
]
# FILES TO EXCLUDE (match is done with suffix using `endswith`)
# You wouldn't drive without a seatbelt, though, so why would you
# launch a kernel without some safety? Use this as a quick workaround
# for a problem with the checker, fix the checker, then de-exclude
# the files in question.
exclude_files: List[str] = []
# Without using a C++ AST we can't 100% detect kernel launches, so we
# model them as having the pattern "<<<parameters>>>(arguments);"
# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
# the next statement.
#
# We model the next statement as ending at the next `}` or `;`.
# If we see `}` then a clause ended (bad) if we see a semi-colon then
# we expect the launch check just before it.
#
# Since the kernel launch can include lambda statements, it's important
# to find the correct end-paren of the kernel launch. Doing this with
# pure regex requires recursive regex, which aren't part of the Python
# standard library. To avoid an additional dependency, we build a prefix
# regex that finds the start of a kernel launch, use a paren-matching
# algorithm to find the end of the launch, and then another regex to
# determine if a launch check is present.
# Finds potential starts of kernel launches
kernel_launch_start = re.compile(
r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
)
# This pattern should start at the character after the final paren of the
# kernel launch. It returns a match if the launch check is not the next statement
has_check = re.compile(
r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
)
def find_matching_paren(s: str, startpos: int) -> int:
"""Given a string "prefix (unknown number of characters) suffix"
and the position of the first `(` returns the index of the character
1 past the `)`, accounting for paren nesting
"""
opening = 0
for i, c in enumerate(s[startpos:]):
if c == '(':
opening += 1
elif c == ')':
opening -= 1
if opening == 0:
return startpos + i + 1
raise IndexError("Closing parens not found!")
def should_exclude_file(filename) -> bool:
for exclude_suffix in exclude_files:
if filename.endswith(exclude_suffix):
return True
return False
def check_code_for_cuda_kernel_launches(code, filename=None):
"""Checks code for CUDA kernel launches without cuda error checks.
Args:
filename - Filename of file containing the code. Used only for display
purposes, so you can put anything here.
code - The code to check
Returns:
The number of unsafe kernel launches in the code
"""
if filename is None:
filename = "##Python Function Call##"
# We break the code apart and put it back together to add
# helpful line numberings for identifying problem areas
code = enumerate(code.split("\n")) # Split by line breaks
code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
code = '\n'.join(code) # Put it back together
num_launches_without_checks = 0
for m in kernel_launch_start.finditer(code):
end_paren = find_matching_paren(code, m.end() - 1)
if has_check.match(code, end_paren):
num_launches_without_checks += 1
context = code[m.start():end_paren + 1]
print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
return num_launches_without_checks
def check_file(filename):
"""Checks a file for CUDA kernel launches without cuda error checks
Args:
filename - File to check
Returns:
The number of unsafe kernel launches in the file
"""
if not (filename.endswith((".cu", ".cuh"))):
return 0
if should_exclude_file(filename):
return 0
with open(filename) as fo:
contents = fo.read()
unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
return unsafeCount
def check_cuda_kernel_launches():
"""Checks all pytorch code for CUDA kernel launches without cuda error checks
Returns:
The number of unsafe kernel launches in the codebase
"""
torch_dir = os.path.dirname(os.path.realpath(__file__))
torch_dir = os.path.dirname(torch_dir) # Go up to parent torch
torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2
kernels_without_checks = 0
files_without_checks = []
for root, dirnames, filenames in os.walk(torch_dir):
# `$BASE/build` and `$BASE/torch/include` are generated
# so we don't want to flag their contents
if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
# Curtail search by modifying dirnames and filenames in place
# Yes, this is the way to do this, see `help(os.walk)`
dirnames[:] = []
continue
for x in filenames:
filename = os.path.join(root, x)
file_result = check_file(filename)
if file_result > 0:
kernels_without_checks += file_result
files_without_checks.append(filename)
if kernels_without_checks > 0:
count_str = f"Found {kernels_without_checks} instances in " \
f"{len(files_without_checks)} files where kernel " \
"launches didn't have checks."
print(count_str, file=sys.stderr)
print("Files without checks:", file=sys.stderr)
for x in files_without_checks:
print(f"\t{x}", file=sys.stderr)
print(count_str, file=sys.stderr)
return kernels_without_checks
if __name__ == "__main__":
unsafe_launches = check_cuda_kernel_launches()
sys.exit(0 if unsafe_launches == 0 else 1)

View File

@ -0,0 +1 @@
# mypy: ignore-errors

View File

@ -0,0 +1,301 @@
# mypy: ignore-errors
r"""This file is allowed to initialize CUDA context when imported."""
import functools
import torch
import torch.cuda
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
import inspect
import contextlib
import os
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
if TEST_WITH_ROCM:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
else:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
def CDNA2OrLater():
if TEST_WITH_ROCM:
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
return False
def evaluate_gfx_arch_exact(matching_arch):
if not torch.cuda.is_available():
return False
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
return arch == matching_arch
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
return False
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
if TEST_CUDA:
return True
return False
def evaluate_platform_supports_cudnn_attention():
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or
PLATFORM_SUPPORTS_CUDNN_ATTENTION or
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip:
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
else:
return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
return False
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
if TEST_NUMBA:
try:
import numba.cuda
TEST_NUMBA_CUDA = numba.cuda.is_available()
except Exception as e:
TEST_NUMBA_CUDA = False
TEST_NUMBA = False
else:
TEST_NUMBA_CUDA = False
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
# RNG have been initialized.
__cuda_ctx_rng_initialized = False
# after this call, CUDA context and RNG must have been initialized on each GPU
def initialize_cuda_context_rng():
global __cuda_ctx_rng_initialized
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
if not __cuda_ctx_rng_initialized:
# initialize cuda context and rng for memory tests
for i in range(torch.cuda.device_count()):
torch.randn(1, device=f"cuda:{i}")
__cuda_ctx_rng_initialized = True
# Test whether hardware TF32 math mode enabled. It is enabled only on:
# - CUDA >= 11
# - arch >= Ampere
def tf32_is_not_fp32():
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split('.')[0]) < 11:
return False
return True
@contextlib.contextmanager
def tf32_off():
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = False
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
torch.backends.cuda.matmul.allow_tf32 = True
self.precision = tf32_precision
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
# This is a wrapper that wraps a test to run this test twice, one with
# allow_tf32=True, another with allow_tf32=False. When running with
# allow_tf32=True, it will use reduced precision as specified by the
# argument. For example:
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
# @tf32_on_and_off(0.005)
# def test_matmul(self, device, dtype):
# a = ...; b = ...;
# c = torch.matmul(a, b)
# self.assertEqual(c, expected)
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
# precision to check values.
#
# This decorator can be used for function with or without device/dtype, such as
# @tf32_on_and_off(0.005)
# def test_my_op(self)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device, dtype)
# @tf32_on_and_off(0.005)
# def test_my_op(self, dtype)
# if neither device nor dtype is specified, it will check if the system has ampere device
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
def tf32_on_and_off(tf32_precision=1e-5):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()
def with_tf32_enabled(self, function_call):
with tf32_on(self, tf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
for k, v in zip(arg_names, args):
kwargs[k] = v
cond = tf32_is_not_fp32()
if 'device' in kwargs:
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
if 'dtype' in kwargs:
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
if cond:
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper
# This is a wrapper that wraps a test to run it with TF32 turned off.
# This wrapper is designed to be used when a test uses matmul or convolutions
# but the purpose of that test is not testing matmul or convolutions.
# Disabling TF32 will enforce torch.float tensors to be always computed
# at full precision.
def with_tf32_off(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
with tf32_off():
return f(*args, **kwargs)
return wrapped
def _get_magma_version():
if 'Magma' not in torch.__config__.show():
return (0, 0)
position = torch.__config__.show().find('Magma ')
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
return tuple(int(x) for x in version_str.split("."))
def _get_torch_cuda_version():
if torch.version.cuda is None:
return (0, 0)
cuda_version = str(torch.version.cuda)
return tuple(int(x) for x in cuda_version.split("."))
def _get_torch_rocm_version():
if not TEST_WITH_ROCM:
return (0, 0)
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
return tuple(int(x) for x in rocm_version.split("."))
def _check_cusparse_generic_available():
return not TEST_WITH_ROCM
def _check_hipsparse_generic_available():
if not TEST_WITH_ROCM:
return False
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
# Shared by test_torch.py and test_multigpu.py
def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
# Create a module+optimizer that will use scaling, and a control module+optimizer
# that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
with torch.no_grad():
for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
s.copy_(c)
kwargs = {"lr": 1.0}
if optimizer_kwargs is not None:
kwargs.update(optimizer_kwargs)
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
return mod_control, mod_scaling, opt_control, opt_scaling
# Shared by test_torch.py, test_cuda.py and test_multigpu.py
def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
loss_fn = torch.nn.MSELoss().to(device)
skip_iter = 2
return _create_scaling_models_optimizers(
device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
) + (data, loss_fn, skip_iter)
# Importing this module should NOT eagerly initialize CUDA
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
assert not torch.cuda.is_initialized()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,111 @@
# mypy: ignore-errors
# Owner(s): ["oncall: distributed"]
from typing import Tuple
import torch
import torch.nn as nn
class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.seq(self.l1(x)))
class CompositeModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.u2(self.u1(self.l1(x))))
class UnitParamModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x):
return torch.mm(self.seq(self.l(x)), self.p)
class CompositeParamModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
self.register_buffer(
"buffer", torch.randn((100, 100), device=device), persistent=True
)
def forward(self, x):
a = self.u2(self.u1(self.l(x)))
b = self.p
return torch.mm(a, b)
class FakeSequential(nn.Module):
# Define this class to achieve a desired nested wrapping using the module
# wrap policy with `nn.Sequential`
def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
super().__init__()
self._module_sequence = list(modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self._module_sequence:
x = module(x)
return x
class NestedSequentialModel(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
# This nested structure exercises traversal order to catch differences
# between valid traversals (e.g. BFS and DFS variations).
self.seq1 = nn.Sequential(
nn.Linear(1, 1, device=device),
FakeSequential(
nn.Linear(1, 1, device=device),
nn.ReLU(),
FakeSequential(
nn.Linear(1, 1, device=device),
),
nn.ReLU(),
),
nn.Linear(1, 2, device=device),
)
self.lin = nn.Linear(2, 2, device=device)
self.seq2 = nn.Sequential(
nn.ReLU(),
nn.Linear(2, 3, device=device),
FakeSequential(
nn.Linear(3, 2, bias=False, device=device),
nn.Linear(2, 4, bias=False, device=device),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.lin(self.seq1(x)))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,191 @@
# mypy: ignore-errors
from typing import List
import torch
# Functions and classes for describing the dtypes a function supports
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
# Verifies each given dtype is a torch.dtype
def _validate_dtypes(*dtypes):
for dtype in dtypes:
assert isinstance(dtype, torch.dtype)
return dtypes
# class for tuples corresponding to a PyTorch dispatch macro
class _dispatch_dtypes(tuple):
def __add__(self, other):
assert isinstance(other, tuple)
return _dispatch_dtypes(tuple.__add__(self, other))
_empty_types = _dispatch_dtypes(())
def empty_types():
return _empty_types
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
def floating_types():
return _floating_types
_floating_types_and_half = _floating_types + (torch.half,)
def floating_types_and_half():
return _floating_types_and_half
def floating_types_and(*dtypes):
return _floating_types + _validate_dtypes(*dtypes)
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
def floating_and_complex_types():
return _floating_and_complex_types
def floating_and_complex_types_and(*dtypes):
return _floating_and_complex_types + _validate_dtypes(*dtypes)
_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
def double_types():
return _double_types
# NB: Does not contain uint16/uint32/uint64 for BC reasons
_integral_types = _dispatch_dtypes(
(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
)
def integral_types():
return _integral_types
def integral_types_and(*dtypes):
return _integral_types + _validate_dtypes(*dtypes)
_all_types = _floating_types + _integral_types
def all_types():
return _all_types
def all_types_and(*dtypes):
return _all_types + _validate_dtypes(*dtypes)
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
def complex_types():
return _complex_types
def complex_types_and(*dtypes):
return _complex_types + _validate_dtypes(*dtypes)
_all_types_and_complex = _all_types + _complex_types
def all_types_and_complex():
return _all_types_and_complex
def all_types_and_complex_and(*dtypes):
return _all_types_and_complex + _validate_dtypes(*dtypes)
_all_types_and_half = _all_types + (torch.half,)
def all_types_and_half():
return _all_types_and_half
def custom_types(*dtypes):
"""Create a list of arbitrary dtypes"""
return _empty_types + _validate_dtypes(*dtypes)
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
def get_all_dtypes(
include_half=True,
include_bfloat16=True,
include_bool=True,
include_complex=True,
include_complex32=False,
include_qint=False,
) -> List[torch.dtype]:
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
include_half=include_half, include_bfloat16=include_bfloat16
)
if include_bool:
dtypes.append(torch.bool)
if include_complex:
dtypes += get_all_complex_dtypes(include_complex32)
if include_qint:
dtypes += get_all_qint_dtypes()
return dtypes
def get_all_math_dtypes(device) -> List[torch.dtype]:
return (
get_all_int_dtypes()
+ get_all_fp_dtypes(
include_half=device.startswith("cuda"), include_bfloat16=False
)
+ get_all_complex_dtypes()
)
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
return (
[torch.complex32, torch.complex64, torch.complex128]
if include_complex32
else [torch.complex64, torch.complex128]
)
def get_all_int_dtypes() -> List[torch.dtype]:
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
dtypes = [torch.float32, torch.float64]
if include_half:
dtypes.append(torch.float16)
if include_bfloat16:
dtypes.append(torch.bfloat16)
return dtypes
def get_all_qint_dtypes() -> List[torch.dtype]:
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
float_to_corresponding_complex_type_map = {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,323 @@
# mypy: ignore-errors
# Torch
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
# Testing utils
from torch.testing._internal.common_dtype import floating_and_complex_types_and
from torch.testing._internal.common_utils import TestCase, \
freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
# Standard library
from itertools import chain
from typing import List, Union
from torch._C import TensorType
import io
def check_output_types(self, func, ref_outputs, args, kwargs):
graph = getattr(func, 'last_graph', None)
types = [o.type() for o in graph.outputs()]
self.assertTrue(len(types) == 1)
t = types[0]
torch._C._jit_assert_is_instance(ref_outputs, t)
# Test names in this set are only checked for a single derivative
nn_functional_single_grad = frozenset('test_nn_' + name for name in [
'pdist',
'multilabel_margin_loss',
'max_unpool3d',
'multi_margin_loss',
'binary_cross_entropy',
'binary_cross_entropy_size_average',
'ctc_loss',
'grid_sample',
])
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
"""Verifies a function performs identically to some reference implementation.
Commonly, this is used to verify that a JIT implementation
(output_func) matches the behavior of the eager implementation
(reference_func).
"""
kwargs = kwargs if kwargs else {}
def allSum(vs):
if isinstance(vs, torch.Tensor):
vs = (vs,)
return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
for i, v in enumerate(vs)
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
def clone_tensor(t, preserve_requires_grad):
require_grad = preserve_requires_grad and t.requires_grad
return t.detach().clone().requires_grad_(require_grad)
def clone_inputs(preserve_requires_grad: bool):
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
for arg in args:
if isinstance(arg, torch.Tensor):
inputs.append(clone_tensor(arg, preserve_requires_grad))
elif is_iterable_of_tensors(arg):
inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
else:
inputs.append(arg)
return inputs
# Returns tensors in args that requires_grad, including tensors in TensorList args
def get_recording_tensors(args):
recording_tensors: List[torch.Tensor] = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
recording_tensors.append(arg)
elif is_iterable_of_tensors(arg):
recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
return recording_tensors
# test no gradients case
nograd_inputs = clone_inputs(preserve_requires_grad=False)
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
with enable_profiling_mode_for_profiling_tests():
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
self.assertEqual(outputs, outputs_test)
if check_types:
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
if no_grad:
# skip grad tests
return
with enable_profiling_mode_for_profiling_tests():
# test single grad case
recording_inputs = clone_inputs(preserve_requires_grad=True)
recording_tensors = get_recording_tensors(recording_inputs)
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
allow_unused=allow_unused)
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
# test the grad grad case
if self._testMethodName in nn_functional_single_grad or no_gradgrad:
return
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
allow_unused=allow_unused)
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
recording_inputs = clone_inputs(preserve_requires_grad=True)
recording_tensors = get_recording_tensors(recording_inputs)
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
l1_test = allSum(outputs_test)
grads_test = torch.autograd.grad(
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
l2_test = (allSum(grads_test) * l1_test)
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
for g2, g2_test in zip(grads2, grads2_test):
if g2 is None and g2_test is None:
continue
self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
class JitCommonTestCase(TestCase):
def createFunctionFromGraph(self, trace):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
return torch._C._create_function_from_graph("forward", graph)
def assertExportImport(self, trace, inputs):
m = self.createFunctionFromGraph(trace)
self.assertExportImportModule(m, inputs)
def assertExportImportModule(self, m, inputs):
m_import = self.getExportImportCopy(m)
a = self.runAndSaveRNG(m, inputs)
b = self.runAndSaveRNG(m_import, inputs)
self.assertEqual(a, b, "Results of original model and "
"exported/imported version of model differed")
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
with freeze_rng_state():
results = func(*inputs, **kwargs)
return results
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
if not also_test_file:
return imported
with TemporaryFileName() as fname:
torch.jit.save(imported, fname)
return torch.jit.load(fname, map_location=map_location)
def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
fusion_nodes_not_found, non_fusible_nodes_being_fused,
fusion_nodes_found, nodes_in_diff_graph):
err_msg = "\nFailure in testing nodes' autodifferentiation. "
if should_autodiff_node:
err_msg += "One or more nodes were expected to be autodiffed, " \
"but were not found in specified fusible/nonfusible " \
"DifferentiableGraph groups. \nSpecifically:"
# The node is intended to appear in a differentiable graph but doesn't
diff_nodes_missing = []
# The node is intended to appear in a differentiable graph
# outside of a fusion group but instead is in a fusion group
diff_nodes_in_fusion = []
# The node is intended to appear in a fusion group but doesn't
fusion_nodes_missing = []
# The node is intended to appear in a fusion group but instead
# is just in an outer differentiable graph
fusion_nodes_in_diff = []
for node in nodes_not_in_diff_graph:
if node in non_fusible_nodes_being_fused:
diff_nodes_in_fusion.append(node)
else:
diff_nodes_missing.append(node)
for node in fusion_nodes_not_found:
if node in nodes_in_diff_graph:
fusion_nodes_in_diff.append(node)
else:
fusion_nodes_missing.append(node)
if len(diff_nodes_missing) > 0:
err_msg += f"\n {diff_nodes_missing} were not in one of the " \
"DifferentiableGraphs when they were expected to be. " \
"Did you intend for these nodes to be autodiffed? " \
"If not, remove them from the list of nonfusible nodes."
if len(diff_nodes_in_fusion) > 0:
err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \
"when they were expected to be just in a DifferentiableGraph. If it was " \
"intended for these nodes to be in FusionGroups, reclassify these nodes as " \
"fusible nodes. If these nodes were not intended to be fused, your " \
"autodifferentiation logic might be wrong."
if len(fusion_nodes_missing) > 0:
err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \
"of the DifferentiableGraphs when they were expected to be. " \
"They were also not found in an outer DifferentiableGraph. Did you " \
"intend for these nodes to be autodifferentiated? If not, you should " \
"remove these nodes from the test's fusible nodes. Otherwise your " \
"autodifferentiation logic might be wrong."
if len(fusion_nodes_in_diff) > 0:
err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \
"of the DifferentiableGraphs when they were expected to be, " \
"instead they were found just in an outer DifferentiableGraph. " \
"Did you intend for these nodes to be fused? If not, you should " \
"move these nodes into the test's nonfusible nodes. Otherwise your " \
"autodifferentiation logic might be wrong."
else:
err_msg += "One or more nodes were not expected to be autodiffed " \
"but were found in a DifferentiableGraph or in a FusionGroup " \
"of a DifferentiableGraph. Did you intend for these nodes to be " \
"autodiffed? If so, change this test to expect autodifferentiation. " \
"\nSpecifically:"
if len(fusion_nodes_found) > 0:
err_msg += f"\n {fusion_nodes_found} were not expected to be in " \
"one of the DifferentiableGraphs, but appeared in a FusionGroup " \
"of a DifferentiableGraph. "
if len(nodes_in_diff_graph) > 0:
err_msg += f"\n {nodes_in_diff_graph} were not expected to " \
"be in one of the DifferentiableGraphs but were."
return err_msg
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
# Note: currently no tests have fusible_nodes
fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
# For any non-fusible node, it must show up in one of the DifferentiableGraphs.
nodes_in_diff_graph = []
nodes_not_in_diff_graph = []
non_fusible_nodes_being_fused = []
for node in nonfusible_nodes:
if any(g.findNode(node) is not None for g in diff_subgraphs):
nodes_in_diff_graph.append(node)
else:
nodes_not_in_diff_graph.append(node)
if any(g.findNode(node) is not None for g in fusion_subgraphs):
non_fusible_nodes_being_fused.append(node)
found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
# For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
fusion_nodes_found = []
fusion_nodes_not_found = []
for node in fusible_nodes:
if any(g.findNode(node) is not None for g in fusion_subgraphs):
fusion_nodes_found.append(node)
else:
fusion_nodes_not_found.append(node)
found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
if should_autodiff_node is not None:
err_msg = self.autoDiffErrorMessage(should_autodiff_node,
nodes_not_in_diff_graph,
fusion_nodes_not_found,
non_fusible_nodes_being_fused,
fusion_nodes_found,
nodes_in_diff_graph)
self.assertEqual(should_autodiff_node,
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
traced_graph, assert_propagation, constant_prop=True):
# repropagte input shapes provided by tracing,
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
for enable_test_mode in [True, False]:
# here we are testing allowing/disallowing substituting in complete shapes as constants,
# disallowing constants helps stress test partial eval and substitution pipeline
torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
torch._C._jit_erase_non_input_shape_information(traced_graph)
if constant_prop:
torch._C._jit_pass_constant_propagation(traced_graph)
torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
# Add sizes to default tensor type to avoid checking something out of scope
# and difficulties with tracer leaving in other parts of tensor type
output = next(traced_graph.outputs()).type()
def test_type(type, actual_size):
sizes = type.symbolic_sizes()
out_type = TensorType.get().with_sizes(sizes)
actual_type = TensorType.get().with_sizes(actual_size)
# always check actual shape is a subtype of the output
self.assertTrue(actual_type.isSubtypeOf(out_type))
# and then if assertion flag is provided, check shape analysis
# is successful
if assert_propagation:
self.assertEqual(out_type.sizes(), actual_size)
if output.isSubtypeOf(torch._C.TensorType.get()):
test_type(output, out_sizes)
else:
tuple_elements = output.elements()
for i in range(len(tuple_elements)):
test_type(tuple_elements[i], out_sizes[i])
torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,78 @@
# mypy: ignore-errors
import contextlib
import functools
import inspect
import torch
# Test whether hardware BF32 math mode enabled. It is enabled only on:
# - MKLDNN is available
# - BF16 is supported by MKLDNN
def bf32_is_not_fp32():
if not torch.backends.mkldnn.is_available():
return False
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return False
return True
@contextlib.contextmanager
def bf32_off():
old_matmul_precision = torch.get_float32_matmul_precision()
try:
torch.set_float32_matmul_precision("highest")
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
@contextlib.contextmanager
def bf32_on(self, bf32_precision=1e-5):
old_matmul_precision = torch.get_float32_matmul_precision()
old_precision = self.precision
try:
torch.set_float32_matmul_precision("medium")
self.precision = bf32_precision
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
self.precision = old_precision
# This is a wrapper that wraps a test to run this test twice, one with
# allow_bf32=True, another with allow_bf32=False. When running with
# allow_bf32=True, it will use reduced precision as specified by the
# argument
def bf32_on_and_off(bf32_precision=1e-5):
def with_bf32_disabled(self, function_call):
with bf32_off():
function_call()
def with_bf32_enabled(self, function_call):
with bf32_on(self, bf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
for k, v in zip(arg_names, args):
kwargs[k] = v
cond = bf32_is_not_fp32()
if "device" in kwargs:
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
if "dtype" in kwargs:
cond = cond and (kwargs["dtype"] == torch.float)
if cond:
with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,385 @@
# Owner(s): ["module: unknown"]
from typing import Dict, Any, Tuple
from torch.ao.pruning import BaseSparsifier
import torch
import torch.nn.functional as F
from torch import nn
class ImplementedSparsifier(BaseSparsifier):
def __init__(self, **kwargs: Dict[str, Any]) -> None:
super().__init__(defaults=kwargs)
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None:
module.parametrizations.weight[0].mask[0] = 0
linear_state = self.state['linear1.weight']
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
class MockSparseLinear(nn.Linear):
"""
This class is a MockSparseLinear class to check convert functionality.
It is the same as a normal Linear layer, except with a different type, as
well as an additional from_dense method.
"""
@classmethod
def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear':
"""
"""
linear = cls(mod.in_features,
mod.out_features)
return linear
def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool:
"""
Checks to see if all rows in subset tensor are present in the superset tensor
"""
i = 0
for row in subset_tensor:
while i < len(superset_tensor):
if not torch.equal(row, superset_tensor[i]):
i += 1
else:
break
else:
return False
return True
class SimpleLinear(nn.Module):
r"""Model with only Linear layers without biases, some wrapped in a Sequential,
some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=False),
nn.Linear(5, 6, bias=False),
nn.Linear(6, 4, bias=False),
)
self.linear1 = nn.Linear(4, 4, bias=False)
self.linear2 = nn.Linear(4, 10, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = self.linear2(x)
return x
class LinearBias(nn.Module):
r"""Model with only Linear layers, alternating layers with biases,
wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.Linear(5, 6, bias=False),
nn.Linear(6, 3, bias=True),
nn.Linear(3, 3, bias=True),
nn.Linear(3, 10, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
return x
class LinearActivation(nn.Module):
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
Activation functions modules in between each Linear in the Sequential, and each outside layer.
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 6, bias=False),
nn.Tanh(),
nn.Linear(6, 4, bias=True),
)
self.linear1 = nn.Linear(4, 3, bias=True)
self.act1 = nn.ReLU()
self.linear2 = nn.Linear(3, 10, bias=False)
self.act2 = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = self.act1(x)
x = self.linear2(x)
x = self.act2(x)
return x
class LinearActivationFunctional(nn.Module):
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
Activation functions modules in between each Linear in the Sequential, and functional
activationals are called in between each outside layer.
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 6, bias=False),
nn.ReLU(),
nn.Linear(6, 4, bias=True),
)
self.linear1 = nn.Linear(4, 3, bias=True)
self.linear2 = nn.Linear(3, 8, bias=False)
self.linear3 = nn.Linear(8, 10, bias=False)
self.act1 = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
x = F.relu(x)
return x
class SimpleConv2d(nn.Module):
r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
Used to test pruned Conv2d-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=False),
nn.Conv2d(32, 64, 3, 1, bias=False),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.conv2d2(x)
return x
class Conv2dBias(nn.Module):
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
Used to test pruned Conv2d-Bias-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.Conv2d(32, 32, 3, 1, bias=True),
nn.Conv2d(32, 64, 3, 1, bias=False),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.conv2d2(x)
return x
class Conv2dActivation(nn.Module):
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
Activation function modules in between each Sequential layer, functional activations called
in-between each outside layer.
Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, bias=True),
nn.Tanh(),
nn.Conv2d(64, 64, 3, 1, bias=False),
nn.ReLU(),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.relu(x)
x = self.conv2d2(x)
x = F.hardtanh(x)
return x
class Conv2dPadBias(nn.Module):
r"""Model with only Conv2d layers, all with bias and some with padding > 0,
some in a Sequential and some following. Activation function modules in between each layer.
Used to test that bias is propagated correctly in the special case of
pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, bias=False),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, bias=True),
nn.Tanh(),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
self.act1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
self.act2 = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.act1(x)
x = self.conv2d2(x)
x = self.act2(x)
return x
class Conv2dPool(nn.Module):
r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
Activation function modules in between each layer, Pool2d modules in between each layer.
Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.maxpool(x)
x = self.af1(x)
x = self.conv2d2(x)
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
x = F.relu(x)
x = self.conv2d3(x)
return x
class Conv2dPoolFlattenFunctional(nn.Module):
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
and a functional Flatten followed by a Linear layer.
Activation functions and Pool2ds in between each layer also.
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(11, 13, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = self.af1(x)
x = self.conv2d2(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1) # test functional flatten
x = self.fc(x)
return x
class Conv2dPoolFlatten(nn.Module):
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
and a Flatten module followed by a Linear layer.
Activation functions and Pool2ds in between each layer also.
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
self.flatten = nn.Flatten()
self.fc = nn.Linear(44, 13, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = self.af1(x)
x = self.conv2d2(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class LSTMLinearModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a linear."""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> None:
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
self.linear = nn.Linear(hidden_dim, output_dim)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output, hidden = self.lstm(input)
decoded = self.linear(output)
return decoded, output
class LSTMLayerNormLinearModel(nn.Module):
"""Container module with an LSTM, a LayerNorm, and a linear."""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> None:
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
self.norm = nn.LayerNorm(hidden_dim)
self.linear = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x, state = self.lstm(x)
x = self.norm(x)
x = self.linear(x)
return x, state

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
# mypy: ignore-errors
r"""Importing this file includes common utility methods for checking quantized
tensors and modules.
"""
import numpy as np
import torch
from contextlib import contextmanager
from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
supported_qengines = torch.backends.quantized.supported_engines
supported_qengines.remove('none')
# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
# QNNPACK is not supported on PPC
# QNNPACK throws ASAN heap-buffer-overflow error.
if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
supported_qengines.remove('qnnpack')
def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
output_padding=0):
"""Computes the output shape given convolution parameters."""
return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
* (dilation - 1)) / stride) + 2 * output_padding + 1
# Quantization references
def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
"""Quantizes a numpy array."""
if qmin is None:
qmin = np.iinfo(dtype).min
if qmax is None:
qmax = np.iinfo(dtype).max
qx = np.round(x / scale + zero_point).astype(np.int64)
qx = np.clip(qx, qmin, qmax)
qx = qx.astype(dtype)
return qx
def _dequantize(qx, scale, zero_point):
"""Dequantizes a numpy array."""
x = (qx.astype(float) - zero_point) * scale
return x
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
converted back to given type"""
qx = (x * multiplier).round() + zero_point
qx = np.clip(qx, qmin, qmax).astype(qtype)
return qx
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
"""Calculate the dynamic quantization parameters (scale, zero_point)
according to the min and max element of the tensor"""
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
if qscheme == torch.per_tensor_symmetric:
assert dtype == torch.qint8
if isinstance(X, torch.Tensor):
X = X.numpy()
if dtype == torch.qint8:
if reduce_range:
qmin, qmax = -64, 63
else:
qmin, qmax = -128, 127
else: # dtype == torch.quint8
if reduce_range:
qmin, qmax = 0, 127
else:
qmin, qmax = 0, 255
min_val = X.min()
max_val = X.max()
is_symmetric = (qscheme == torch.per_tensor_symmetric)
if min_val == max_val:
scale = 1.0
zero_point = 0
else:
if is_symmetric:
max_val = max(max_val, -min_val)
min_val = -max_val
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return [float(scale), int(zero_point)]
def _calculate_dynamic_per_channel_qparams(X, dtype):
"""Calculate the dynamic quantization parameters (scale, zero_point)
according to the min and max element of the tensor"""
if isinstance(X, torch.Tensor):
X = X.numpy()
qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
n_levels = qmax - qmin
scale = np.zeros(X.shape[0], dtype=np.float64)
zero_point = np.zeros(X.shape[0], dtype=np.int64)
for i in range(zero_point.shape[0]):
min_val = X.min()
max_val = X.max()
if min_val == max_val:
scale[i] = 1.0
zero_point[i] = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale[i] = (max_val - min_val) / n_levels
scale[i] = max(scale[i], np.finfo(np.float32).eps)
zero_point[i] = qmin - round(min_val / scale[i])
zero_point[i] = max(qmin, zero_point[i])
zero_point[i] = min(qmax, zero_point[i])
return scale, zero_point
def _snr(x, x_hat):
"""Calculates the signal to noise ratio and returns the signal and noise
power, as well as the SNR in dB.
If the input is a list/tuple this function is called recursively on each
element. The result will have the same nested structure as the inputs.
Args:
x, x_hat: Either a tensor or a nested list/tuple of tensors.
Returns:
signal, noise, SNR(in dB): Either floats or a nested list of floats
"""
if isinstance(x, (list, tuple)):
assert len(x) == len(x_hat)
res = []
for idx in range(len(x)):
res.append(_snr(x[idx], x_hat[idx]))
return res
if x_hat.is_quantized:
x_hat = x_hat.dequantize()
if x.is_quantized:
x = x.dequantize()
noise = (x - x_hat).norm()
if noise == 0:
return 0.0, float('inf'), float('inf')
signal = x.norm()
snr = signal / noise
snr_db = 20 * snr.log10()
return signal, noise, snr_db
@contextmanager
def override_quantized_engine(qengine):
previous = torch.backends.quantized.engine
torch.backends.quantized.engine = qengine
try:
yield
finally:
torch.backends.quantized.engine = previous
@contextmanager
def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
try:
if qengine_is_qnnpack:
torch._C._set_default_mobile_cpu_allocator()
yield
finally:
if qengine_is_qnnpack:
torch._C._unset_default_mobile_cpu_allocator()
# TODO: Update all quantization tests to use this decorator.
# Currently for some of the tests it seems to have inconsistent params
# for fbgemm vs qnnpack.
def override_qengines(qfunction):
def test_fn(*args, **kwargs):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
# qfunction should not return anything.
qfunction(*args, **kwargs)
return test_fn
def qengine_is_fbgemm():
return torch.backends.quantized.engine == 'fbgemm'
def qengine_is_qnnpack():
return torch.backends.quantized.engine == 'qnnpack'
def qengine_is_onednn():
return torch.backends.quantized.engine == 'onednn'
def qengine_is_x86():
return torch.backends.quantized.engine == 'x86'
# Helper function used to simulate per-channel fake-quant against any axis
def _permute_to_axis_zero(X, axis):
new_axis_list = list(range(X.dim()))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = X.permute(tuple(new_axis_list))
return y, new_axis_list
# Reference method for fake quantize
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
dtype = X.dtype
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
res = torch.zeros_like(X)
for i in range(X.size()[0]):
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
out = res.permute(tuple(permute_axis_list))
return out.to(dtype)
# Reference method for the gradient of the fake quantize operator
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
dtype = X.dtype
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
Xq = torch.zeros_like(X)
for i in range(X.size()[0]):
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
Xq = Xq.permute(tuple(permute_axis_list))
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY)
res[mask] = dY[mask]
return res.to(dtype)
def to_tensor(X, device):
if not isinstance(X, torch.Tensor):
X = torch.tensor(X)
else:
X = X.clone().detach()
return X.to(device=torch.device(device), dtype=torch.float32)

View File

@ -0,0 +1,266 @@
# mypy: ignore-errors
import torch
from copy import deepcopy
from torch.utils._pytree import tree_map
import torch.utils._pytree as pytree
# TODO: Move LoggingTensor here.
from torch.testing._internal.logging_tensor import LoggingTensor
# Base class for wrapper-style tensors.
class WrapperTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
if "size" not in kwargs:
size = t.size()
else:
size = kwargs["size"]
del kwargs["size"]
if "dtype" not in kwargs:
kwargs["dtype"] = t.dtype
if "layout" not in kwargs:
kwargs["layout"] = t.layout
if "device" not in kwargs:
kwargs["device"] = t.device
if "requires_grad" not in kwargs:
kwargs["requires_grad"] = False
# Ignore memory_format and pin memory for now as I don't know how to
# safely access them on a Tensor (if possible??)
wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
wrapper._validate_methods()
return wrapper
@classmethod
def get_wrapper_properties(cls, *args, **kwargs):
# Should return both an example Tensor and a dictionary of kwargs
# to override any of that example Tensor's properly.
# This is very similar to the `t.new_*(args)` API
raise NotImplementedError("You need to implement get_wrapper_properties")
def _validate_methods(self):
# Skip this if not in debug mode?
# Changing these on the python side is wrong as it would not be properly reflected
# on the c++ side
# This doesn't catch attributes set in the __init__
forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
for el in forbidden_overrides:
if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
f"property {el} but this is not allowed as such change would "
"not be reflected to c++ callers.")
class DiagTensorBelow(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, diag, requires_grad=False):
assert diag.ndim == 1
return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
def __init__(self, diag, requires_grad=False):
self.diag = diag
handled_ops = {}
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# For everything else, call the handler:
fn = cls.handled_ops.get(func.__name__, None)
if fn:
return fn(*args, **(kwargs or {}))
else:
# Note that here, because we don't need to provide the autograd formulas
# we can have a default "fallback" that creates a plain Tensor based
# on the diag elements and calls the func again.
def unwrap(e):
return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
def wrap(e):
if isinstance(e, torch.Tensor) and e.ndim == 1:
return DiagTensorBelow(e)
if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
return DiagTensorBelow(e.diag())
return e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"diag={self.diag}")
class SparseTensor(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
assert values.device == indices.device
return values, {"size": size, "requires_grad": requires_grad}
def __init__(self, size, values, indices, requires_grad=False):
self.values = values
self.indices = indices
def __repr__(self):
return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
def sparse_to_dense(self):
res = torch.zeros(self.size(), dtype=self.values.dtype)
res[self.indices.unbind(1)] = self.values
return res
@staticmethod
def from_dense(t):
indices = t.nonzero()
values = t[indices.unbind(1)]
return SparseTensor(t.size(), values, indices)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
func_name = f"{func.__module__}.{func.__name__}"
res = cls._try_call_special_impl(func_name, args, kwargs)
if res is not NotImplemented:
return res
# Otherwise, use a default implementation that construct dense
# tensors and use that to compute values
def unwrap(e):
return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
# Wrap back all Tensors into our custom class
def wrap(e):
# Check for zeros and use that to get indices
return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
# To show how things happen later
def __rmul__(self, other):
return super().__rmul__(other)
_SPECIAL_IMPLS = {}
@classmethod
def _try_call_special_impl(cls, func, args, kwargs):
if func not in cls._SPECIAL_IMPLS:
return NotImplemented
return cls._SPECIAL_IMPLS[func](args, kwargs)
# Example non-wrapper subclass that stores extra state.
class NonWrapperTensor(torch.Tensor):
def __new__(cls, data):
t = torch.Tensor._make_subclass(cls, data)
t.extra_state = {
'last_func_called': None
}
return t
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, cls):
# Do something with the extra state. For the example here, just store the name of the
# last function called (skip for deepcopy so the copy has the same extra state).
if func is torch.Tensor.__deepcopy__:
result.extra_state = deepcopy(args[0].extra_state)
else:
result.extra_state = {
'last_func_called': func.__name__,
}
return result
# new_empty() must be defined for deepcopy to work
def new_empty(self, shape):
return type(self)(torch.empty(shape))
# Class used to store info about subclass tensors used in testing.
class SubclassInfo:
__slots__ = ['name', 'create_fn', 'closed_under_ops']
def __init__(self, name, create_fn, closed_under_ops=True):
self.name = name
self.create_fn = create_fn # create_fn(shape) -> tensor instance
self.closed_under_ops = closed_under_ops
subclass_db = {
torch.Tensor: SubclassInfo(
'base_tensor', create_fn=torch.randn
),
NonWrapperTensor: SubclassInfo(
'non_wrapper_tensor',
create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
),
LoggingTensor: SubclassInfo(
'logging_tensor',
create_fn=lambda shape: LoggingTensor(torch.randn(shape))
),
SparseTensor: SubclassInfo(
'sparse_tensor',
create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
),
DiagTensorBelow: SubclassInfo(
'diag_tensor_below',
create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
closed_under_ops=False # sparse semantics
),
}
class SubclassWithTensorFactory(torch.Tensor):
@staticmethod
def __new__(cls, src):
shape = src.shape
kwargs = {}
kwargs["strides"] = src.stride()
kwargs["storage_offset"] = src.storage_offset()
kwargs["device"] = src.device
kwargs["layout"] = src.layout
kwargs["requires_grad"] = src.requires_grad
kwargs["dtype"] = src.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, src):
self.src = src
def __repr__(self):
return f"{self.__class__.__name__}"
def __tensor_flatten__(self):
return ["src"], None
@classmethod
def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
src = inner_tensors["src"]
return cls(src)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
def _fn(x):
return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
_args = pytree.tree_map_only(cls, _fn, args)
_kwargs = pytree.tree_map_only(cls, _fn, kwargs)
_out = func(*_args, **_kwargs)
_out_flat, _out_spec = pytree.tree_flatten(_out)
out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
return pytree.tree_unflatten(out_flat, _out_spec)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,581 @@
# mypy: ignore-errors
import torch
from torch import Tensor
import itertools
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
from functools import partial
from torch.utils._mode_utils import no_dispatch, all_same_mode
import torch.autograd.forward_ad as fwAD
from typing import Callable
import re
def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
elem = wrapper_tensor.elem
metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
metadata_elem = metadata_accessor(elem)
if metadata_wrapper_tensor == metadata_elem:
return
raise RuntimeError(
f"This operator is not Composite Compliant: the "
f"{metadata_name} of the tensor was modified directly without "
f"going through the PyTorch dispatcher.")
def check_metadata_consistency(wrapper_tensor, CCT):
# CCT: CompositeCompliantTensor class which is generated using generate_cct
if not isinstance(wrapper_tensor, CCT):
return
things_to_check = {
'shape': Tensor.size,
'dtype': lambda x: x.dtype,
'device': lambda x: x.device,
'numel': Tensor.numel,
'stride': Tensor.stride,
'storage_offset': Tensor.storage_offset,
}
for metadata_name, metadata_accessor in things_to_check.items():
check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
def is_view_fn(func):
return func.overloadpacket.__name__ in {
'as_strided',
'detach',
'diagonal',
'expand',
'expand_as',
'movedim',
'narrow',
'permute',
'select',
'squeeze',
'transpose',
't',
'real',
'imag',
'view_as_real',
'view_as_complex',
'unflatten',
'unfold',
'unsqueeze',
'view',
'view_as',
'unbind',
'split',
'split_with_sizes',
'vsplit',
'hsplit',
'tensor_split',
'chunk',
'swapaxes',
'slice',
'_reshape_alias',
'_unsafe_view',
'_conj',
'alias',
}
# manually populated from native_functions that have inplace_view: True.
# In the future we will probably be able to grab that list directly
def is_inplace_view_fn(func):
return func.overloadpacket.__name__ in {
'as_strided_',
'detach_',
'squeeze_',
'swapaxes_',
'swapdims_',
't_',
'transpose_',
'unsqueeze_',
}
# Introspection please save us
def is_inplace(func):
name = func.overloadpacket.__name__
if re.match('__i.+__', name):
return True
if re.match('__.+__', name):
return False
return name[-1] == '_'
def generate_cct_and_mode(autograd_view_consistency=True):
# This function returns a new class CompositeCompliantTensor
# The two arguments control the behaviour described below.
# autograd_view_consistency:
# If True, alias result using `set_` if func returns a view
# (See Note [Alias Result]).
# Since Forward AD doesn't work with `set_`
# we disable it by setting alias to False.
class CompositeCompliantTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
@staticmethod
def __new__(cls, elem, mode, *args, **kwargs):
assert type(elem) is not cls, \
"Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
# The storage of CompositeCompliantTensor should never be used directly
# by a Composite operation; if the Composite
# operator attempts to read from the storage without dispatching then it'll
# raise a RuntimeError due to it being a meta storage.
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, elem.size(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad,
strides=elem.stride(), storage_offset=elem.storage_offset())
if elem.requires_grad:
# CompositeCompliantTensor steals the "requires_grad"-ness.
# Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
device=elem.device, layout=elem.layout,
requires_grad=False)
tmp.copy_(elem.detach())
r.elem = tmp
else:
r.elem = elem
assert r.stride() == r.elem.stride()
# Propagate conjugate bits to the wrapper tensor
# Ref: https://github.com/albanD/subclass_zoo/issues/24
# Ref: https://github.com/albanD/subclass_zoo/issues/21
torch._C._set_conj(r, r.elem.is_conj())
torch._C._set_neg(r, r.elem.is_neg())
r.mode = mode
return r
def __repr__(self):
return f"CompositeCompliantTensor({self.elem})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
if not all_same_mode(modes):
raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
with modes[0]:
return func(*args, **kwargs)
class CompositeCompliantTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, CompositeCompliantTensor) else e
def wrap(e):
return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
if func == torch.ops.aten._local_scalar_dense.default:
raise RuntimeError(
".item() is not allowed to be called inside of composite "
"functions in the PyTorch library because not all backends "
"and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
if func.overloadpacket.__name__ in ('set_', 'resize_'):
raise RuntimeError(
f"{func.__name__} is not allowed to be called inside of "
f"Composite operators.")
if is_inplace(func):
# NB: We are making an assumption that if the function is in-place,
# then the first argument is being written to. Introspection please save us!
mutated_argument = args[0]
if not isinstance(mutated_argument, CompositeCompliantTensor) and \
any(isinstance(a, CompositeCompliantTensor) for a in args[1:]):
raise RuntimeError(
'Not composite compliant: performing in-place operation '
f'{func.__name__} where the Tensor being written to is '
'regular Tensor but the other tensors are Tensor Subclasses. '
'Please try to avoid this in-place operation.')
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
rs = tree_map(wrap, unwrapped_rs)
if is_view_fn(func) and autograd_view_consistency:
# Note [Alias Result]
# Autograd asserts that for B = A.view_fn(...), B and A's storages
# are the same. Here we try to make B alias A to avoid those asserts.
# See https://github.com/pytorch/pytorch/issues/65339 for more information
# about the issue.
with no_dispatch():
# Idea: this is a weird way of getting a storage that aliases the input.
# This is a workaround for #65339.
# 1. under no_dispatch, all of the wrapper tensors look like regular
# tensors with special storage (the storage is nullptr and
# advertises CPU/CUDA device.
# 2. we run func, which ends up running the view operation
# 3. All view operations reuse the input's storage and return
# result Tensor(s) with new sizes/strides/offset that alias
# the input.
# 4. we set the storage (and sizes/strides/offset) of the wrapper
# tensor results to be that of the tensors that alias the input
result = func(*args, **kwargs)
if isinstance(result, (tuple, list)):
for a, b in zip(rs, result):
a.set_(b)
else:
rs.set_(result)
# Some operations are allowed to in-place modify the metadata of the
# inputs. The only ones are the "inplace view functions"; when we
# run into these, we manually modify the metadata of the input.
with no_dispatch():
if is_inplace_view_fn(func):
func(*args, **kwargs)
# For each CompositeCompliantTensor t, we check that t and t.elem
# have consistent metadata. If they don't have consistent metadata,
# that means the operator did something fishy.
check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
pytree.tree_map_(check, args)
pytree.tree_map_(check, kwargs)
pytree.tree_map_(check, rs)
return rs
return CompositeCompliantTensor, CompositeCompliantTensorMode()
def is_tensorlist(lst):
if not isinstance(lst, list) and not isinstance(lst, tuple):
return False
if len(lst) == 0:
return False
all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst)
if all_tensors:
return True
exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst)
if exists_one_tensor:
raise RuntimeError('This test assumes that PyTorch APIs cannot take '
'mixed lists of Tensor and other things')
return False
def maybe_map(fn, should_map, arg):
return fn(arg) if should_map else arg
def wrap(arg, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
if isinstance(arg, torch.Tensor):
return CCT(arg, cct_mode)
if is_tensorlist(arg):
return [CCT(a, cct_mode) for a in arg]
raise RuntimeError("wrap assumes that the input can be wrapped")
# Given a list of flat arguments, some of which may be Tensors, return all
# possible ways some of the arguments could be CompositeCompliantTensors (CCT).
# For example, given Tensors A, B, C and flat_args = [A, 1, B],
# We would return the following 4 options:
# [CCT(A), 1, CCT(B)]
# [CCT(A), 1, B]
# [A, 1, CCT(B)]
# [A, 1, B]
# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
# don't accept that many input Tensors.
def generate_subclass_choices(flat_args, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
for which_args_are_wrapped in itertools.product(*subclass_options):
result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
yield result, which_args_are_wrapped
# For an operation f(*args, **kwargs), each Tensor argument may either be
# a regular Tensor or a Tensor Subclass. This iterator iterates through
# all of those options.
def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
flat_kwargs, spec = tree_flatten(kwargs)
flat_args_kwargs = list(args) + list(flat_kwargs)
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
new_args = choice[:len(args)]
new_kwargs = tree_unflatten(choice[len(args):], spec)
which_args_are_wrapped = debug_metadata[:len(args)]
which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
def raise_composite_compliance_error(err, additional_info=''):
raise RuntimeError(
"Composite compliance check failed with "
"the above error.\n"
f"{additional_info}"
"If you are adding an OpInfo of an "
"existing operator, please feel free to skip this test "
"because the problem was pre-existing and file an issue. "
"Otherwise, if you added a new operator, please read "
"through the Composite Compliance section in "
"aten/src/ATen/native/README.md for how to resolve this. "
) from err
# This test checks ALL possible permutations of calling `op` with arguments
# that are individually either a regular Tensor or a Tensor subclass.
#
# The general strategy is to wrap some Tensor args and kwargs in
# CompositeCompliantTensor wrappers and call the operation.
# If some composite operation does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_all_permutations(op, args, kwargs, assert_equal_fn):
CCT, cct_mode = generate_cct_and_mode()
expected = op(*args, **kwargs)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
try:
actual = op(*new_args, **new_kwargs)
# NOTE: [What errors are Composite Compliance trying to catch?]
#
# There's two things we want to catch:
# - errors that would raise within the torch_dispatch impl
# - data_ptr accesses
# The first is easy to filter for (we could make the error a different
# error class), the second is always going to be a RuntimeError due to
# how it is implemented (if you try to access the data_ptr of thex
# wrapper Tensor, it raises you some internal RuntimeError).
#
# So the most general thing to catch here was RuntimeError. If you
# are here and debugging why your test failed, it's plausible that
# the operator itself is broken and that there are other tests failing.
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tree_map(unwrap, actual), expected)
# Checks via the usage of torch dispatch mode certain anti-patterns that
# are not composite compliant.
#
# In particular, the anti-pattern we are trying to prevent is a user
# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
# here because all factory functions will create tensors that are
# CompositeCompliantTensor.
#
# The general strategy is to wrap all Tensor args and kwargs in
# CompositeCompliantTensor wrappers. If an operator that is
# Composite does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_with_mode(op, args, kwargs, assert_equal_fn):
CCT, cct_mode = generate_cct_and_mode()
def wrap(e):
return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
expected = op(*args, **kwargs)
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
try:
with cct_mode:
actual = op(*args, **kwargs)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(err)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tree_map(unwrap, actual), expected)
def gather_leaf_tensors(args, kwargs):
leaf_tensors = []
args, args_spec = tree_flatten(args)
kwargs, kwargs_spec = tree_flatten(kwargs)
args = args + kwargs
for arg in args:
if not isinstance(arg, torch.Tensor):
continue
if arg.requires_grad:
leaf_tensors.append(arg)
return leaf_tensors
def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
if gradcheck_wrapper is None:
results = op(*args, **kwargs)
else:
results = gradcheck_wrapper(op, *args, **kwargs)
if output_process_fn_grad is not None:
results = output_process_fn_grad(results)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
leaf_tensors = gather_leaf_tensors(args, kwargs)
assert len(leaf_tensors) > 0
return torch.autograd.grad(flat_diff_results, leaf_tensors,
grads, allow_unused=True, retain_graph=True)
# Checks if the backward formula is composite compliant by testing
# all possible permutations of {inputs, grad_outputs} being
# CompositeCompliantTensor or regular Tensors.
#
# NB: it is important that op is accepted as a Callable and not an OpInfo,
# this means we can apply check_backward_formula to things that aren't OpInfos
# while debugging.
def check_backward_formula(op: Callable, args, kwargs,
output_process_fn_grad=None,
gradcheck_wrapper=None, assert_equal_fn=None):
CCT, cct_mode = generate_cct_and_mode()
expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
assert len(leaf_tensors) > 0
try:
if gradcheck_wrapper is None:
results = op(*new_args, **new_kwargs)
else:
results = gradcheck_wrapper(op, *new_args, **new_kwargs)
if output_process_fn_grad is not None:
results = output_process_fn_grad(results)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
# NB: ones, not ones_like, so we get a regular Tensor here
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
for r in flat_diff_results]
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
try:
actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
allow_unused=True, retain_graph=True)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
f"- wrapped_grads: {which_grad_is_batched}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
# Checks if the forward AD formula is composite compliant by testing
# all possible permutations of {primals, tangents} being
# CompositeCompliantTensor or regular Tensors.
#
# NB: it is important that op is accepted as a Callable and not an OpInfo,
# this means we can apply check_forward_ad_formula to things that aren't OpInfos
# while debugging.
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
def maybe_tangent(t):
assert type(t) is not CCT
# Generate `tangent` tensor
# if given object is a Tensor and requires grad is set.
if isinstance(t, torch.Tensor) and t.requires_grad:
return torch.randn_like(t)
elif is_tensorlist(t):
return [torch.randn_like(e) if e.requires_grad else None for e in t]
return None
tangent_args = tuple(maybe_tangent(arg) for arg in args)
flat_kwargs, spec = tree_flatten(kwargs)
flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
with fwAD.dual_level():
def maybe_make_dual(dual):
# Returns dual tensor if primal is a tensor/tensor subclass
# with requires_grad set.
primal, tangent = dual
if isinstance(primal, torch.Tensor) and primal.requires_grad:
return fwAD.make_dual(primal.detach(), tangent)
elif is_tensorlist(primal):
return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
for pri, tang in zip(primal, tangent))
return primal
def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
if gradcheck_wrapper is None:
return op(*op_args, **op_kwargs)
return gradcheck_wrapper(op, *op_args, **op_kwargs)
expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
expected = tree_map(fwAD.unpack_dual, expected)
expected_primals = tree_map(lambda x: x.primal, expected)
expected_tangents = tree_map(lambda x: x.tangent, expected)
# Permutations of arg and kwargs in CCT.
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
# Permutations tangent arg and tangent kwargs in CCT.
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
new_tang_args, new_tang_kwargs, \
which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
try:
if gradcheck_wrapper is None:
actual = op(*op_args, **op_kwargs)
else:
actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
actual = tree_map(fwAD.unpack_dual, actual)
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)

View File

@ -0,0 +1,586 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import functools
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
from torch.testing._internal.autograd_function_db import (
sample_inputs_numpy_cube,
sample_inputs_numpy_mul,
sample_inputs_numpy_mul_scalar,
sample_inputs_numpy_sort,
sample_inputs_numpy_take,
)
from torch import Tensor
from torch.types import Number
from typing import * # noqa: F403
# Note: [custom op db]
#
# This is a collection of custom operator test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if subsystems
# support them correctly.
def to_numpy(tensor):
return tensor.cpu().numpy()
@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=())
def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
x_np = to_numpy(x)
dx = torch.tensor(3 * x_np ** 2, device=x.device)
return torch.tensor(x_np ** 3, device=x.device), dx
@numpy_cube.register_fake
def _(x):
return x.clone(), x.clone()
def numpy_cube_setup_context(ctx, inputs, output):
x, = inputs
cube, dx = output
ctx.save_for_backward(x, dx)
def numpy_cube_backward(ctx, grad_out, grad_dx):
x, dx = ctx.saved_tensors
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
return grad_x
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
def numpy_cube_vmap(info, in_dims, x):
result = numpy_cube(x)
return result, (in_dims[0], in_dims[0])
numpy_cube.register_vmap(numpy_cube_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@numpy_mul.register_fake
def _(x, y):
assert x.device == y.device
return (x * y).contiguous()
def numpy_mul_setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
def numpy_mul_backward(ctx, grad_out):
x, y = ctx.saved_tensors
grad_x = grad_out * y if ctx.needs_input_grad[0] else None
grad_y = grad_out * x if ctx.needs_input_grad[1] else None
return grad_x, grad_y
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
def numpy_mul_vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
numpy_mul.register_vmap(numpy_mul_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
return torch.tensor(to_numpy(x) * scalar, device=x.device)
@numpy_mul_scalar.register_fake
def _(x, *, scalar):
return (x * scalar).contiguous()
def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output):
ctx.scalar = keyword_only_inputs["scalar"]
def numpy_mul_scalar_backward(ctx, grad_out):
grad_x = grad_out * ctx.scalar
return grad_x
numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
x_bdim, = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
result = x * scalar
result = result.movedim(-1, 0)
return result, 0
numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@numpy_sort.register_fake
def _(x, dim):
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
def numpy_sort_setup_context(ctx, inputs, output):
out, ind, ind_inv = output
ctx.dim = inputs[1]
ctx.save_for_backward(ind, ind_inv)
ctx.mark_non_differentiable(ind, ind_inv)
def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
ind, ind_inv = ctx.saved_tensors
return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
def numpy_sort_vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
dim = dim if dim >= 0 else dim + x.dim() - 1
result = numpy_sort(x, dim + 1)
return result, (0, 0, 0)
numpy_sort.register_vmap(numpy_sort_vmap)
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@numpy_take.register_fake
def _(x, ind, ind_inv, dim):
assert x.device == ind.device
assert x.device == ind_inv.device
assert ind.dtype == torch.long
assert ind_inv.dtype == torch.long
return torch.empty_like(x)
def numpy_take_setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.dim = dim
ctx.save_for_backward(ind, ind_inv)
def numpy_take_backward(ctx, grad_out):
ind, ind_inv = ctx.saved_tensors
grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
return grad_x, None, None, None
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return numpy_take(x, ind, ind_inv, dim + 1), 0
numpy_take.register_vmap(numpy_take_vmap)
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
def numpy_nonzero(x: Tensor) -> Tensor:
x_np = to_numpy(x)
res = np.stack(np.nonzero(x_np), axis=1)
if res.shape[0] <= 1:
raise RuntimeError("not supported")
return torch.tensor(res, device=x.device)
@numpy_nonzero.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
shape = [i0, x.dim()]
result = x.new_empty(shape, dtype=torch.long)
return result
def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
shape = 10
result = make_arg(shape, low=0.9, high=2)
mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
with torch.no_grad():
result *= mask
yield SampleInput(result, args=())
def numpy_nonzero_vmap(info, in_dims, x):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nonzero.register_vmap(numpy_nonzero_vmap)
@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
@numpy_view_copy.register_fake
def _(x, shape) -> Tensor:
return x.clone().view(shape).clone()
def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
ctx.x_shape = inputs[0].shape
def numpy_view_copy_backward(ctx, grad_out):
return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
def numpy_view_copy_vmap(info, in_dims, x, shape):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
x_shape = x.shape[0]
batch_shape = (x_shape, *shape)
result = numpy_view_copy(x, batch_shape)
return result, 0
numpy_view_copy.register_vmap(numpy_view_copy_vmap)
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
result = make_arg(2, 3, 4, low=0.9, high=2)
yield SampleInput(result, args=([2, 12],))
@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=())
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
np_xs = [to_numpy(x) for x in xs]
np_out = np.concatenate(np_xs, axis=dim)
return torch.tensor(np_out, device=xs[0].device)
@numpy_cat.register_fake
def _(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
return torch.cat(xs, dim=dim)
def numpy_cat_setup_context(ctx, inputs, output):
xs, dim = inputs
ctx.dim_sizes = [x.shape[dim] for x in xs]
ctx.dim = dim
def numpy_cat_backward(ctx, grad_out):
dim_sizes = ctx.dim_sizes
dim = ctx.dim
splits = list(np.cumsum(dim_sizes)[:-1])
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
return grad_xs, None
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
def numpy_cat_vmap(info, in_dims, x, dim):
x_bdim, = in_dims
result = numpy_cat(x, dim)
return result, x_bdim
numpy_cat.register_vmap(numpy_cat_vmap)
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
r0 = make_arg(2, 3, 4, low=0.9, high=2)
r1 = make_arg(4, 3, 4, low=0.9, high=2)
r2 = make_arg(5, 3, 4, low=0.9, high=2)
yield SampleInput([r0, r1, r2], args=(0,))
@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=())
def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]:
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
@numpy_split_copy.register_fake
def _(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
def numpy_split_copy_setup_context(ctx, inputs, output):
_, _, dim = inputs
ctx.dim = dim
def numpy_split_copy_backward(ctx, grad_out):
result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
return result, None, None
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result = numpy_split_copy(x, splits, dim + 1)
return result, 0
numpy_split_copy.register_vmap(numpy_split_copy_vmap)
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
x = make_arg(2, 9, low=0.9, high=2)
yield SampleInput(x, args=([1, 3, 6], 1))
@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=())
def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> Tuple[List[Tensor], int]:
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
@numpy_split_copy_with_int.register_fake
def _(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
_, _, dim = inputs
ctx.dim = dim
def numpy_split_copy_with_int_backward(ctx, grad_out, _):
return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
numpy_split_copy_with_int.register_autograd(
numpy_split_copy_with_int_backward,
setup_context=numpy_split_copy_with_int_setup_context)
def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
return (result, len_split), ([0 for _ in range(len(result))], None)
numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
# Adapted from Ross Girshick's fast-rcnn implementation at
# https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
assert boxes.device == scores.device
device = boxes.device
boxes = to_numpy(boxes)
scores = to_numpy(scores)
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
result = torch.tensor(np.stack(keep), device=device)
# Needed for data-dependent condition :(
assert result.size(0) >= 2
return result
@numpy_nms.register_fake
def _(boxes, scores, iou_threshold):
assert boxes.device == scores.device
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
result = boxes.new_empty([i0], dtype=torch.int64)
return result
def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nms.register_vmap(numpy_nms_vmap)
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
N = 64
xs = make_arg([N], low=0, high=28)
dx = make_arg([N], low=0, high=4)
ys = make_arg([N], low=0, high=28)
dy = make_arg([N], low=0, high=4)
boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
iou_threshold = make_arg([], low=0, high=1).item()
yield SampleInput(boxes, args=(scores, iou_threshold))
custom_op_db = [
OpInfo(
'NumpyCubeCustomOp',
op=numpy_cube._opoverload,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulCustomOp',
op=numpy_mul._opoverload,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulScalarCustomOp',
op=numpy_mul_scalar._opoverload,
sample_inputs_func=sample_inputs_numpy_mul_scalar,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortCustomOp',
op=numpy_sort._opoverload,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyTakeCustomOp',
op=numpy_take._opoverload,
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyNonzeroCustomOp',
op=numpy_nonzero._opoverload,
sample_inputs_func=sample_inputs_numpy_nonzero,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyNMSCustomOp',
op=torch.ops._torch_testing.numpy_nms,
sample_inputs_func=sample_inputs_numpy_nms,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyViewCopyCustomOp',
op=torch.ops._torch_testing.numpy_view_copy,
sample_inputs_func=sample_inputs_numpy_view_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
OpInfo(
'NumpyCatCustomOp',
op=torch.ops._torch_testing.numpy_cat,
sample_inputs_func=sample_inputs_numpy_cat,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyCustomOp',
op=torch.ops._torch_testing.numpy_split_copy,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyWithIntCustomOp',
op=torch.ops._torch_testing.numpy_split_copy_with_int,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
]
# ==============================================================
# some mechanical test cases
# ==============================================================
lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901
lib.define("source0(Tensor x) -> Tensor")
@torch.library.register_fake("_torch_testing::source0", lib=lib)
def _(x):
return x.clone()
lib.define("source1(Tensor x) -> Tensor")
def source1_fake(x):
return x.clone()
torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib)
lib.define("source2(Tensor x) -> Tensor")
@torch.library.register_fake("_torch_testing::source2", lib=lib)
def _(x):
return x.clone()
lib.define("source3(Tensor x) -> Tensor")
def source3_fake(x):
return x.clone()
torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib)
@torch.library.custom_op("_torch_testing::source4", mutates_args=())
def source4(x: Tensor) -> Tensor:
return x.clone()
@source4.register_fake
def _(x):
return x.clone()
@torch.library.custom_op("_torch_testing::source5", mutates_args=())
def source5(x: Tensor) -> Tensor:
return x.clone()
def source5_fake(x):
return x.clone()
source5.register_fake(source5_fake)

View File

@ -0,0 +1,67 @@
# mypy: ignore-errors
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds a tensor with custom metadata and custom method
class ConstantExtraMetadataTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
shape = elem.shape
kwargs = {}
kwargs["strides"] = elem.stride()
kwargs["storage_offset"] = elem.storage_offset()
kwargs["device"] = elem.device
kwargs["layout"] = elem.layout
kwargs["requires_grad"] = elem.requires_grad
kwargs["dtype"] = elem.dtype
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, elem):
self.elem = elem
self.constant_attribute = 4
def __repr__(self):
inner_repr = repr(self.elem)
return f"CustomTensor({inner_repr})"
def __tensor_flatten__(self):
return ["elem"], self.constant_attribute
def add_constant(self, a):
self.constant_attribute += a
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is not None
elem = inner_tensors["elem"]
out = ConstantExtraMetadataTensor(elem)
out.constant_attribute = meta
return out
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, args
)
kwargs_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
)
out_inner = func(*args_inner, **kwargs_inner)
out_inner_flat, spec = pytree.tree_flatten(out_inner)
# for aten ops that return non-tensors, just assume that
# our cust inner tensors return the same value
out_flat = [
ConstantExtraMetadataTensor(o_inner)
if isinstance(o_inner, torch.Tensor)
else o_inner
for o_inner in out_inner_flat
]
out = pytree.tree_unflatten(out_flat, spec)
return return_and_correct_aliasing(func, args, kwargs, out)

View File

@ -0,0 +1 @@
# mypy: ignore-errors

View File

@ -0,0 +1,10 @@
# mypy: ignore-errors
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(10, 20)

View File

@ -0,0 +1,11 @@
# mypy: ignore-errors
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(10, 20)
self.relu = nn.ReLU()

View File

@ -0,0 +1,200 @@
# mypy: ignore-errors
import re
import sys
import time
from functools import partial, wraps
from typing import Tuple
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc import _rref_context_get_debug_info
from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN
if not dist.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}"
def dist_init(
old_test_method=None,
setup_rpc: bool = True,
clean_shutdown: bool = True,
faulty_messages=None,
messages_to_delay=None,
):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
Note: pass the string representation of MessageTypes that should be used
with the faulty agent's send function. By default, all retriable messages
("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE",
"CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is
set from faulty_rpc_agent_test_fixture.py).
"""
# If we use dist_init without arguments (ex: @dist_init), old_test_method is
# appropriately set and we return the wrapper appropriately. On the other
# hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
# old_test_method is None and we return a functools.partial which is the real
# decorator that is used and as a result we recursively call dist_init with
# old_test_method and the rest of the arguments appropriately set.
if old_test_method is None:
return partial(
dist_init,
setup_rpc=setup_rpc,
clean_shutdown=clean_shutdown,
faulty_messages=faulty_messages,
messages_to_delay=messages_to_delay,
)
@wraps(old_test_method)
def new_test_method(self, *arg, **kwargs):
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
# in tests.
import torch.distributed.rpc.api as api
api._ignore_rref_leak = False
self.worker_id = self.rank
self.setup_fault_injection(faulty_messages, messages_to_delay)
rpc_backend_options = self.rpc_backend_options
if setup_rpc:
if TEST_WITH_TSAN:
# TSAN runs much slower.
rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5
rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc_backend_options,
)
return_value = old_test_method(self, *arg, **kwargs)
if setup_rpc:
rpc.shutdown(graceful=clean_shutdown)
return return_value
return new_test_method
def noop() -> None:
pass
def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
"""
Loops until an RPC to the given rank fails. This is used to
indicate that the node has failed in unit tests.
Args:
rank (int): Rank of the node expected to fail
expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
occurs, not just any.
"""
while True:
try:
rpc.rpc_sync(f"worker{rank}", noop, args=())
time.sleep(0.1)
except Exception as e:
if re.search(pattern=expected_error_regex, string=str(e)):
return str(e)
def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
"""
The RRef protocol holds forkIds of rrefs in a map until those forks are
confirmed by the owner. The message confirming the fork may arrive after
our tests check whether this map is empty, which leads to failures and
flaky tests. to_here also does not guarantee that we have finished
processind the owner's confirmation message for the RRef. This function
loops until the map is empty, which means the messages have been received
as processed. Call this function before asserting the map returned by
_get_debug_info is empty.
"""
start = time.time()
while True:
debug_info = _rref_context_get_debug_info()
num_pending_futures = int(debug_info["num_pending_futures"])
num_pending_users = int(debug_info["num_pending_users"])
if num_pending_futures == 0 and num_pending_users == 0:
break
time.sleep(0.1)
if time.time() - start > timeout:
raise ValueError(
f"Timed out waiting to flush pending futures and users, "
f"had {num_pending_futures} pending futures and {num_pending_users} pending users"
)
def get_num_owners_and_forks() -> Tuple[str, str]:
"""
Retrieves number of OwnerRRefs and forks on this node from
_rref_context_get_debug_info.
"""
rref_dbg_info = _rref_context_get_debug_info()
num_owners = rref_dbg_info["num_owner_rrefs"]
num_forks = rref_dbg_info["num_forks"]
return num_owners, num_forks
def wait_until_owners_and_forks_on_rank(
num_owners: int, num_forks: int, rank: int, timeout: int = 20
) -> None:
"""
Waits until timeout for num_forks and num_owners to exist on the rank. Used
to ensure proper deletion of RRefs in tests.
"""
start = time.time()
while True:
num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync(
worker_name(rank), get_num_owners_and_forks, args=(), timeout=5
)
num_owners_on_rank = int(num_owners_on_rank)
num_forks_on_rank = int(num_forks_on_rank)
if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks:
return
time.sleep(1)
if time.time() - start > timeout:
raise ValueError(
f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank,"
f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks"
)
def initialize_pg(init_method, rank: int, world_size: int) -> None:
# This is for tests using `dist.barrier`.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=rank,
world_size=world_size,
)
def worker_name(rank: int) -> str:
return f"worker{rank}"
def get_function_event(function_events, partial_event_name):
"""
Returns the first event that matches partial_event_name in the provided
function_events. These function_events should be the output of
torch.autograd.profiler.function_events().
Args:
function_events: function_events returned by the profiler.
event_name (str): partial key that the event was profiled with.
"""
event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015
return event

View File

@ -0,0 +1 @@
# mypy: allow-untyped-defs

View File

@ -0,0 +1,98 @@
# mypy: allow-untyped-defs
import sys
from functools import wraps, partial
import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS,
tp_transports,
)
TEST_GPU_NUM = 4
class ShardedTensorTestBase(MultiProcessTestCase):
@property
def world_size(self):
return TEST_GPU_NUM
def init_pg(self, backend="nccl"):
if backend not in ["nccl", "gloo", "mpi"]:
raise RuntimeError(f"Backend {backend} not supported!")
dist.init_process_group(
backend=backend,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if backend == "nccl":
torch.cuda.set_device(self.rank)
def init_rpc(self):
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports())
rpc_backend_options.init_method = f"file://{self.file_name}"
for rank in range(self.world_size):
rpc_backend_options.set_device_map(
f"worker{rank}", {rank: self.rank, self.rank: rank}
)
rpc.init_rpc(
name="worker%d" % self.rank,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc_backend_options,
)
def init_comms(self, init_rpc=True, backend="nccl"):
if init_rpc:
self.init_rpc()
self.init_pg(backend=backend)
def destroy_comms(self, destroy_rpc=True):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
if destroy_rpc:
rpc.shutdown()
dist.destroy_process_group()
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def assert_sharded_tensor_equal(self, st1, st2):
st1_local_shards = st1.local_shards()
st2_local_shards = st2.local_shards()
self.assertEqual(len(st1_local_shards), len(st2_local_shards))
for i, st1_local_shard in enumerate(st1_local_shards):
self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
self.assertEqual(st1.metadata(), st2.metadata())
self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
# wrapper to initialize comms (processgroup + rpc)
def with_comms(func=None, init_rpc=True, backend="nccl"):
if func is None:
return partial(
with_comms,
init_rpc=init_rpc,
backend=backend,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.init_comms(init_rpc=init_rpc, backend=backend)
func(self, *args, **kwargs)
self.destroy_comms(destroy_rpc=init_rpc)
return wrapper

View File

@ -0,0 +1,136 @@
# mypy: allow-untyped-defs
import builtins
import torch
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata,
)
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
def generate_chunk_sharding_specs_for_test(sharding_dim):
return [
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
),
# Test different ordering. (Case 1)
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:2/cuda:2",
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
),
# Test different ordering. (Case 2)
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
],
),
]
def generate_enumerable_sharding_specs_for_test():
return [
EnumerableShardingSpec(
[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[5, 5],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[5, 0],
shard_sizes=[5, 5],
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[0, 5],
shard_sizes=[5, 5],
placement="rank:2/cuda:2",
),
ShardMetadata(
shard_offsets=[5, 5],
shard_sizes=[5, 5],
placement="rank:3/cuda:3",
),
]
)
]
def generate_local_weight_sharding_params_for_test(
local_weight, sharded_dim, gpu_num, spec, rank
):
"""
Shard the local weight based the given spec, so we can compare against
the one from sharded tensor.
Args:
local_weight: weight matrix to be sharded.
sharded_dim: The dimension which we shard on.
gpu_num: number of ranks.
spec: sharding spec.
rank: # of cuda process.
Returns:
start_pos: start position of sharded weight on the given rank.
chunk_size: chunk size of sharded weight on the given rank.
"""
sharding_dim_size = local_weight.size(sharded_dim)
split_size = get_split_size(sharding_dim_size, gpu_num)
current_offsets = 0
start_pos = current_offsets
for idx, placement in enumerate(spec.placements):
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
if rank == placement.rank():
start_pos = current_offsets
break
current_offsets += chunk_size
return start_pos, chunk_size
def clone_module_parameter(module, param_name):
"""
Clone a parameter from a given existing module.
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
param_name (str): Name of the parameter of ``module`` that needs to be cloned.
Returns: cloned tensor as :class:`torch.nn.Parameter`.
"""
tensor = getattr(module, param_name)
return torch.nn.Parameter(tensor.detach().clone())
def gen_binary_op_func(python_op, inplace=False):
src_lines = ['def f(lhs, rhs):']
if "torch" in python_op:
src_lines.append(f' return {python_op}(lhs, rhs)\n')
elif inplace:
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
else:
src_lines.append(f' return lhs {python_op} rhs\n')
code_str = '\n'.join(src_lines)
g = {'torch': torch}
builtins.exec(code_str, g)
return g["f"]

View File

@ -0,0 +1,66 @@
# mypy: allow-untyped-defs
import copy
import random
import torch
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
PLACEMENTS = [
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
]
DEFAULT_GPU_NUM = 4
def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
spec_list = []
for i in range(len(sharding_dims)):
random.Random(seed + i).shuffle(PLACEMENTS)
spec_list.append(
ChunkShardingSpec(
dim=sharding_dims[i],
placements=copy.deepcopy(PLACEMENTS),
)
)
return spec_list
class MyShardedModel2(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor2 = sharded_tensor.rand(
spec, 10, 20, process_group=group, init_rrefs=init_rrefs
)
else:
self.sharded_tensor2 = None
self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
class MyShardedModel1(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor1 = sharded_tensor.rand(
spec, 10, 20, process_group=group, init_rrefs=init_rrefs
)
else:
self.sharded_tensor1 = None
self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
self.submodule = MyShardedModel2(spec, group, init_rrefs)

View File

@ -0,0 +1,42 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
class SimpleMegatronLM(nn.Module):
def __init__(self, linear_size, rank=None, dtype=torch.float32):
super().__init__()
self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
if rank is not None:
self.fc1.cuda(rank)
self.fc2.cuda(rank)
def forward(self, inp):
return self.fc2(self.gelu(self.fc1(inp)))
def get_weights(self):
if isinstance(self.fc1.weight, ShardedTensor):
weight1 = self.fc1.weight.local_tensor()
else:
weight1 = self.fc1.weight
if isinstance(self.fc2.weight, ShardedTensor):
weight2 = self.fc2.weight.local_tensor()
else:
weight2 = self.fc2.weight
return (weight1, weight2)
def get_biases(self):
return (self.fc1.bias, self.fc2.bias)
def get_weight_grads(self):
return (self.fc1.weight.grad, self.fc2.weight.grad)
def get_bias_grads(self):
return (self.fc1.bias.grad, self.fc2.bias.grad)

Some files were not shown because too many files have changed in this diff Show More