I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,474 @@
# mypy: ignore-errors
import collections
import torch
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.testing._internal.common_utils import TestCase
class AutocastTestLists:
def _rnn_cell_args(self, n, num_chunks, is_lstm, dev, dtype):
input = (torch.randn((n, n), device=dev, dtype=torch.float32),)
hx = ((torch.randn((n, n), device=dev, dtype=torch.float32),
torch.randn((n, n), device=dev, dtype=torch.float32)) if is_lstm else
torch.randn((n, n), device=dev, dtype=torch.float32),)
weights = (torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_ih
torch.randn((num_chunks * n, n), device=dev, dtype=torch.float32), # weight_hh
torch.randn((num_chunks * n), device=dev, dtype=torch.float32), # bias_ih
torch.randn((num_chunks * n), device=dev, dtype=torch.float32)) # bias_hh
# returns args as a tuple
return input + hx + weights
# Supplies ops and arguments for test_autocast_* in test/test_cuda.py
def __init__(self, dev):
super().__init__()
n = 8
# Utility arguments, created as one-element tuples
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise2_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
mat0_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
mat1_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
mat2_fp16 = (torch.randn((n, n), dtype=torch.float16, device=dev),)
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]
bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
# The lists below organize ops that autocast needs to test.
# self.list_name corresponds to test_autocast_list_name in test/test_cuda.py.
# Each op is associated with a tuple of valid arguments.
# In addition, cudnn conv ops are not supported on ROCm and hence will
# be skipped by passing TEST_WITH_ROCM flag to those ops in self.torch_fp16 list.
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("cat", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
("equal", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("stack", (pointwise0_fp16 + pointwise1_fp32,), torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_fp16 = [
# deprecated _convolution
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
(0, 0), 1, False, True, True)),
# the current _convolution
("_convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False,
(0, 0), 1, False, True, True, True)),
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("conv_tbc", conv_args_fp32[0] + bias_fp32),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("convolution", conv_args_fp32[1] + bias_fp32 + ((1, 1), (0, 0), (1, 1), False, (0, 0), 1)),
("cudnn_convolution", conv_args_fp32[1] + ((0, 0), (1, 1), (1, 1), 1, False, True, True), TEST_WITH_ROCM),
("cudnn_convolution_transpose", conv_args_fp32[1] + ((0, 0), (0, 0), (1, 1),
(1, 1), 1, False, True, True), TEST_WITH_ROCM),
("prelu", pointwise0_fp32 + element0_fp32),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("addmv", pointwise0_fp32 + mat2_fp32 + pointwise1_fp32),
("addr", mat0_fp32 + pointwise0_fp32 + pointwise1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("einsum", "bkhd,bqhd->bqkh", mat0_fp32 + mat1_fp32),
("mm", mat0_fp32 + mat1_fp32),
("mv", mat0_fp32 + pointwise0_fp32),
("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
# _thnn_fused_lstm_cell and _thnn_fused_gru_cell are not Python-exposed as far as I can tell.
# ("_thnn_fused_lstm_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
# ("_thnn_fused_gru_cell", mat0_fp32 + mat1_fp32 + mat2_fp32 + pointwise0_fp32 + pointwise1_fp32),
("lstm_cell", self._rnn_cell_args(n, num_chunks=4, is_lstm=True, dev=dev, dtype=torch.float32)),
("gru_cell", self._rnn_cell_args(n, num_chunks=3, is_lstm=False, dev=dev, dtype=torch.float32)),
("rnn_tanh_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
("rnn_relu_cell", self._rnn_cell_args(n, num_chunks=1, is_lstm=False, dev=dev, dtype=torch.float32)),
]
self.torch_fp32 = [
("acos", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
("asin", (pointwise0_fp16[0].clamp(-.9, 0.9),)),
("cosh", pointwise0_fp16),
("erfinv", (pointwise0_fp16[0].clamp(-.9, .9),)),
("exp", pointwise0_fp16),
("expm1", pointwise0_fp16),
("log", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log10", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log2", (pointwise0_fp16[0].clamp(0.1, 100.0),)),
("log1p", (pointwise0_fp16[0].clamp(-0.9, 100.0),)),
("reciprocal", pointwise0_fp16),
("rsqrt", (pointwise0_fp16[0].clamp(0.0, 100.0),)),
("sinh", pointwise0_fp16),
("tan", (pointwise0_fp16[0].clamp(-3.1 / 2, 3.1 / 2),)),
("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_fp16),
("pow", ((pointwise0_fp16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
# ("pow", (1.7,) + pointwise0_fp16), # This variant has a backend, but is not documented in the API.
("softmax", pointwise0_fp16 + (0,)),
("log_softmax", pointwise0_fp16 + (0,)),
("layer_norm", pointwise0_fp16 + ((pointwise0_fp16[0].numel(),),)),
("group_norm", mat0_fp16 + (1,)),
("norm", pointwise0_fp16),
("norm", pointwise0_fp16, {"dim": 0}),
# these need magma
# ("norm", mat0_fp16, {"p": "nuc"}),
# ("norm", mat0_fp16, {"p": "nuc", "dim": 0}),
("norm", pointwise0_fp16, {"p": 1}),
("norm", pointwise0_fp16, {"p": 1, "dim": 0}),
("cosine_similarity", mat0_fp16 + mat1_fp16),
("poisson_nll_loss", mat0_fp16 + mat1_fp16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.float16),
torch.tensor([[1, 3, 4]], device=dev, dtype=torch.float16),
torch.tensor([1], device=dev, dtype=torch.int))),
("hinge_embedding_loss", mat0_fp16 + (torch.ones(n, device=dev, dtype=torch.int),)),
("kl_div", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
("margin_ranking_loss", mat0_fp16 + mat1_fp16 + (torch.ones((n,), device=dev, dtype=torch.float16),)),
("triplet_margin_loss", mat0_fp16 + mat1_fp16 + mat2_fp16),
("binary_cross_entropy_with_logits", mat0_fp16 + (torch.rand((n, n), device=dev, dtype=torch.float16),)),
("cumprod", pointwise0_fp16 + (0,)),
("cumsum", pointwise0_fp16 + (0,)),
("dist", pointwise0_fp16 + pointwise1_fp16),
("pdist", mat0_fp16),
("cdist", mat0_fp16 + mat1_fp16),
("prod", pointwise0_fp16),
("prod", pointwise0_fp16 + (0,)),
("renorm", mat0_fp16 + (2, 0, 1.0)),
("sum", pointwise0_fp16),
("sum", mat0_fp16 + (1,)),
("logsumexp", mat0_fp16 + (1,)),
]
self.torch_need_autocast_promote = [
("addcdiv", pointwise0_fp32 + pointwise1_fp16 + (pointwise2_fp16[0].clamp(0.1, 100),)),
("addcmul", pointwise0_fp32 + pointwise1_fp16 + pointwise2_fp16),
("atan2", pointwise0_fp32 + (pointwise1_fp16[0].clamp(0.1, 100),)),
("bilinear", (torch.randn((1, 2), dtype=torch.float16, device=dev),
torch.randn((1, 2), dtype=torch.float32, device=dev),
torch.randn((1, 2, 2), dtype=torch.float16, device=dev),
torch.randn((1,), dtype=torch.float32, device=dev))),
("cross", (torch.randn(3, dtype=torch.float32, device=dev),
torch.randn(3, dtype=torch.float16, device=dev))),
("dot", pointwise0_fp16 + pointwise1_fp32),
("vdot", pointwise0_fp16 + pointwise1_fp32),
("grid_sampler", (torch.randn((2, 3, 33, 22), dtype=torch.float16, device=dev),
torch.randn((2, 22, 11, 2), dtype=torch.float32, device=dev),
0, 0, False)),
("index_put", pointwise0_fp32 + ((torch.tensor([1], device=dev, dtype=torch.long),),
torch.randn(1, device=dev, dtype=torch.float16))),
("index_put", pointwise0_fp16 + ((torch.tensor([1], device=dev, dtype=torch.long),),
torch.randn(1, device=dev, dtype=torch.float32))),
("tensordot", (torch.randn((2, 2, 2), dtype=torch.float32, device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float32, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float16, device=dev))),
("scatter_add", (torch.zeros(2, 2, 2, dtype=torch.float16, device=dev),
0,
torch.randint(0, 2, (2, 2, 2), device=dev),
torch.randn((2, 2, 2), dtype=torch.float32, device=dev))),
]
self.nn_fp16 = [
("linear", mat0_fp32 + mat1_fp32 + mat2_fp32),
]
self.nn_fp32 = [
("softplus", pointwise0_fp16),
("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.float),
torch.zeros((n,), device=dev, dtype=torch.long))),
("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.half),
torch.zeros((n, n, n), device=dev, dtype=torch.long))),
("l1_loss", mat0_fp16 + mat1_fp16),
("smooth_l1_loss", mat0_fp16 + mat1_fp16),
("mse_loss", mat0_fp16 + mat1_fp16),
("multilabel_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
self.linalg_fp16 = [
("linalg_vecdot", mat0_fp32 + mat0_fp32),
("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
]
self.methods_fp16 = [
("__matmul__", mat0_fp32 + mat1_fp32)
]
self.methods_fp32 = [
("__pow__", (torch.rand(n, device=dev, dtype=torch.float16), 1.5)),
]
self.banned = [
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
]
class AutocastCPUTestLists:
# Supplies ops and arguments for test_autocast_* in test/test_cpu.py
def __init__(self, dev):
super().__init__()
n = 8
# Utility arguments, created as one-element tuples
pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
pointwise0_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
pointwise1_fp16 = (torch.randn(n, dtype=torch.float16, device=dev),)
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
for dimset in dummy_dimsets]
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),
torch.randn(dimset, dtype=torch.bfloat16, device=dev))
for dimset in dimsets]
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]
bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),)
for dimset in dummy_dimsets]
# The lists below organize ops that autocast needs to test.
# self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
# Each op is associated with a tuple of valid arguments.
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, pointwise0_fp32 + pointwise1_fp16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("matmul", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32),
torch.randn((5, 3, 5), device=dev, dtype=torch.float32),
torch.randn(5, device=dev, dtype=torch.float32),
0)),
("conv_transpose1d", conv_args_fp32[0]),
("conv_transpose2d", conv_args_fp32[1]),
("conv_transpose3d", conv_args_fp32[2]),
("prelu", pointwise0_fp32 + element0_fp32),
("_native_multi_head_attention", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
n, 4, torch.randn((3 * n, n), device=dev, dtype=torch.float32),
torch.randn((3 * n), device=dev, dtype=torch.float32),
torch.randn((n, n), device=dev, dtype=torch.float32),
torch.randn((n), device=dev, dtype=torch.float32))),
]
self.torch_fp32 = [
("poisson_nll_loss", mat0_bf16 + mat1_bf16 + (True, False, 1.e-8, torch.nn._reduction.get_enum('mean'))),
("cosine_embedding_loss", (torch.tensor([[1, 2, 3]], device=dev, dtype=torch.bfloat16),
torch.tensor([[1, 3, 4]], device=dev, dtype=torch.bfloat16),
torch.tensor([1], device=dev, dtype=torch.int))),
("hinge_embedding_loss", mat0_bf16 + (torch.ones(n, device=dev, dtype=torch.int),)),
("margin_ranking_loss", mat0_bf16 + mat1_bf16 + (torch.ones((n,), device=dev, dtype=torch.bfloat16),)),
("triplet_margin_loss", mat0_bf16 + mat1_bf16 + mat2_bf16),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
]
self.nn_16 = [
("linear", mat0_fp32 + mat1_fp32, {}),
]
self.nn_fp32 = [
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
(torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
("nll_loss", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),
torch.zeros((n,), device=dev, dtype=torch.long))),
("nll_loss2d", (torch.rand((n, n, n, n), device=dev, dtype=torch.bfloat16),
torch.zeros((n, n, n), device=dev, dtype=torch.long))),
("l1_loss", mat0_bf16 + mat1_bf16),
("smooth_l1_loss", mat0_bf16 + mat1_bf16),
("mse_loss", mat0_bf16 + mat1_bf16),
("multilabel_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("soft_margin_loss", mat0_bf16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_bf16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
("huber_loss", mat0_bf16 + mat1_bf16),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)),
]
class TestAutocast(TestCase):
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
def _run_autocast_outofplace(
self,
op,
args,
run_as_type,
device,
out_type=None,
module=torch,
add_kwargs=None,
amp_dtype=torch.bfloat16,
):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_enabled(device_type=device))
with torch.amp.autocast(device_type=device, dtype=amp_dtype):
self.assertTrue(torch.is_autocast_enabled(device_type=device))
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(
out_type == output.dtype,
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}",
)
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(
out_type == output_method.dtype,
f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}",
)
self.assertTrue(
(output is not None) or (output_method is not None),
f"{op} not found as an attribute on either Tensor or the requested module {module}",
)
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(
comparison, f"torch.{op} result did not match Tensor.{op} result"
)
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.amp.autocast(device_type=device, enabled=False):
self.assertFalse(
torch.is_autocast_enabled(device_type=device)
)
if module is not None and hasattr(module, op):
control = getattr(module, op)(
*cast(args, run_as_type), **add_kwargs
)
else:
control = getattr(args[0].to(run_as_type), op)(
*cast(args[1:], run_as_type), **add_kwargs
)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, f"torch.{op} result did not match control")
self.assertTrue(torch.is_autocast_enabled(device_type=device))
self.assertFalse(torch.is_autocast_enabled(device_type=device))

View File

@ -0,0 +1,635 @@
# mypy: ignore-errors
import torch
from functools import partial
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
# Note: [autograd.Function db]
#
# This is a collection of autograd.Function test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if a subsystem
# supports autograd.Function.
#
# Axes:
# - saves {output, input, intermediate, non-tensor}
# - {inputs, output} x {single tensor, tensors, arbitrary objects}
# - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpyCube(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
return torch.tensor(input_np ** 3, device=input.device), dinput
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(inputs[0], output[1])
ctx.save_for_forward(inputs[0], output[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
@staticmethod
def vmap(info, in_dims, input):
result = NumpyCube.apply(input)
return result, (in_dims[0], in_dims[0])
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
class CubeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x):
return x ** 3, 3 * x ** 2
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(inputs[0], outputs[1])
ctx.save_for_forward(inputs[0], outputs[1])
@staticmethod
def backward(ctx, grad_output, grad_saved):
input, dinput = ctx.saved_tensors
result = grad_output * dinput + 6 * dinput
return result
@staticmethod
def jvp(ctx, input_tangent):
input, dinput = ctx.saved_tensors
return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(1, low=0.8, high=2), args=())
class NumpyCubeNotComposable(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = to_numpy(input)
return torch.tensor(input_np ** 3, device=input.device), input_np
@staticmethod
def setup_context(ctx, inputs, output):
_, input_np = output
ctx.input_np = input_np
ctx.device = inputs[0].device
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output, grad_saved):
result_np = 3 * (ctx.input_np ** 2)
return torch.tensor(result_np, device=ctx.device)
class NumpyMul(torch.autograd.Function):
@staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@staticmethod
def setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = NumpyMul.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
@staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Broadcasting
yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14})
class MulGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x * y
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = MulGenVmap.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = MulGenVmap.apply(grad_output, x)
return gx, gy
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
class NumpyExp_(torch.autograd.Function):
@staticmethod
def forward(x):
x_np = to_numpy(x)
np.exp(x_np, x_np)
return x
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(output)
ctx.save_for_forward(output)
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return NumpyMul.apply(grad_output, output)
@staticmethod
def vmap(info, in_dims, x):
NumpyExp_.apply(x)
return x, in_dims[0]
@staticmethod
def jvp(ctx, x_tangent):
# Doesn't call numpy operations because I didn't want to write NumpyMul_
output, = ctx.saved_tensors
x_tangent.mul_(output)
return x_tangent
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(x, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
# wrap dim
dim = dim if dim >= 0 else dim + x.dim() - 1
return NumpySort.apply(x, dim + 1), (0, 0, 0)
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
class SortGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, dim):
device = x.device
ind = torch.argsort(x, dim=dim)
ind_inv = torch.argsort(ind, axis=dim)
result = torch.take_along_dim(x, ind, dim=dim)
return result, ind, ind_inv
@staticmethod
def setup_context(ctx, inputs, outputs):
x, dim = inputs
_, ind, ind_inv = outputs
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
@staticmethod
def jvp(ctx, x_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(1,))
def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
tensor = make_arg(3, 5)
dim = 1
_, ind, ind_inv = NumpySort.apply(tensor, 1)
yield SampleInput(tensor, args=(ind, ind_inv, dim))
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
assert ind_tangent is None
assert ind_inv_tangent is None
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
class TakeGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, ind, ind_inv, dim):
return torch.take_along_dim(x, ind, dim)
@staticmethod
def setup_context(ctx, inputs, outputs):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.save_for_forward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
ind, ind_inv = ctx.saved_tensors
return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
class Select(torch.autograd.Function):
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return Select.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return Select.apply(x_tangent, ctx.idx)
class SelectGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, idx):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, outputs):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def jvp(ctx, x_tangent, _):
return SelectGenVmap.apply(x_tangent, ctx.idx)
def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5), args=(2,))
class ScaleGradGenVmap(torch.autograd.Function):
generate_vmap_rule = True
scale = 3.14
@staticmethod
def forward(x):
return x.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, grad_output):
return grad_output * ScaleGradGenVmap.scale
@staticmethod
def jvp(ctx, x_tangent):
return x_tangent * ScaleGradGenVmap.scale
class ZeroGradientsGenVmap(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x, y):
return x.clone(), y.clone()
@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@staticmethod
def backward(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
# Intentionally too-large gradient
torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
@staticmethod
def jvp(ctx, gx, gy):
# Intentionally returning torch.zeros instead of zeros_like or new_zeros.
# Also intentionally not None.
return (
torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
)
def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
yield SampleInput(make_arg(3, 5))
class ForwardHasDefaultArgs(torch.autograd.Function):
@staticmethod
def forward(x, idx=(2,)):
return x[idx]
@staticmethod
def setup_context(ctx, inputs, output):
x, idx = inputs
ctx.x_shape = x.shape
ctx.idx = idx
@staticmethod
def backward(ctx, grad_output):
result = grad_output.new_zeros(ctx.x_shape)
result[ctx.idx] = grad_output
return result, None
@staticmethod
def vmap(info, in_dims, x, idx):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 1)
return ForwardHasDefaultArgs.apply(x, idx), 0
@staticmethod
def jvp(ctx, x_tangent, _):
return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
autograd_function_db = [
OpInfo(
'NumpyCubeAutogradFunction',
op=NumpyCube.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyExpMarkDirtyAutogradFunction',
op=lambda x: NumpyExp_.apply(x.clone()),
inplace_variant=NumpyExp_.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulAutogradFunction',
op=NumpyMul.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyCubeNotComposableAutogradFunction',
op=lambda x: NumpyCubeNotComposable.apply(x)[0],
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortAutogradFunction',
op=NumpySort.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'NumpyTakeAutogradFunction',
op=NumpyTake.apply,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SelectAutogradFunction',
op=Select.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'CubeGenVmapAutogradFunction',
op=CubeGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'MulGenVmapAutogradFunction',
op=MulGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'SortGenVmapAutogradFunction',
op=SortGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
gradcheck_wrapper=lambda y, ind: y,
),
OpInfo(
'SelectGenVmapAutogradFunction',
op=SelectGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_select,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ScaleGradGenVmapAutogradFunction',
op=ScaleGradGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ZeroGradientsGenVmapAutogradFunction',
op=ZeroGradientsGenVmap.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'ForwardHasDefaultArgsAutogradFunction',
op=ForwardHasDefaultArgs.apply,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_forward_default_args,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
]

View File

@ -0,0 +1,165 @@
# mypy: ignore-errors
import os
import re
import sys
from typing import List
__all__ = [
"check_code_for_cuda_kernel_launches",
"check_cuda_kernel_launches",
]
# FILES TO EXCLUDE (match is done with suffix using `endswith`)
# You wouldn't drive without a seatbelt, though, so why would you
# launch a kernel without some safety? Use this as a quick workaround
# for a problem with the checker, fix the checker, then de-exclude
# the files in question.
exclude_files: List[str] = []
# Without using a C++ AST we can't 100% detect kernel launches, so we
# model them as having the pattern "<<<parameters>>>(arguments);"
# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
# the next statement.
#
# We model the next statement as ending at the next `}` or `;`.
# If we see `}` then a clause ended (bad) if we see a semi-colon then
# we expect the launch check just before it.
#
# Since the kernel launch can include lambda statements, it's important
# to find the correct end-paren of the kernel launch. Doing this with
# pure regex requires recursive regex, which aren't part of the Python
# standard library. To avoid an additional dependency, we build a prefix
# regex that finds the start of a kernel launch, use a paren-matching
# algorithm to find the end of the launch, and then another regex to
# determine if a launch check is present.
# Finds potential starts of kernel launches
kernel_launch_start = re.compile(
r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
)
# This pattern should start at the character after the final paren of the
# kernel launch. It returns a match if the launch check is not the next statement
has_check = re.compile(
r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
)
def find_matching_paren(s: str, startpos: int) -> int:
"""Given a string "prefix (unknown number of characters) suffix"
and the position of the first `(` returns the index of the character
1 past the `)`, accounting for paren nesting
"""
opening = 0
for i, c in enumerate(s[startpos:]):
if c == '(':
opening += 1
elif c == ')':
opening -= 1
if opening == 0:
return startpos + i + 1
raise IndexError("Closing parens not found!")
def should_exclude_file(filename) -> bool:
for exclude_suffix in exclude_files:
if filename.endswith(exclude_suffix):
return True
return False
def check_code_for_cuda_kernel_launches(code, filename=None):
"""Checks code for CUDA kernel launches without cuda error checks.
Args:
filename - Filename of file containing the code. Used only for display
purposes, so you can put anything here.
code - The code to check
Returns:
The number of unsafe kernel launches in the code
"""
if filename is None:
filename = "##Python Function Call##"
# We break the code apart and put it back together to add
# helpful line numberings for identifying problem areas
code = enumerate(code.split("\n")) # Split by line breaks
code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines
code = '\n'.join(code) # Put it back together
num_launches_without_checks = 0
for m in kernel_launch_start.finditer(code):
end_paren = find_matching_paren(code, m.end() - 1)
if has_check.match(code, end_paren):
num_launches_without_checks += 1
context = code[m.start():end_paren + 1]
print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
return num_launches_without_checks
def check_file(filename):
"""Checks a file for CUDA kernel launches without cuda error checks
Args:
filename - File to check
Returns:
The number of unsafe kernel launches in the file
"""
if not (filename.endswith((".cu", ".cuh"))):
return 0
if should_exclude_file(filename):
return 0
with open(filename) as fo:
contents = fo.read()
unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
return unsafeCount
def check_cuda_kernel_launches():
"""Checks all pytorch code for CUDA kernel launches without cuda error checks
Returns:
The number of unsafe kernel launches in the codebase
"""
torch_dir = os.path.dirname(os.path.realpath(__file__))
torch_dir = os.path.dirname(torch_dir) # Go up to parent torch
torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2
kernels_without_checks = 0
files_without_checks = []
for root, dirnames, filenames in os.walk(torch_dir):
# `$BASE/build` and `$BASE/torch/include` are generated
# so we don't want to flag their contents
if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
# Curtail search by modifying dirnames and filenames in place
# Yes, this is the way to do this, see `help(os.walk)`
dirnames[:] = []
continue
for x in filenames:
filename = os.path.join(root, x)
file_result = check_file(filename)
if file_result > 0:
kernels_without_checks += file_result
files_without_checks.append(filename)
if kernels_without_checks > 0:
count_str = f"Found {kernels_without_checks} instances in " \
f"{len(files_without_checks)} files where kernel " \
"launches didn't have checks."
print(count_str, file=sys.stderr)
print("Files without checks:", file=sys.stderr)
for x in files_without_checks:
print(f"\t{x}", file=sys.stderr)
print(count_str, file=sys.stderr)
return kernels_without_checks
if __name__ == "__main__":
unsafe_launches = check_cuda_kernel_launches()
sys.exit(0 if unsafe_launches == 0 else 1)

View File

@ -0,0 +1 @@
# mypy: ignore-errors

View File

@ -0,0 +1,301 @@
# mypy: ignore-errors
r"""This file is allowed to initialize CUDA context when imported."""
import functools
import torch
import torch.cuda
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
import inspect
import contextlib
import os
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
if TEST_WITH_ROCM:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
else:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
def CDNA2OrLater():
if TEST_WITH_ROCM:
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
return False
def evaluate_gfx_arch_exact(matching_arch):
if not torch.cuda.is_available():
return False
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
return arch == matching_arch
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
return False
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
if TEST_CUDA:
return True
return False
def evaluate_platform_supports_cudnn_attention():
return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000)
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_cudnn_attention())
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or
PLATFORM_SUPPORTS_CUDNN_ATTENTION or
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip:
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
else:
return SM90OrLater or torch.cuda.get_device_capability() == (8, 9)
return False
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
if TEST_NUMBA:
try:
import numba.cuda
TEST_NUMBA_CUDA = numba.cuda.is_available()
except Exception as e:
TEST_NUMBA_CUDA = False
TEST_NUMBA = False
else:
TEST_NUMBA_CUDA = False
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
# RNG have been initialized.
__cuda_ctx_rng_initialized = False
# after this call, CUDA context and RNG must have been initialized on each GPU
def initialize_cuda_context_rng():
global __cuda_ctx_rng_initialized
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
if not __cuda_ctx_rng_initialized:
# initialize cuda context and rng for memory tests
for i in range(torch.cuda.device_count()):
torch.randn(1, device=f"cuda:{i}")
__cuda_ctx_rng_initialized = True
# Test whether hardware TF32 math mode enabled. It is enabled only on:
# - CUDA >= 11
# - arch >= Ampere
def tf32_is_not_fp32():
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split('.')[0]) < 11:
return False
return True
@contextlib.contextmanager
def tf32_off():
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
try:
torch.backends.cuda.matmul.allow_tf32 = False
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
torch.backends.cuda.matmul.allow_tf32 = True
self.precision = tf32_precision
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
# This is a wrapper that wraps a test to run this test twice, one with
# allow_tf32=True, another with allow_tf32=False. When running with
# allow_tf32=True, it will use reduced precision as specified by the
# argument. For example:
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
# @tf32_on_and_off(0.005)
# def test_matmul(self, device, dtype):
# a = ...; b = ...;
# c = torch.matmul(a, b)
# self.assertEqual(c, expected)
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
# precision to check values.
#
# This decorator can be used for function with or without device/dtype, such as
# @tf32_on_and_off(0.005)
# def test_my_op(self)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device)
# @tf32_on_and_off(0.005)
# def test_my_op(self, device, dtype)
# @tf32_on_and_off(0.005)
# def test_my_op(self, dtype)
# if neither device nor dtype is specified, it will check if the system has ampere device
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
def tf32_on_and_off(tf32_precision=1e-5):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()
def with_tf32_enabled(self, function_call):
with tf32_on(self, tf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
for k, v in zip(arg_names, args):
kwargs[k] = v
cond = tf32_is_not_fp32()
if 'device' in kwargs:
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
if 'dtype' in kwargs:
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
if cond:
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper
# This is a wrapper that wraps a test to run it with TF32 turned off.
# This wrapper is designed to be used when a test uses matmul or convolutions
# but the purpose of that test is not testing matmul or convolutions.
# Disabling TF32 will enforce torch.float tensors to be always computed
# at full precision.
def with_tf32_off(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
with tf32_off():
return f(*args, **kwargs)
return wrapped
def _get_magma_version():
if 'Magma' not in torch.__config__.show():
return (0, 0)
position = torch.__config__.show().find('Magma ')
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
return tuple(int(x) for x in version_str.split("."))
def _get_torch_cuda_version():
if torch.version.cuda is None:
return (0, 0)
cuda_version = str(torch.version.cuda)
return tuple(int(x) for x in cuda_version.split("."))
def _get_torch_rocm_version():
if not TEST_WITH_ROCM:
return (0, 0)
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
return tuple(int(x) for x in rocm_version.split("."))
def _check_cusparse_generic_available():
return not TEST_WITH_ROCM
def _check_hipsparse_generic_available():
if not TEST_WITH_ROCM:
return False
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-")[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
# Shared by test_torch.py and test_multigpu.py
def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
# Create a module+optimizer that will use scaling, and a control module+optimizer
# that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
with torch.no_grad():
for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
s.copy_(c)
kwargs = {"lr": 1.0}
if optimizer_kwargs is not None:
kwargs.update(optimizer_kwargs)
opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
return mod_control, mod_scaling, opt_control, opt_scaling
# Shared by test_torch.py, test_cuda.py and test_multigpu.py
def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
loss_fn = torch.nn.MSELoss().to(device)
skip_iter = 2
return _create_scaling_models_optimizers(
device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
) + (data, loss_fn, skip_iter)
# Importing this module should NOT eagerly initialize CUDA
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
assert not torch.cuda.is_initialized()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,111 @@
# mypy: ignore-errors
# Owner(s): ["oncall: distributed"]
from typing import Tuple
import torch
import torch.nn as nn
class UnitModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.seq(self.l1(x)))
class CompositeModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l1 = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.l2 = nn.Linear(100, 100, device=device)
def forward(self, x):
return self.l2(self.u2(self.u1(self.l1(x))))
class UnitParamModule(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.seq = nn.Sequential(
nn.ReLU(),
nn.Linear(100, 100, device=device),
nn.ReLU(),
)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x):
return torch.mm(self.seq(self.l(x)), self.p)
class CompositeParamModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.l = nn.Linear(100, 100, device=device)
self.u1 = UnitModule(device)
self.u2 = UnitModule(device)
self.p = nn.Parameter(torch.randn((100, 100), device=device))
self.register_buffer(
"buffer", torch.randn((100, 100), device=device), persistent=True
)
def forward(self, x):
a = self.u2(self.u1(self.l(x)))
b = self.p
return torch.mm(a, b)
class FakeSequential(nn.Module):
# Define this class to achieve a desired nested wrapping using the module
# wrap policy with `nn.Sequential`
def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
super().__init__()
self._module_sequence = list(modules)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self._module_sequence:
x = module(x)
return x
class NestedSequentialModel(nn.Module):
def __init__(self, device: torch.device) -> None:
super().__init__()
# This nested structure exercises traversal order to catch differences
# between valid traversals (e.g. BFS and DFS variations).
self.seq1 = nn.Sequential(
nn.Linear(1, 1, device=device),
FakeSequential(
nn.Linear(1, 1, device=device),
nn.ReLU(),
FakeSequential(
nn.Linear(1, 1, device=device),
),
nn.ReLU(),
),
nn.Linear(1, 2, device=device),
)
self.lin = nn.Linear(2, 2, device=device)
self.seq2 = nn.Sequential(
nn.ReLU(),
nn.Linear(2, 3, device=device),
FakeSequential(
nn.Linear(3, 2, bias=False, device=device),
nn.Linear(2, 4, bias=False, device=device),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.lin(self.seq1(x)))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,191 @@
# mypy: ignore-errors
from typing import List
import torch
# Functions and classes for describing the dtypes a function supports
# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
# Verifies each given dtype is a torch.dtype
def _validate_dtypes(*dtypes):
for dtype in dtypes:
assert isinstance(dtype, torch.dtype)
return dtypes
# class for tuples corresponding to a PyTorch dispatch macro
class _dispatch_dtypes(tuple):
def __add__(self, other):
assert isinstance(other, tuple)
return _dispatch_dtypes(tuple.__add__(self, other))
_empty_types = _dispatch_dtypes(())
def empty_types():
return _empty_types
_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
def floating_types():
return _floating_types
_floating_types_and_half = _floating_types + (torch.half,)
def floating_types_and_half():
return _floating_types_and_half
def floating_types_and(*dtypes):
return _floating_types + _validate_dtypes(*dtypes)
_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
def floating_and_complex_types():
return _floating_and_complex_types
def floating_and_complex_types_and(*dtypes):
return _floating_and_complex_types + _validate_dtypes(*dtypes)
_double_types = _dispatch_dtypes((torch.float64, torch.complex128))
def double_types():
return _double_types
# NB: Does not contain uint16/uint32/uint64 for BC reasons
_integral_types = _dispatch_dtypes(
(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
)
def integral_types():
return _integral_types
def integral_types_and(*dtypes):
return _integral_types + _validate_dtypes(*dtypes)
_all_types = _floating_types + _integral_types
def all_types():
return _all_types
def all_types_and(*dtypes):
return _all_types + _validate_dtypes(*dtypes)
_complex_types = _dispatch_dtypes((torch.cfloat, torch.cdouble))
def complex_types():
return _complex_types
def complex_types_and(*dtypes):
return _complex_types + _validate_dtypes(*dtypes)
_all_types_and_complex = _all_types + _complex_types
def all_types_and_complex():
return _all_types_and_complex
def all_types_and_complex_and(*dtypes):
return _all_types_and_complex + _validate_dtypes(*dtypes)
_all_types_and_half = _all_types + (torch.half,)
def all_types_and_half():
return _all_types_and_half
def custom_types(*dtypes):
"""Create a list of arbitrary dtypes"""
return _empty_types + _validate_dtypes(*dtypes)
# The functions below are used for convenience in our test suite and thus have no corresponding C++ dispatch macro
# See AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
def get_all_dtypes(
include_half=True,
include_bfloat16=True,
include_bool=True,
include_complex=True,
include_complex32=False,
include_qint=False,
) -> List[torch.dtype]:
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
include_half=include_half, include_bfloat16=include_bfloat16
)
if include_bool:
dtypes.append(torch.bool)
if include_complex:
dtypes += get_all_complex_dtypes(include_complex32)
if include_qint:
dtypes += get_all_qint_dtypes()
return dtypes
def get_all_math_dtypes(device) -> List[torch.dtype]:
return (
get_all_int_dtypes()
+ get_all_fp_dtypes(
include_half=device.startswith("cuda"), include_bfloat16=False
)
+ get_all_complex_dtypes()
)
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
return (
[torch.complex32, torch.complex64, torch.complex128]
if include_complex32
else [torch.complex64, torch.complex128]
)
def get_all_int_dtypes() -> List[torch.dtype]:
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
dtypes = [torch.float32, torch.float64]
if include_half:
dtypes.append(torch.float16)
if include_bfloat16:
dtypes.append(torch.bfloat16)
return dtypes
def get_all_qint_dtypes() -> List[torch.dtype]:
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
float_to_corresponding_complex_type_map = {
torch.float16: torch.complex32,
torch.float32: torch.complex64,
torch.float64: torch.complex128,
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,323 @@
# mypy: ignore-errors
# Torch
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
# Testing utils
from torch.testing._internal.common_dtype import floating_and_complex_types_and
from torch.testing._internal.common_utils import TestCase, \
freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
# Standard library
from itertools import chain
from typing import List, Union
from torch._C import TensorType
import io
def check_output_types(self, func, ref_outputs, args, kwargs):
graph = getattr(func, 'last_graph', None)
types = [o.type() for o in graph.outputs()]
self.assertTrue(len(types) == 1)
t = types[0]
torch._C._jit_assert_is_instance(ref_outputs, t)
# Test names in this set are only checked for a single derivative
nn_functional_single_grad = frozenset('test_nn_' + name for name in [
'pdist',
'multilabel_margin_loss',
'max_unpool3d',
'multi_margin_loss',
'binary_cross_entropy',
'binary_cross_entropy_size_average',
'ctc_loss',
'grid_sample',
])
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
"""Verifies a function performs identically to some reference implementation.
Commonly, this is used to verify that a JIT implementation
(output_func) matches the behavior of the eager implementation
(reference_func).
"""
kwargs = kwargs if kwargs else {}
def allSum(vs):
if isinstance(vs, torch.Tensor):
vs = (vs,)
return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
for i, v in enumerate(vs)
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
def clone_tensor(t, preserve_requires_grad):
require_grad = preserve_requires_grad and t.requires_grad
return t.detach().clone().requires_grad_(require_grad)
def clone_inputs(preserve_requires_grad: bool):
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
for arg in args:
if isinstance(arg, torch.Tensor):
inputs.append(clone_tensor(arg, preserve_requires_grad))
elif is_iterable_of_tensors(arg):
inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
else:
inputs.append(arg)
return inputs
# Returns tensors in args that requires_grad, including tensors in TensorList args
def get_recording_tensors(args):
recording_tensors: List[torch.Tensor] = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
recording_tensors.append(arg)
elif is_iterable_of_tensors(arg):
recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
return recording_tensors
# test no gradients case
nograd_inputs = clone_inputs(preserve_requires_grad=False)
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
with enable_profiling_mode_for_profiling_tests():
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
self.assertEqual(outputs, outputs_test)
if check_types:
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
if no_grad:
# skip grad tests
return
with enable_profiling_mode_for_profiling_tests():
# test single grad case
recording_inputs = clone_inputs(preserve_requires_grad=True)
recording_tensors = get_recording_tensors(recording_inputs)
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
allow_unused=allow_unused)
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
# test the grad grad case
if self._testMethodName in nn_functional_single_grad or no_gradgrad:
return
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
allow_unused=allow_unused)
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
recording_inputs = clone_inputs(preserve_requires_grad=True)
recording_tensors = get_recording_tensors(recording_inputs)
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
l1_test = allSum(outputs_test)
grads_test = torch.autograd.grad(
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
l2_test = (allSum(grads_test) * l1_test)
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
for g2, g2_test in zip(grads2, grads2_test):
if g2 is None and g2_test is None:
continue
self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
class JitCommonTestCase(TestCase):
def createFunctionFromGraph(self, trace):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
return torch._C._create_function_from_graph("forward", graph)
def assertExportImport(self, trace, inputs):
m = self.createFunctionFromGraph(trace)
self.assertExportImportModule(m, inputs)
def assertExportImportModule(self, m, inputs):
m_import = self.getExportImportCopy(m)
a = self.runAndSaveRNG(m, inputs)
b = self.runAndSaveRNG(m_import, inputs)
self.assertEqual(a, b, "Results of original model and "
"exported/imported version of model differed")
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
with freeze_rng_state():
results = func(*inputs, **kwargs)
return results
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
if not also_test_file:
return imported
with TemporaryFileName() as fname:
torch.jit.save(imported, fname)
return torch.jit.load(fname, map_location=map_location)
def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
fusion_nodes_not_found, non_fusible_nodes_being_fused,
fusion_nodes_found, nodes_in_diff_graph):
err_msg = "\nFailure in testing nodes' autodifferentiation. "
if should_autodiff_node:
err_msg += "One or more nodes were expected to be autodiffed, " \
"but were not found in specified fusible/nonfusible " \
"DifferentiableGraph groups. \nSpecifically:"
# The node is intended to appear in a differentiable graph but doesn't
diff_nodes_missing = []
# The node is intended to appear in a differentiable graph
# outside of a fusion group but instead is in a fusion group
diff_nodes_in_fusion = []
# The node is intended to appear in a fusion group but doesn't
fusion_nodes_missing = []
# The node is intended to appear in a fusion group but instead
# is just in an outer differentiable graph
fusion_nodes_in_diff = []
for node in nodes_not_in_diff_graph:
if node in non_fusible_nodes_being_fused:
diff_nodes_in_fusion.append(node)
else:
diff_nodes_missing.append(node)
for node in fusion_nodes_not_found:
if node in nodes_in_diff_graph:
fusion_nodes_in_diff.append(node)
else:
fusion_nodes_missing.append(node)
if len(diff_nodes_missing) > 0:
err_msg += f"\n {diff_nodes_missing} were not in one of the " \
"DifferentiableGraphs when they were expected to be. " \
"Did you intend for these nodes to be autodiffed? " \
"If not, remove them from the list of nonfusible nodes."
if len(diff_nodes_in_fusion) > 0:
err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \
"when they were expected to be just in a DifferentiableGraph. If it was " \
"intended for these nodes to be in FusionGroups, reclassify these nodes as " \
"fusible nodes. If these nodes were not intended to be fused, your " \
"autodifferentiation logic might be wrong."
if len(fusion_nodes_missing) > 0:
err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \
"of the DifferentiableGraphs when they were expected to be. " \
"They were also not found in an outer DifferentiableGraph. Did you " \
"intend for these nodes to be autodifferentiated? If not, you should " \
"remove these nodes from the test's fusible nodes. Otherwise your " \
"autodifferentiation logic might be wrong."
if len(fusion_nodes_in_diff) > 0:
err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \
"of the DifferentiableGraphs when they were expected to be, " \
"instead they were found just in an outer DifferentiableGraph. " \
"Did you intend for these nodes to be fused? If not, you should " \
"move these nodes into the test's nonfusible nodes. Otherwise your " \
"autodifferentiation logic might be wrong."
else:
err_msg += "One or more nodes were not expected to be autodiffed " \
"but were found in a DifferentiableGraph or in a FusionGroup " \
"of a DifferentiableGraph. Did you intend for these nodes to be " \
"autodiffed? If so, change this test to expect autodifferentiation. " \
"\nSpecifically:"
if len(fusion_nodes_found) > 0:
err_msg += f"\n {fusion_nodes_found} were not expected to be in " \
"one of the DifferentiableGraphs, but appeared in a FusionGroup " \
"of a DifferentiableGraph. "
if len(nodes_in_diff_graph) > 0:
err_msg += f"\n {nodes_in_diff_graph} were not expected to " \
"be in one of the DifferentiableGraphs but were."
return err_msg
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
# Note: currently no tests have fusible_nodes
fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
# For any non-fusible node, it must show up in one of the DifferentiableGraphs.
nodes_in_diff_graph = []
nodes_not_in_diff_graph = []
non_fusible_nodes_being_fused = []
for node in nonfusible_nodes:
if any(g.findNode(node) is not None for g in diff_subgraphs):
nodes_in_diff_graph.append(node)
else:
nodes_not_in_diff_graph.append(node)
if any(g.findNode(node) is not None for g in fusion_subgraphs):
non_fusible_nodes_being_fused.append(node)
found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
# For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
fusion_nodes_found = []
fusion_nodes_not_found = []
for node in fusible_nodes:
if any(g.findNode(node) is not None for g in fusion_subgraphs):
fusion_nodes_found.append(node)
else:
fusion_nodes_not_found.append(node)
found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
if should_autodiff_node is not None:
err_msg = self.autoDiffErrorMessage(should_autodiff_node,
nodes_not_in_diff_graph,
fusion_nodes_not_found,
non_fusible_nodes_being_fused,
fusion_nodes_found,
nodes_in_diff_graph)
self.assertEqual(should_autodiff_node,
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
traced_graph, assert_propagation, constant_prop=True):
# repropagte input shapes provided by tracing,
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
for enable_test_mode in [True, False]:
# here we are testing allowing/disallowing substituting in complete shapes as constants,
# disallowing constants helps stress test partial eval and substitution pipeline
torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
torch._C._jit_erase_non_input_shape_information(traced_graph)
if constant_prop:
torch._C._jit_pass_constant_propagation(traced_graph)
torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
# Add sizes to default tensor type to avoid checking something out of scope
# and difficulties with tracer leaving in other parts of tensor type
output = next(traced_graph.outputs()).type()
def test_type(type, actual_size):
sizes = type.symbolic_sizes()
out_type = TensorType.get().with_sizes(sizes)
actual_type = TensorType.get().with_sizes(actual_size)
# always check actual shape is a subtype of the output
self.assertTrue(actual_type.isSubtypeOf(out_type))
# and then if assertion flag is provided, check shape analysis
# is successful
if assert_propagation:
self.assertEqual(out_type.sizes(), actual_size)
if output.isSubtypeOf(torch._C.TensorType.get()):
test_type(output, out_sizes)
else:
tuple_elements = output.elements()
for i in range(len(tuple_elements)):
test_type(tuple_elements[i], out_sizes[i])
torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,78 @@
# mypy: ignore-errors
import contextlib
import functools
import inspect
import torch
# Test whether hardware BF32 math mode enabled. It is enabled only on:
# - MKLDNN is available
# - BF16 is supported by MKLDNN
def bf32_is_not_fp32():
if not torch.backends.mkldnn.is_available():
return False
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return False
return True
@contextlib.contextmanager
def bf32_off():
old_matmul_precision = torch.get_float32_matmul_precision()
try:
torch.set_float32_matmul_precision("highest")
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
@contextlib.contextmanager
def bf32_on(self, bf32_precision=1e-5):
old_matmul_precision = torch.get_float32_matmul_precision()
old_precision = self.precision
try:
torch.set_float32_matmul_precision("medium")
self.precision = bf32_precision
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
self.precision = old_precision
# This is a wrapper that wraps a test to run this test twice, one with
# allow_bf32=True, another with allow_bf32=False. When running with
# allow_bf32=True, it will use reduced precision as specified by the
# argument
def bf32_on_and_off(bf32_precision=1e-5):
def with_bf32_disabled(self, function_call):
with bf32_off():
function_call()
def with_bf32_enabled(self, function_call):
with bf32_on(self, bf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
for k, v in zip(arg_names, args):
kwargs[k] = v
cond = bf32_is_not_fp32()
if "device" in kwargs:
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
if "dtype" in kwargs:
cond = cond and (kwargs["dtype"] == torch.float)
if cond:
with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,385 @@
# Owner(s): ["module: unknown"]
from typing import Dict, Any, Tuple
from torch.ao.pruning import BaseSparsifier
import torch
import torch.nn.functional as F
from torch import nn
class ImplementedSparsifier(BaseSparsifier):
def __init__(self, **kwargs: Dict[str, Any]) -> None:
super().__init__(defaults=kwargs)
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None:
module.parametrizations.weight[0].mask[0] = 0
linear_state = self.state['linear1.weight']
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
class MockSparseLinear(nn.Linear):
"""
This class is a MockSparseLinear class to check convert functionality.
It is the same as a normal Linear layer, except with a different type, as
well as an additional from_dense method.
"""
@classmethod
def from_dense(cls, mod: nn.Linear) -> 'MockSparseLinear':
"""
"""
linear = cls(mod.in_features,
mod.out_features)
return linear
def rows_are_subset(subset_tensor: torch.Tensor, superset_tensor: torch.Tensor) -> bool:
"""
Checks to see if all rows in subset tensor are present in the superset tensor
"""
i = 0
for row in subset_tensor:
while i < len(superset_tensor):
if not torch.equal(row, superset_tensor[i]):
i += 1
else:
break
else:
return False
return True
class SimpleLinear(nn.Module):
r"""Model with only Linear layers without biases, some wrapped in a Sequential,
some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=False),
nn.Linear(5, 6, bias=False),
nn.Linear(6, 4, bias=False),
)
self.linear1 = nn.Linear(4, 4, bias=False)
self.linear2 = nn.Linear(4, 10, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = self.linear2(x)
return x
class LinearBias(nn.Module):
r"""Model with only Linear layers, alternating layers with biases,
wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.Linear(5, 6, bias=False),
nn.Linear(6, 3, bias=True),
nn.Linear(3, 3, bias=True),
nn.Linear(3, 10, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
return x
class LinearActivation(nn.Module):
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
Activation functions modules in between each Linear in the Sequential, and each outside layer.
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 6, bias=False),
nn.Tanh(),
nn.Linear(6, 4, bias=True),
)
self.linear1 = nn.Linear(4, 3, bias=True)
self.act1 = nn.ReLU()
self.linear2 = nn.Linear(3, 10, bias=False)
self.act2 = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = self.act1(x)
x = self.linear2(x)
x = self.act2(x)
return x
class LinearActivationFunctional(nn.Module):
r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
Activation functions modules in between each Linear in the Sequential, and functional
activationals are called in between each outside layer.
Used to test pruned Linear(Bias)-Activation-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 6, bias=False),
nn.ReLU(),
nn.Linear(6, 4, bias=True),
)
self.linear1 = nn.Linear(4, 3, bias=True)
self.linear2 = nn.Linear(3, 8, bias=False)
self.linear3 = nn.Linear(8, 10, bias=False)
self.act1 = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
x = F.relu(x)
return x
class SimpleConv2d(nn.Module):
r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
Used to test pruned Conv2d-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=False),
nn.Conv2d(32, 64, 3, 1, bias=False),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.conv2d2(x)
return x
class Conv2dBias(nn.Module):
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
Used to test pruned Conv2d-Bias-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.Conv2d(32, 32, 3, 1, bias=True),
nn.Conv2d(32, 64, 3, 1, bias=False),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.conv2d2(x)
return x
class Conv2dActivation(nn.Module):
r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
Activation function modules in between each Sequential layer, functional activations called
in-between each outside layer.
Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, bias=True),
nn.Tanh(),
nn.Conv2d(64, 64, 3, 1, bias=False),
nn.ReLU(),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.relu(x)
x = self.conv2d2(x)
x = F.hardtanh(x)
return x
class Conv2dPadBias(nn.Module):
r"""Model with only Conv2d layers, all with bias and some with padding > 0,
some in a Sequential and some following. Activation function modules in between each layer.
Used to test that bias is propagated correctly in the special case of
pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, bias=False),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1, bias=True),
nn.Tanh(),
)
self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
self.act1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
self.act2 = nn.Tanh()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.act1(x)
x = self.conv2d2(x)
x = self.act2(x)
return x
class Conv2dPool(nn.Module):
r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
Activation function modules in between each layer, Pool2d modules in between each layer.
Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = self.maxpool(x)
x = self.af1(x)
x = self.conv2d2(x)
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
x = F.relu(x)
x = self.conv2d3(x)
return x
class Conv2dPoolFlattenFunctional(nn.Module):
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
and a functional Flatten followed by a Linear layer.
Activation functions and Pool2ds in between each layer also.
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(11, 13, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = self.af1(x)
x = self.conv2d2(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1) # test functional flatten
x = self.fc(x)
return x
class Conv2dPoolFlatten(nn.Module):
r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
and a Flatten module followed by a Linear layer.
Activation functions and Pool2ds in between each layer also.
Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
def __init__(self) -> None:
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
)
self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
self.af1 = nn.ReLU()
self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
self.flatten = nn.Flatten()
self.fc = nn.Linear(44, 13, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.seq(x)
x = self.conv2d1(x)
x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
x = self.af1(x)
x = self.conv2d2(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x
class LSTMLinearModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a linear."""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> None:
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
self.linear = nn.Linear(hidden_dim, output_dim)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output, hidden = self.lstm(input)
decoded = self.linear(output)
return decoded, output
class LSTMLayerNormLinearModel(nn.Module):
"""Container module with an LSTM, a LayerNorm, and a linear."""
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
) -> None:
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
self.norm = nn.LayerNorm(hidden_dim)
self.linear = nn.Linear(hidden_dim, output_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x, state = self.lstm(x)
x = self.norm(x)
x = self.linear(x)
return x, state

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
# mypy: ignore-errors
r"""Importing this file includes common utility methods for checking quantized
tensors and modules.
"""
import numpy as np
import torch
from contextlib import contextmanager
from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
supported_qengines = torch.backends.quantized.supported_engines
supported_qengines.remove('none')
# Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
# QNNPACK is not supported on PPC
# QNNPACK throws ASAN heap-buffer-overflow error.
if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
supported_qengines.remove('qnnpack')
def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
output_padding=0):
"""Computes the output shape given convolution parameters."""
return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
* (dilation - 1)) / stride) + 2 * output_padding + 1
# Quantization references
def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
"""Quantizes a numpy array."""
if qmin is None:
qmin = np.iinfo(dtype).min
if qmax is None:
qmax = np.iinfo(dtype).max
qx = np.round(x / scale + zero_point).astype(np.int64)
qx = np.clip(qx, qmin, qmax)
qx = qx.astype(dtype)
return qx
def _dequantize(qx, scale, zero_point):
"""Dequantizes a numpy array."""
x = (qx.astype(float) - zero_point) * scale
return x
def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
"""Requantizes a numpy array, i.e., intermediate int32 or int16 values are
converted back to given type"""
qx = (x * multiplier).round() + zero_point
qx = np.clip(qx, qmin, qmax).astype(qtype)
return qx
def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
"""Calculate the dynamic quantization parameters (scale, zero_point)
according to the min and max element of the tensor"""
assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
if qscheme == torch.per_tensor_symmetric:
assert dtype == torch.qint8
if isinstance(X, torch.Tensor):
X = X.numpy()
if dtype == torch.qint8:
if reduce_range:
qmin, qmax = -64, 63
else:
qmin, qmax = -128, 127
else: # dtype == torch.quint8
if reduce_range:
qmin, qmax = 0, 127
else:
qmin, qmax = 0, 255
min_val = X.min()
max_val = X.max()
is_symmetric = (qscheme == torch.per_tensor_symmetric)
if min_val == max_val:
scale = 1.0
zero_point = 0
else:
if is_symmetric:
max_val = max(max_val, -min_val)
min_val = -max_val
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale = (max_val - min_val) / (qmax - qmin)
scale = max(scale, np.finfo(np.float32).eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
return [float(scale), int(zero_point)]
def _calculate_dynamic_per_channel_qparams(X, dtype):
"""Calculate the dynamic quantization parameters (scale, zero_point)
according to the min and max element of the tensor"""
if isinstance(X, torch.Tensor):
X = X.numpy()
qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
n_levels = qmax - qmin
scale = np.zeros(X.shape[0], dtype=np.float64)
zero_point = np.zeros(X.shape[0], dtype=np.int64)
for i in range(zero_point.shape[0]):
min_val = X.min()
max_val = X.max()
if min_val == max_val:
scale[i] = 1.0
zero_point[i] = 0
else:
max_val = max(max_val, 0.0)
min_val = min(min_val, 0.0)
scale[i] = (max_val - min_val) / n_levels
scale[i] = max(scale[i], np.finfo(np.float32).eps)
zero_point[i] = qmin - round(min_val / scale[i])
zero_point[i] = max(qmin, zero_point[i])
zero_point[i] = min(qmax, zero_point[i])
return scale, zero_point
def _snr(x, x_hat):
"""Calculates the signal to noise ratio and returns the signal and noise
power, as well as the SNR in dB.
If the input is a list/tuple this function is called recursively on each
element. The result will have the same nested structure as the inputs.
Args:
x, x_hat: Either a tensor or a nested list/tuple of tensors.
Returns:
signal, noise, SNR(in dB): Either floats or a nested list of floats
"""
if isinstance(x, (list, tuple)):
assert len(x) == len(x_hat)
res = []
for idx in range(len(x)):
res.append(_snr(x[idx], x_hat[idx]))
return res
if x_hat.is_quantized:
x_hat = x_hat.dequantize()
if x.is_quantized:
x = x.dequantize()
noise = (x - x_hat).norm()
if noise == 0:
return 0.0, float('inf'), float('inf')
signal = x.norm()
snr = signal / noise
snr_db = 20 * snr.log10()
return signal, noise, snr_db
@contextmanager
def override_quantized_engine(qengine):
previous = torch.backends.quantized.engine
torch.backends.quantized.engine = qengine
try:
yield
finally:
torch.backends.quantized.engine = previous
@contextmanager
def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
try:
if qengine_is_qnnpack:
torch._C._set_default_mobile_cpu_allocator()
yield
finally:
if qengine_is_qnnpack:
torch._C._unset_default_mobile_cpu_allocator()
# TODO: Update all quantization tests to use this decorator.
# Currently for some of the tests it seems to have inconsistent params
# for fbgemm vs qnnpack.
def override_qengines(qfunction):
def test_fn(*args, **kwargs):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
# qfunction should not return anything.
qfunction(*args, **kwargs)
return test_fn
def qengine_is_fbgemm():
return torch.backends.quantized.engine == 'fbgemm'
def qengine_is_qnnpack():
return torch.backends.quantized.engine == 'qnnpack'
def qengine_is_onednn():
return torch.backends.quantized.engine == 'onednn'
def qengine_is_x86():
return torch.backends.quantized.engine == 'x86'
# Helper function used to simulate per-channel fake-quant against any axis
def _permute_to_axis_zero(X, axis):
new_axis_list = list(range(X.dim()))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = X.permute(tuple(new_axis_list))
return y, new_axis_list
# Reference method for fake quantize
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
dtype = X.dtype
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
res = torch.zeros_like(X)
for i in range(X.size()[0]):
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
out = res.permute(tuple(permute_axis_list))
return out.to(dtype)
# Reference method for the gradient of the fake quantize operator
# Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
dtype = X.dtype
X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
Xq = torch.zeros_like(X)
for i in range(X.size()[0]):
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
Xq = Xq.permute(tuple(permute_axis_list))
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY)
res[mask] = dY[mask]
return res.to(dtype)
def to_tensor(X, device):
if not isinstance(X, torch.Tensor):
X = torch.tensor(X)
else:
X = X.clone().detach()
return X.to(device=torch.device(device), dtype=torch.float32)

View File

@ -0,0 +1,266 @@
# mypy: ignore-errors
import torch
from copy import deepcopy
from torch.utils._pytree import tree_map
import torch.utils._pytree as pytree
# TODO: Move LoggingTensor here.
from torch.testing._internal.logging_tensor import LoggingTensor
# Base class for wrapper-style tensors.
class WrapperTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
if "size" not in kwargs:
size = t.size()
else:
size = kwargs["size"]
del kwargs["size"]
if "dtype" not in kwargs:
kwargs["dtype"] = t.dtype
if "layout" not in kwargs:
kwargs["layout"] = t.layout
if "device" not in kwargs:
kwargs["device"] = t.device
if "requires_grad" not in kwargs:
kwargs["requires_grad"] = False
# Ignore memory_format and pin memory for now as I don't know how to
# safely access them on a Tensor (if possible??)
wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
wrapper._validate_methods()
return wrapper
@classmethod
def get_wrapper_properties(cls, *args, **kwargs):
# Should return both an example Tensor and a dictionary of kwargs
# to override any of that example Tensor's properly.
# This is very similar to the `t.new_*(args)` API
raise NotImplementedError("You need to implement get_wrapper_properties")
def _validate_methods(self):
# Skip this if not in debug mode?
# Changing these on the python side is wrong as it would not be properly reflected
# on the c++ side
# This doesn't catch attributes set in the __init__
forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
for el in forbidden_overrides:
if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
f"property {el} but this is not allowed as such change would "
"not be reflected to c++ callers.")
class DiagTensorBelow(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, diag, requires_grad=False):
assert diag.ndim == 1
return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
def __init__(self, diag, requires_grad=False):
self.diag = diag
handled_ops = {}
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# For everything else, call the handler:
fn = cls.handled_ops.get(func.__name__, None)
if fn:
return fn(*args, **(kwargs or {}))
else:
# Note that here, because we don't need to provide the autograd formulas
# we can have a default "fallback" that creates a plain Tensor based
# on the diag elements and calls the func again.
def unwrap(e):
return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
def wrap(e):
if isinstance(e, torch.Tensor) and e.ndim == 1:
return DiagTensorBelow(e)
if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
return DiagTensorBelow(e.diag())
return e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
def __repr__(self):
return super().__repr__(tensor_contents=f"diag={self.diag}")
class SparseTensor(WrapperTensor):
@classmethod
def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
assert values.device == indices.device
return values, {"size": size, "requires_grad": requires_grad}
def __init__(self, size, values, indices, requires_grad=False):
self.values = values
self.indices = indices
def __repr__(self):
return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
def sparse_to_dense(self):
res = torch.zeros(self.size(), dtype=self.values.dtype)
res[self.indices.unbind(1)] = self.values
return res
@staticmethod
def from_dense(t):
indices = t.nonzero()
values = t[indices.unbind(1)]
return SparseTensor(t.size(), values, indices)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
func_name = f"{func.__module__}.{func.__name__}"
res = cls._try_call_special_impl(func_name, args, kwargs)
if res is not NotImplemented:
return res
# Otherwise, use a default implementation that construct dense
# tensors and use that to compute values
def unwrap(e):
return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
# Wrap back all Tensors into our custom class
def wrap(e):
# Check for zeros and use that to get indices
return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
return rs
# To show how things happen later
def __rmul__(self, other):
return super().__rmul__(other)
_SPECIAL_IMPLS = {}
@classmethod
def _try_call_special_impl(cls, func, args, kwargs):
if func not in cls._SPECIAL_IMPLS:
return NotImplemented
return cls._SPECIAL_IMPLS[func](args, kwargs)
# Example non-wrapper subclass that stores extra state.
class NonWrapperTensor(torch.Tensor):
def __new__(cls, data):
t = torch.Tensor._make_subclass(cls, data)
t.extra_state = {
'last_func_called': None
}
return t
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, cls):
# Do something with the extra state. For the example here, just store the name of the
# last function called (skip for deepcopy so the copy has the same extra state).
if func is torch.Tensor.__deepcopy__:
result.extra_state = deepcopy(args[0].extra_state)
else:
result.extra_state = {
'last_func_called': func.__name__,
}
return result
# new_empty() must be defined for deepcopy to work
def new_empty(self, shape):
return type(self)(torch.empty(shape))
# Class used to store info about subclass tensors used in testing.
class SubclassInfo:
__slots__ = ['name', 'create_fn', 'closed_under_ops']
def __init__(self, name, create_fn, closed_under_ops=True):
self.name = name
self.create_fn = create_fn # create_fn(shape) -> tensor instance
self.closed_under_ops = closed_under_ops
subclass_db = {
torch.Tensor: SubclassInfo(
'base_tensor', create_fn=torch.randn
),
NonWrapperTensor: SubclassInfo(
'non_wrapper_tensor',
create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
),
LoggingTensor: SubclassInfo(
'logging_tensor',
create_fn=lambda shape: LoggingTensor(torch.randn(shape))
),
SparseTensor: SubclassInfo(
'sparse_tensor',
create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
),
DiagTensorBelow: SubclassInfo(
'diag_tensor_below',
create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
closed_under_ops=False # sparse semantics
),
}
class SubclassWithTensorFactory(torch.Tensor):
@staticmethod
def __new__(cls, src):
shape = src.shape
kwargs = {}
kwargs["strides"] = src.stride()
kwargs["storage_offset"] = src.storage_offset()
kwargs["device"] = src.device
kwargs["layout"] = src.layout
kwargs["requires_grad"] = src.requires_grad
kwargs["dtype"] = src.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, src):
self.src = src
def __repr__(self):
return f"{self.__class__.__name__}"
def __tensor_flatten__(self):
return ["src"], None
@classmethod
def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
src = inner_tensors["src"]
return cls(src)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
def _fn(x):
return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
_args = pytree.tree_map_only(cls, _fn, args)
_kwargs = pytree.tree_map_only(cls, _fn, kwargs)
_out = func(*_args, **_kwargs)
_out_flat, _out_spec = pytree.tree_flatten(_out)
out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
return pytree.tree_unflatten(out_flat, _out_spec)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,581 @@
# mypy: ignore-errors
import torch
from torch import Tensor
import itertools
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
from functools import partial
from torch.utils._mode_utils import no_dispatch, all_same_mode
import torch.autograd.forward_ad as fwAD
from typing import Callable
import re
def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
elem = wrapper_tensor.elem
metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
metadata_elem = metadata_accessor(elem)
if metadata_wrapper_tensor == metadata_elem:
return
raise RuntimeError(
f"This operator is not Composite Compliant: the "
f"{metadata_name} of the tensor was modified directly without "
f"going through the PyTorch dispatcher.")
def check_metadata_consistency(wrapper_tensor, CCT):
# CCT: CompositeCompliantTensor class which is generated using generate_cct
if not isinstance(wrapper_tensor, CCT):
return
things_to_check = {
'shape': Tensor.size,
'dtype': lambda x: x.dtype,
'device': lambda x: x.device,
'numel': Tensor.numel,
'stride': Tensor.stride,
'storage_offset': Tensor.storage_offset,
}
for metadata_name, metadata_accessor in things_to_check.items():
check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
def is_view_fn(func):
return func.overloadpacket.__name__ in {
'as_strided',
'detach',
'diagonal',
'expand',
'expand_as',
'movedim',
'narrow',
'permute',
'select',
'squeeze',
'transpose',
't',
'real',
'imag',
'view_as_real',
'view_as_complex',
'unflatten',
'unfold',
'unsqueeze',
'view',
'view_as',
'unbind',
'split',
'split_with_sizes',
'vsplit',
'hsplit',
'tensor_split',
'chunk',
'swapaxes',
'slice',
'_reshape_alias',
'_unsafe_view',
'_conj',
'alias',
}
# manually populated from native_functions that have inplace_view: True.
# In the future we will probably be able to grab that list directly
def is_inplace_view_fn(func):
return func.overloadpacket.__name__ in {
'as_strided_',
'detach_',
'squeeze_',
'swapaxes_',
'swapdims_',
't_',
'transpose_',
'unsqueeze_',
}
# Introspection please save us
def is_inplace(func):
name = func.overloadpacket.__name__
if re.match('__i.+__', name):
return True
if re.match('__.+__', name):
return False
return name[-1] == '_'
def generate_cct_and_mode(autograd_view_consistency=True):
# This function returns a new class CompositeCompliantTensor
# The two arguments control the behaviour described below.
# autograd_view_consistency:
# If True, alias result using `set_` if func returns a view
# (See Note [Alias Result]).
# Since Forward AD doesn't work with `set_`
# we disable it by setting alias to False.
class CompositeCompliantTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
@staticmethod
def __new__(cls, elem, mode, *args, **kwargs):
assert type(elem) is not cls, \
"Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
# The storage of CompositeCompliantTensor should never be used directly
# by a Composite operation; if the Composite
# operator attempts to read from the storage without dispatching then it'll
# raise a RuntimeError due to it being a meta storage.
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, elem.size(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad,
strides=elem.stride(), storage_offset=elem.storage_offset())
if elem.requires_grad:
# CompositeCompliantTensor steals the "requires_grad"-ness.
# Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
device=elem.device, layout=elem.layout,
requires_grad=False)
tmp.copy_(elem.detach())
r.elem = tmp
else:
r.elem = elem
assert r.stride() == r.elem.stride()
# Propagate conjugate bits to the wrapper tensor
# Ref: https://github.com/albanD/subclass_zoo/issues/24
# Ref: https://github.com/albanD/subclass_zoo/issues/21
torch._C._set_conj(r, r.elem.is_conj())
torch._C._set_neg(r, r.elem.is_neg())
r.mode = mode
return r
def __repr__(self):
return f"CompositeCompliantTensor({self.elem})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
if not all_same_mode(modes):
raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
with modes[0]:
return func(*args, **kwargs)
class CompositeCompliantTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, CompositeCompliantTensor) else e
def wrap(e):
return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
if func == torch.ops.aten._local_scalar_dense.default:
raise RuntimeError(
".item() is not allowed to be called inside of composite "
"functions in the PyTorch library because not all backends "
"and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
if func.overloadpacket.__name__ in ('set_', 'resize_'):
raise RuntimeError(
f"{func.__name__} is not allowed to be called inside of "
f"Composite operators.")
if is_inplace(func):
# NB: We are making an assumption that if the function is in-place,
# then the first argument is being written to. Introspection please save us!
mutated_argument = args[0]
if not isinstance(mutated_argument, CompositeCompliantTensor) and \
any(isinstance(a, CompositeCompliantTensor) for a in args[1:]):
raise RuntimeError(
'Not composite compliant: performing in-place operation '
f'{func.__name__} where the Tensor being written to is '
'regular Tensor but the other tensors are Tensor Subclasses. '
'Please try to avoid this in-place operation.')
unwrapped_args = tree_map(unwrap, args)
unwrapped_kwargs = tree_map(unwrap, kwargs)
unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
rs = tree_map(wrap, unwrapped_rs)
if is_view_fn(func) and autograd_view_consistency:
# Note [Alias Result]
# Autograd asserts that for B = A.view_fn(...), B and A's storages
# are the same. Here we try to make B alias A to avoid those asserts.
# See https://github.com/pytorch/pytorch/issues/65339 for more information
# about the issue.
with no_dispatch():
# Idea: this is a weird way of getting a storage that aliases the input.
# This is a workaround for #65339.
# 1. under no_dispatch, all of the wrapper tensors look like regular
# tensors with special storage (the storage is nullptr and
# advertises CPU/CUDA device.
# 2. we run func, which ends up running the view operation
# 3. All view operations reuse the input's storage and return
# result Tensor(s) with new sizes/strides/offset that alias
# the input.
# 4. we set the storage (and sizes/strides/offset) of the wrapper
# tensor results to be that of the tensors that alias the input
result = func(*args, **kwargs)
if isinstance(result, (tuple, list)):
for a, b in zip(rs, result):
a.set_(b)
else:
rs.set_(result)
# Some operations are allowed to in-place modify the metadata of the
# inputs. The only ones are the "inplace view functions"; when we
# run into these, we manually modify the metadata of the input.
with no_dispatch():
if is_inplace_view_fn(func):
func(*args, **kwargs)
# For each CompositeCompliantTensor t, we check that t and t.elem
# have consistent metadata. If they don't have consistent metadata,
# that means the operator did something fishy.
check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
pytree.tree_map_(check, args)
pytree.tree_map_(check, kwargs)
pytree.tree_map_(check, rs)
return rs
return CompositeCompliantTensor, CompositeCompliantTensorMode()
def is_tensorlist(lst):
if not isinstance(lst, list) and not isinstance(lst, tuple):
return False
if len(lst) == 0:
return False
all_tensors = all(isinstance(elt, torch.Tensor) for elt in lst)
if all_tensors:
return True
exists_one_tensor = all(isinstance(elt, torch.Tensor) for elt in lst)
if exists_one_tensor:
raise RuntimeError('This test assumes that PyTorch APIs cannot take '
'mixed lists of Tensor and other things')
return False
def maybe_map(fn, should_map, arg):
return fn(arg) if should_map else arg
def wrap(arg, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
if isinstance(arg, torch.Tensor):
return CCT(arg, cct_mode)
if is_tensorlist(arg):
return [CCT(a, cct_mode) for a in arg]
raise RuntimeError("wrap assumes that the input can be wrapped")
# Given a list of flat arguments, some of which may be Tensors, return all
# possible ways some of the arguments could be CompositeCompliantTensors (CCT).
# For example, given Tensors A, B, C and flat_args = [A, 1, B],
# We would return the following 4 options:
# [CCT(A), 1, CCT(B)]
# [CCT(A), 1, B]
# [A, 1, CCT(B)]
# [A, 1, B]
# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
# don't accept that many input Tensors.
def generate_subclass_choices(flat_args, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
for which_args_are_wrapped in itertools.product(*subclass_options):
result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
yield result, which_args_are_wrapped
# For an operation f(*args, **kwargs), each Tensor argument may either be
# a regular Tensor or a Tensor Subclass. This iterator iterates through
# all of those options.
def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
flat_kwargs, spec = tree_flatten(kwargs)
flat_args_kwargs = list(args) + list(flat_kwargs)
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
new_args = choice[:len(args)]
new_kwargs = tree_unflatten(choice[len(args):], spec)
which_args_are_wrapped = debug_metadata[:len(args)]
which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
def raise_composite_compliance_error(err, additional_info=''):
raise RuntimeError(
"Composite compliance check failed with "
"the above error.\n"
f"{additional_info}"
"If you are adding an OpInfo of an "
"existing operator, please feel free to skip this test "
"because the problem was pre-existing and file an issue. "
"Otherwise, if you added a new operator, please read "
"through the Composite Compliance section in "
"aten/src/ATen/native/README.md for how to resolve this. "
) from err
# This test checks ALL possible permutations of calling `op` with arguments
# that are individually either a regular Tensor or a Tensor subclass.
#
# The general strategy is to wrap some Tensor args and kwargs in
# CompositeCompliantTensor wrappers and call the operation.
# If some composite operation does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_all_permutations(op, args, kwargs, assert_equal_fn):
CCT, cct_mode = generate_cct_and_mode()
expected = op(*args, **kwargs)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
try:
actual = op(*new_args, **new_kwargs)
# NOTE: [What errors are Composite Compliance trying to catch?]
#
# There's two things we want to catch:
# - errors that would raise within the torch_dispatch impl
# - data_ptr accesses
# The first is easy to filter for (we could make the error a different
# error class), the second is always going to be a RuntimeError due to
# how it is implemented (if you try to access the data_ptr of thex
# wrapper Tensor, it raises you some internal RuntimeError).
#
# So the most general thing to catch here was RuntimeError. If you
# are here and debugging why your test failed, it's plausible that
# the operator itself is broken and that there are other tests failing.
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tree_map(unwrap, actual), expected)
# Checks via the usage of torch dispatch mode certain anti-patterns that
# are not composite compliant.
#
# In particular, the anti-pattern we are trying to prevent is a user
# creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
# here because all factory functions will create tensors that are
# CompositeCompliantTensor.
#
# The general strategy is to wrap all Tensor args and kwargs in
# CompositeCompliantTensor wrappers. If an operator that is
# Composite does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_with_mode(op, args, kwargs, assert_equal_fn):
CCT, cct_mode = generate_cct_and_mode()
def wrap(e):
return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
expected = op(*args, **kwargs)
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
try:
with cct_mode:
actual = op(*args, **kwargs)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(err)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tree_map(unwrap, actual), expected)
def gather_leaf_tensors(args, kwargs):
leaf_tensors = []
args, args_spec = tree_flatten(args)
kwargs, kwargs_spec = tree_flatten(kwargs)
args = args + kwargs
for arg in args:
if not isinstance(arg, torch.Tensor):
continue
if arg.requires_grad:
leaf_tensors.append(arg)
return leaf_tensors
def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
if gradcheck_wrapper is None:
results = op(*args, **kwargs)
else:
results = gradcheck_wrapper(op, *args, **kwargs)
if output_process_fn_grad is not None:
results = output_process_fn_grad(results)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
leaf_tensors = gather_leaf_tensors(args, kwargs)
assert len(leaf_tensors) > 0
return torch.autograd.grad(flat_diff_results, leaf_tensors,
grads, allow_unused=True, retain_graph=True)
# Checks if the backward formula is composite compliant by testing
# all possible permutations of {inputs, grad_outputs} being
# CompositeCompliantTensor or regular Tensors.
#
# NB: it is important that op is accepted as a Callable and not an OpInfo,
# this means we can apply check_backward_formula to things that aren't OpInfos
# while debugging.
def check_backward_formula(op: Callable, args, kwargs,
output_process_fn_grad=None,
gradcheck_wrapper=None, assert_equal_fn=None):
CCT, cct_mode = generate_cct_and_mode()
expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
assert len(leaf_tensors) > 0
try:
if gradcheck_wrapper is None:
results = op(*new_args, **new_kwargs)
else:
results = gradcheck_wrapper(op, *new_args, **new_kwargs)
if output_process_fn_grad is not None:
results = output_process_fn_grad(results)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
)
flat_results = pytree.tree_leaves(results)
flat_results = [r for r in flat_results if isinstance(r, torch.Tensor)]
flat_diff_results = [r for r in flat_results if r.requires_grad]
assert len(flat_diff_results) > 0
# NB: ones, not ones_like, so we get a regular Tensor here
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
for r in flat_diff_results]
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
try:
actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
allow_unused=True, retain_graph=True)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
f"- wrapped_grads: {which_grad_is_batched}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
# Checks if the forward AD formula is composite compliant by testing
# all possible permutations of {primals, tangents} being
# CompositeCompliantTensor or regular Tensors.
#
# NB: it is important that op is accepted as a Callable and not an OpInfo,
# this means we can apply check_forward_ad_formula to things that aren't OpInfos
# while debugging.
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
def maybe_tangent(t):
assert type(t) is not CCT
# Generate `tangent` tensor
# if given object is a Tensor and requires grad is set.
if isinstance(t, torch.Tensor) and t.requires_grad:
return torch.randn_like(t)
elif is_tensorlist(t):
return [torch.randn_like(e) if e.requires_grad else None for e in t]
return None
tangent_args = tuple(maybe_tangent(arg) for arg in args)
flat_kwargs, spec = tree_flatten(kwargs)
flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
with fwAD.dual_level():
def maybe_make_dual(dual):
# Returns dual tensor if primal is a tensor/tensor subclass
# with requires_grad set.
primal, tangent = dual
if isinstance(primal, torch.Tensor) and primal.requires_grad:
return fwAD.make_dual(primal.detach(), tangent)
elif is_tensorlist(primal):
return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
for pri, tang in zip(primal, tangent))
return primal
def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
if gradcheck_wrapper is None:
return op(*op_args, **op_kwargs)
return gradcheck_wrapper(op, *op_args, **op_kwargs)
expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
expected = tree_map(fwAD.unpack_dual, expected)
expected_primals = tree_map(lambda x: x.primal, expected)
expected_tangents = tree_map(lambda x: x.tangent, expected)
# Permutations of arg and kwargs in CCT.
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
# Permutations tangent arg and tangent kwargs in CCT.
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
new_tang_args, new_tang_kwargs, \
which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
try:
if gradcheck_wrapper is None:
actual = op(*op_args, **op_kwargs)
else:
actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
raise_composite_compliance_error(
err,
f"- wrapped_args: {which_args_are_wrapped}\n"
f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
)
def unwrap(e):
return e.elem if isinstance(e, CCT) else e
actual = tree_map(fwAD.unpack_dual, actual)
actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)

View File

@ -0,0 +1,586 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import functools
from torch.testing import make_tensor
from torch.testing._internal.opinfo.core import (
OpInfo,
SampleInput,
)
from torch.testing._internal.common_dtype import all_types_and
import numpy as np
from torch.testing._internal.autograd_function_db import (
sample_inputs_numpy_cube,
sample_inputs_numpy_mul,
sample_inputs_numpy_mul_scalar,
sample_inputs_numpy_sort,
sample_inputs_numpy_take,
)
from torch import Tensor
from torch.types import Number
from typing import * # noqa: F403
# Note: [custom op db]
#
# This is a collection of custom operator test cases written as OpInfos
# so they can easily be consumed by OpInfo-based tests to check if subsystems
# support them correctly.
def to_numpy(tensor):
return tensor.cpu().numpy()
@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=())
def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
x_np = to_numpy(x)
dx = torch.tensor(3 * x_np ** 2, device=x.device)
return torch.tensor(x_np ** 3, device=x.device), dx
@numpy_cube.register_fake
def _(x):
return x.clone(), x.clone()
def numpy_cube_setup_context(ctx, inputs, output):
x, = inputs
cube, dx = output
ctx.save_for_backward(x, dx)
def numpy_cube_backward(ctx, grad_out, grad_dx):
x, dx = ctx.saved_tensors
grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x)
return grad_x
numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context)
def numpy_cube_vmap(info, in_dims, x):
result = numpy_cube(x)
return result, (in_dims[0], in_dims[0])
numpy_cube.register_vmap(numpy_cube_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=())
def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@numpy_mul.register_fake
def _(x, y):
assert x.device == y.device
return (x * y).contiguous()
def numpy_mul_setup_context(ctx, inputs, output):
ctx.save_for_backward(*inputs)
def numpy_mul_backward(ctx, grad_out):
x, y = ctx.saved_tensors
grad_x = grad_out * y if ctx.needs_input_grad[0] else None
grad_y = grad_out * x if ctx.needs_input_grad[1] else None
return grad_x, grad_y
numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context)
def numpy_mul_vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = x * y
result = result.movedim(-1, 0)
return result, 0
numpy_mul.register_vmap(numpy_mul_vmap)
@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=())
def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor:
return torch.tensor(to_numpy(x) * scalar, device=x.device)
@numpy_mul_scalar.register_fake
def _(x, *, scalar):
return (x * scalar).contiguous()
def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output):
ctx.scalar = keyword_only_inputs["scalar"]
def numpy_mul_scalar_backward(ctx, grad_out):
grad_x = grad_out * ctx.scalar
return grad_x
numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context)
def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar):
x_bdim, = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
result = x * scalar
result = result.movedim(-1, 0)
return result, 0
numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap)
@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=())
def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]:
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@numpy_sort.register_fake
def _(x, dim):
return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long)
def numpy_sort_setup_context(ctx, inputs, output):
out, ind, ind_inv = output
ctx.dim = inputs[1]
ctx.save_for_backward(ind, ind_inv)
ctx.mark_non_differentiable(ind, ind_inv)
def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv):
ind, ind_inv = ctx.saved_tensors
return numpy_take(grad_out, ind_inv, ind, ctx.dim), None
numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context)
def numpy_sort_vmap(info, in_dims, x, dim):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
dim = dim if dim >= 0 else dim + x.dim() - 1
result = numpy_sort(x, dim + 1)
return result, (0, 0, 0)
numpy_sort.register_vmap(numpy_sort_vmap)
@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=())
def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor:
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@numpy_take.register_fake
def _(x, ind, ind_inv, dim):
assert x.device == ind.device
assert x.device == ind_inv.device
assert ind.dtype == torch.long
assert ind_inv.dtype == torch.long
return torch.empty_like(x)
def numpy_take_setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.dim = dim
ctx.save_for_backward(ind, ind_inv)
def numpy_take_backward(ctx, grad_out):
ind, ind_inv = ctx.saved_tensors
grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim)
return grad_x, None, None, None
numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context)
def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# wrap dim
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def expand_bdim(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
x = expand_bdim(x, x_bdim)
ind = expand_bdim(ind, ind_bdim)
ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
return numpy_take(x, ind, ind_inv, dim + 1), 0
numpy_take.register_vmap(numpy_take_vmap)
@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=())
def numpy_nonzero(x: Tensor) -> Tensor:
x_np = to_numpy(x)
res = np.stack(np.nonzero(x_np), axis=1)
if res.shape[0] <= 1:
raise RuntimeError("not supported")
return torch.tensor(res, device=x.device)
@numpy_nonzero.register_fake
def _(x):
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
shape = [i0, x.dim()]
result = x.new_empty(shape, dtype=torch.long)
return result
def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
shape = 10
result = make_arg(shape, low=0.9, high=2)
mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long)
with torch.no_grad():
result *= mask
yield SampleInput(result, args=())
def numpy_nonzero_vmap(info, in_dims, x):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nonzero.register_vmap(numpy_nonzero_vmap)
@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=())
def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor:
return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device)
@numpy_view_copy.register_fake
def _(x, shape) -> Tensor:
return x.clone().view(shape).clone()
def numpy_view_copy_setup_context(ctx, inputs, output) -> None:
ctx.x_shape = inputs[0].shape
def numpy_view_copy_backward(ctx, grad_out):
return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None
numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context)
def numpy_view_copy_vmap(info, in_dims, x, shape):
x_bdim, _ = in_dims
x = x.movedim(x_bdim, 0)
x_shape = x.shape[0]
batch_shape = (x_shape, *shape)
result = numpy_view_copy(x, batch_shape)
return result, 0
numpy_view_copy.register_vmap(numpy_view_copy_vmap)
def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
result = make_arg(2, 3, 4, low=0.9, high=2)
yield SampleInput(result, args=([2, 12],))
@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=())
def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor:
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
np_xs = [to_numpy(x) for x in xs]
np_out = np.concatenate(np_xs, axis=dim)
return torch.tensor(np_out, device=xs[0].device)
@numpy_cat.register_fake
def _(xs, dim):
assert len(xs) > 0
assert all(x.device == xs[0].device for x in xs)
assert all(x.dtype == xs[0].dtype for x in xs)
return torch.cat(xs, dim=dim)
def numpy_cat_setup_context(ctx, inputs, output):
xs, dim = inputs
ctx.dim_sizes = [x.shape[dim] for x in xs]
ctx.dim = dim
def numpy_cat_backward(ctx, grad_out):
dim_sizes = ctx.dim_sizes
dim = ctx.dim
splits = list(np.cumsum(dim_sizes)[:-1])
grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim)
return grad_xs, None
numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context)
def numpy_cat_vmap(info, in_dims, x, dim):
x_bdim, = in_dims
result = numpy_cat(x, dim)
return result, x_bdim
numpy_cat.register_vmap(numpy_cat_vmap)
def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
r0 = make_arg(2, 3, 4, low=0.9, high=2)
r1 = make_arg(4, 3, 4, low=0.9, high=2)
r2 = make_arg(5, 3, 4, low=0.9, high=2)
yield SampleInput([r0, r1, r2], args=(0,))
@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=())
def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]:
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs]
@numpy_split_copy.register_fake
def _(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)]
def numpy_split_copy_setup_context(ctx, inputs, output):
_, _, dim = inputs
ctx.dim = dim
def numpy_split_copy_backward(ctx, grad_out):
result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim)
return result, None, None
numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context)
def numpy_split_copy_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result = numpy_split_copy(x, splits, dim + 1)
return result, 0
numpy_split_copy.register_vmap(numpy_split_copy_vmap)
def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
x = make_arg(2, 9, low=0.9, high=2)
yield SampleInput(x, args=([1, 3, 6], 1))
@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=())
def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> Tuple[List[Tensor], int]:
x_np = to_numpy(x)
arrs = np.split(x_np, splits, axis=dim)
return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits)
@numpy_split_copy_with_int.register_fake
def _(x, splits, dim):
return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits)
def numpy_split_copy_with_int_setup_context(ctx, inputs, output):
_, _, dim = inputs
ctx.dim = dim
def numpy_split_copy_with_int_backward(ctx, grad_out, _):
return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None
numpy_split_copy_with_int.register_autograd(
numpy_split_copy_with_int_backward,
setup_context=numpy_split_copy_with_int_setup_context)
def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim):
x_bdim, _ , _ = in_dims
x = x.movedim(x_bdim, 0)
result, len_split = numpy_split_copy_with_int(x, splits, dim + 1)
return (result, len_split), ([0 for _ in range(len(result))], None)
numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap)
@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=())
def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor:
# Adapted from Ross Girshick's fast-rcnn implementation at
# https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py
assert boxes.device == scores.device
device = boxes.device
boxes = to_numpy(boxes)
scores = to_numpy(scores)
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= iou_threshold)[0]
order = order[inds + 1]
result = torch.tensor(np.stack(keep), device=device)
# Needed for data-dependent condition :(
assert result.size(0) >= 2
return result
@numpy_nms.register_fake
def _(boxes, scores, iou_threshold):
assert boxes.device == scores.device
N = boxes.shape[0]
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
ctx = torch._custom_op.impl.get_ctx()
i0 = ctx.create_unbacked_symint()
result = boxes.new_empty([i0], dtype=torch.int64)
return result
def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold):
raise NotImplementedError("Operator is data-dependent and cannot be vmapped.")
numpy_nms.register_vmap(numpy_nms_vmap)
def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(make_tensor, device=device, dtype=dtype)
N = 64
xs = make_arg([N], low=0, high=28)
dx = make_arg([N], low=0, high=4)
ys = make_arg([N], low=0, high=28)
dy = make_arg([N], low=0, high=4)
boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad)
scores = make_arg([N], low=0, high=1, requires_grad=requires_grad)
iou_threshold = make_arg([], low=0, high=1).item()
yield SampleInput(boxes, args=(scores, iou_threshold))
custom_op_db = [
OpInfo(
'NumpyCubeCustomOp',
op=numpy_cube._opoverload,
sample_inputs_func=sample_inputs_numpy_cube,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulCustomOp',
op=numpy_mul._opoverload,
sample_inputs_func=sample_inputs_numpy_mul,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyMulScalarCustomOp',
op=numpy_mul_scalar._opoverload,
sample_inputs_func=sample_inputs_numpy_mul_scalar,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpySortCustomOp',
op=numpy_sort._opoverload,
sample_inputs_func=sample_inputs_numpy_sort,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyTakeCustomOp',
op=numpy_take._opoverload,
sample_inputs_func=sample_inputs_numpy_take,
dtypes=all_types_and(torch.bool, torch.half),
supports_out=False,
),
OpInfo(
'NumpyNonzeroCustomOp',
op=numpy_nonzero._opoverload,
sample_inputs_func=sample_inputs_numpy_nonzero,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyNMSCustomOp',
op=torch.ops._torch_testing.numpy_nms,
sample_inputs_func=sample_inputs_numpy_nms,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=False,
supports_out=False,
),
OpInfo(
'NumpyViewCopyCustomOp',
op=torch.ops._torch_testing.numpy_view_copy,
sample_inputs_func=sample_inputs_numpy_view_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
supports_out=False,
),
OpInfo(
'NumpyCatCustomOp',
op=torch.ops._torch_testing.numpy_cat,
sample_inputs_func=sample_inputs_numpy_cat,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyCustomOp',
op=torch.ops._torch_testing.numpy_split_copy,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
OpInfo(
'NumpySplitCopyWithIntCustomOp',
op=torch.ops._torch_testing.numpy_split_copy_with_int,
sample_inputs_func=sample_inputs_numpy_split_copy,
dtypes=all_types_and(torch.bool, torch.half),
gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0],
supports_autograd=True,
check_batched_grad=False,
check_batched_gradgrad=False,
supports_out=False,
),
]
# ==============================================================
# some mechanical test cases
# ==============================================================
lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901
lib.define("source0(Tensor x) -> Tensor")
@torch.library.register_fake("_torch_testing::source0", lib=lib)
def _(x):
return x.clone()
lib.define("source1(Tensor x) -> Tensor")
def source1_fake(x):
return x.clone()
torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib)
lib.define("source2(Tensor x) -> Tensor")
@torch.library.register_fake("_torch_testing::source2", lib=lib)
def _(x):
return x.clone()
lib.define("source3(Tensor x) -> Tensor")
def source3_fake(x):
return x.clone()
torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib)
@torch.library.custom_op("_torch_testing::source4", mutates_args=())
def source4(x: Tensor) -> Tensor:
return x.clone()
@source4.register_fake
def _(x):
return x.clone()
@torch.library.custom_op("_torch_testing::source5", mutates_args=())
def source5(x: Tensor) -> Tensor:
return x.clone()
def source5_fake(x):
return x.clone()
source5.register_fake(source5_fake)

View File

@ -0,0 +1,67 @@
# mypy: ignore-errors
import torch
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import return_and_correct_aliasing
# A simple tensor subclass that holds a tensor with custom metadata and custom method
class ConstantExtraMetadataTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
shape = elem.shape
kwargs = {}
kwargs["strides"] = elem.stride()
kwargs["storage_offset"] = elem.storage_offset()
kwargs["device"] = elem.device
kwargs["layout"] = elem.layout
kwargs["requires_grad"] = elem.requires_grad
kwargs["dtype"] = elem.dtype
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
def __init__(self, elem):
self.elem = elem
self.constant_attribute = 4
def __repr__(self):
inner_repr = repr(self.elem)
return f"CustomTensor({inner_repr})"
def __tensor_flatten__(self):
return ["elem"], self.constant_attribute
def add_constant(self, a):
self.constant_attribute += a
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
assert meta is not None
elem = inner_tensors["elem"]
out = ConstantExtraMetadataTensor(elem)
out.constant_attribute = meta
return out
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
args_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, args
)
kwargs_inner = pytree.tree_map_only(
ConstantExtraMetadataTensor, lambda x: x.elem, kwargs
)
out_inner = func(*args_inner, **kwargs_inner)
out_inner_flat, spec = pytree.tree_flatten(out_inner)
# for aten ops that return non-tensors, just assume that
# our cust inner tensors return the same value
out_flat = [
ConstantExtraMetadataTensor(o_inner)
if isinstance(o_inner, torch.Tensor)
else o_inner
for o_inner in out_inner_flat
]
out = pytree.tree_unflatten(out_flat, spec)
return return_and_correct_aliasing(func, args, kwargs, out)

View File

@ -0,0 +1 @@
# mypy: ignore-errors

View File

@ -0,0 +1,10 @@
# mypy: ignore-errors
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(10, 20)

View File

@ -0,0 +1,11 @@
# mypy: ignore-errors
import torch.nn as nn
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(10, 20)
self.relu = nn.ReLU()

View File

@ -0,0 +1,200 @@
# mypy: ignore-errors
import re
import sys
import time
from functools import partial, wraps
from typing import Tuple
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch.distributed.rpc import _rref_context_get_debug_info
from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN
if not dist.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}"
def dist_init(
old_test_method=None,
setup_rpc: bool = True,
clean_shutdown: bool = True,
faulty_messages=None,
messages_to_delay=None,
):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
Note: pass the string representation of MessageTypes that should be used
with the faulty agent's send function. By default, all retriable messages
("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE",
"CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is
set from faulty_rpc_agent_test_fixture.py).
"""
# If we use dist_init without arguments (ex: @dist_init), old_test_method is
# appropriately set and we return the wrapper appropriately. On the other
# hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
# old_test_method is None and we return a functools.partial which is the real
# decorator that is used and as a result we recursively call dist_init with
# old_test_method and the rest of the arguments appropriately set.
if old_test_method is None:
return partial(
dist_init,
setup_rpc=setup_rpc,
clean_shutdown=clean_shutdown,
faulty_messages=faulty_messages,
messages_to_delay=messages_to_delay,
)
@wraps(old_test_method)
def new_test_method(self, *arg, **kwargs):
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
# in tests.
import torch.distributed.rpc.api as api
api._ignore_rref_leak = False
self.worker_id = self.rank
self.setup_fault_injection(faulty_messages, messages_to_delay)
rpc_backend_options = self.rpc_backend_options
if setup_rpc:
if TEST_WITH_TSAN:
# TSAN runs much slower.
rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5
rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
rpc.init_rpc(
name="worker%d" % self.rank,
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc_backend_options,
)
return_value = old_test_method(self, *arg, **kwargs)
if setup_rpc:
rpc.shutdown(graceful=clean_shutdown)
return return_value
return new_test_method
def noop() -> None:
pass
def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
"""
Loops until an RPC to the given rank fails. This is used to
indicate that the node has failed in unit tests.
Args:
rank (int): Rank of the node expected to fail
expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
occurs, not just any.
"""
while True:
try:
rpc.rpc_sync(f"worker{rank}", noop, args=())
time.sleep(0.1)
except Exception as e:
if re.search(pattern=expected_error_regex, string=str(e)):
return str(e)
def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
"""
The RRef protocol holds forkIds of rrefs in a map until those forks are
confirmed by the owner. The message confirming the fork may arrive after
our tests check whether this map is empty, which leads to failures and
flaky tests. to_here also does not guarantee that we have finished
processind the owner's confirmation message for the RRef. This function
loops until the map is empty, which means the messages have been received
as processed. Call this function before asserting the map returned by
_get_debug_info is empty.
"""
start = time.time()
while True:
debug_info = _rref_context_get_debug_info()
num_pending_futures = int(debug_info["num_pending_futures"])
num_pending_users = int(debug_info["num_pending_users"])
if num_pending_futures == 0 and num_pending_users == 0:
break
time.sleep(0.1)
if time.time() - start > timeout:
raise ValueError(
f"Timed out waiting to flush pending futures and users, "
f"had {num_pending_futures} pending futures and {num_pending_users} pending users"
)
def get_num_owners_and_forks() -> Tuple[str, str]:
"""
Retrieves number of OwnerRRefs and forks on this node from
_rref_context_get_debug_info.
"""
rref_dbg_info = _rref_context_get_debug_info()
num_owners = rref_dbg_info["num_owner_rrefs"]
num_forks = rref_dbg_info["num_forks"]
return num_owners, num_forks
def wait_until_owners_and_forks_on_rank(
num_owners: int, num_forks: int, rank: int, timeout: int = 20
) -> None:
"""
Waits until timeout for num_forks and num_owners to exist on the rank. Used
to ensure proper deletion of RRefs in tests.
"""
start = time.time()
while True:
num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync(
worker_name(rank), get_num_owners_and_forks, args=(), timeout=5
)
num_owners_on_rank = int(num_owners_on_rank)
num_forks_on_rank = int(num_forks_on_rank)
if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks:
return
time.sleep(1)
if time.time() - start > timeout:
raise ValueError(
f"Timed out waiting {timeout} sec for {num_owners} owners and {num_forks} forks on rank,"
f" had {num_owners_on_rank} owners and {num_forks_on_rank} forks"
)
def initialize_pg(init_method, rank: int, world_size: int) -> None:
# This is for tests using `dist.barrier`.
if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=rank,
world_size=world_size,
)
def worker_name(rank: int) -> str:
return f"worker{rank}"
def get_function_event(function_events, partial_event_name):
"""
Returns the first event that matches partial_event_name in the provided
function_events. These function_events should be the output of
torch.autograd.profiler.function_events().
Args:
function_events: function_events returned by the profiler.
event_name (str): partial key that the event was profiled with.
"""
event = [event for event in function_events if partial_event_name in event.name][0] # noqa: RUF015
return event

View File

@ -0,0 +1 @@
# mypy: allow-untyped-defs

View File

@ -0,0 +1,98 @@
# mypy: allow-untyped-defs
import sys
from functools import wraps, partial
import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
TEST_SKIPS,
tp_transports,
)
TEST_GPU_NUM = 4
class ShardedTensorTestBase(MultiProcessTestCase):
@property
def world_size(self):
return TEST_GPU_NUM
def init_pg(self, backend="nccl"):
if backend not in ["nccl", "gloo", "mpi"]:
raise RuntimeError(f"Backend {backend} not supported!")
dist.init_process_group(
backend=backend,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if backend == "nccl":
torch.cuda.set_device(self.rank)
def init_rpc(self):
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports())
rpc_backend_options.init_method = f"file://{self.file_name}"
for rank in range(self.world_size):
rpc_backend_options.set_device_map(
f"worker{rank}", {rank: self.rank, self.rank: rank}
)
rpc.init_rpc(
name="worker%d" % self.rank,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=rpc_backend_options,
)
def init_comms(self, init_rpc=True, backend="nccl"):
if init_rpc:
self.init_rpc()
self.init_pg(backend=backend)
def destroy_comms(self, destroy_rpc=True):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
if destroy_rpc:
rpc.shutdown()
dist.destroy_process_group()
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def assert_sharded_tensor_equal(self, st1, st2):
st1_local_shards = st1.local_shards()
st2_local_shards = st2.local_shards()
self.assertEqual(len(st1_local_shards), len(st2_local_shards))
for i, st1_local_shard in enumerate(st1_local_shards):
self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
self.assertEqual(st1.metadata(), st2.metadata())
self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
# wrapper to initialize comms (processgroup + rpc)
def with_comms(func=None, init_rpc=True, backend="nccl"):
if func is None:
return partial(
with_comms,
init_rpc=init_rpc,
backend=backend,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if backend == "nccl" and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.init_comms(init_rpc=init_rpc, backend=backend)
func(self, *args, **kwargs)
self.destroy_comms(destroy_rpc=init_rpc)
return wrapper

View File

@ -0,0 +1,136 @@
# mypy: allow-untyped-defs
import builtins
import torch
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata,
)
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
def generate_chunk_sharding_specs_for_test(sharding_dim):
return [
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
),
# Test different ordering. (Case 1)
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:2/cuda:2",
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
],
),
# Test different ordering. (Case 2)
ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:3/cuda:3",
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
],
),
]
def generate_enumerable_sharding_specs_for_test():
return [
EnumerableShardingSpec(
[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[5, 5],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[5, 0],
shard_sizes=[5, 5],
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[0, 5],
shard_sizes=[5, 5],
placement="rank:2/cuda:2",
),
ShardMetadata(
shard_offsets=[5, 5],
shard_sizes=[5, 5],
placement="rank:3/cuda:3",
),
]
)
]
def generate_local_weight_sharding_params_for_test(
local_weight, sharded_dim, gpu_num, spec, rank
):
"""
Shard the local weight based the given spec, so we can compare against
the one from sharded tensor.
Args:
local_weight: weight matrix to be sharded.
sharded_dim: The dimension which we shard on.
gpu_num: number of ranks.
spec: sharding spec.
rank: # of cuda process.
Returns:
start_pos: start position of sharded weight on the given rank.
chunk_size: chunk size of sharded weight on the given rank.
"""
sharding_dim_size = local_weight.size(sharded_dim)
split_size = get_split_size(sharding_dim_size, gpu_num)
current_offsets = 0
start_pos = current_offsets
for idx, placement in enumerate(spec.placements):
chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
if rank == placement.rank():
start_pos = current_offsets
break
current_offsets += chunk_size
return start_pos, chunk_size
def clone_module_parameter(module, param_name):
"""
Clone a parameter from a given existing module.
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be cloned.
param_name (str): Name of the parameter of ``module`` that needs to be cloned.
Returns: cloned tensor as :class:`torch.nn.Parameter`.
"""
tensor = getattr(module, param_name)
return torch.nn.Parameter(tensor.detach().clone())
def gen_binary_op_func(python_op, inplace=False):
src_lines = ['def f(lhs, rhs):']
if "torch" in python_op:
src_lines.append(f' return {python_op}(lhs, rhs)\n')
elif inplace:
src_lines.append(f' lhs {python_op}= rhs\n return lhs\n')
else:
src_lines.append(f' return lhs {python_op} rhs\n')
code_str = '\n'.join(src_lines)
g = {'torch': torch}
builtins.exec(code_str, g)
return g["f"]

View File

@ -0,0 +1,66 @@
# mypy: allow-untyped-defs
import copy
import random
import torch
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
PLACEMENTS = [
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
]
DEFAULT_GPU_NUM = 4
def _chunk_sharding_specs_list_for_test(sharding_dims, seed=0):
spec_list = []
for i in range(len(sharding_dims)):
random.Random(seed + i).shuffle(PLACEMENTS)
spec_list.append(
ChunkShardingSpec(
dim=sharding_dims[i],
placements=copy.deepcopy(PLACEMENTS),
)
)
return spec_list
class MyShardedModel2(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor2 = sharded_tensor.rand(
spec, 10, 20, process_group=group, init_rrefs=init_rrefs
)
else:
self.sharded_tensor2 = None
self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2))
class MyShardedModel1(torch.nn.Module):
def __init__(
self,
spec=None,
group=None,
init_rrefs=True
) -> None:
super().__init__()
if spec is not None:
self.sharded_tensor1 = sharded_tensor.rand(
spec, 10, 20, process_group=group, init_rrefs=init_rrefs
)
else:
self.sharded_tensor1 = None
self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2))
self.submodule = MyShardedModel2(spec, group, init_rrefs)

View File

@ -0,0 +1,42 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
class SimpleMegatronLM(nn.Module):
def __init__(self, linear_size, rank=None, dtype=torch.float32):
super().__init__()
self.fc1 = nn.Linear(*linear_size[0], dtype=dtype)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(*linear_size[1], dtype=dtype)
if rank is not None:
self.fc1.cuda(rank)
self.fc2.cuda(rank)
def forward(self, inp):
return self.fc2(self.gelu(self.fc1(inp)))
def get_weights(self):
if isinstance(self.fc1.weight, ShardedTensor):
weight1 = self.fc1.weight.local_tensor()
else:
weight1 = self.fc1.weight
if isinstance(self.fc2.weight, ShardedTensor):
weight2 = self.fc2.weight.local_tensor()
else:
weight2 = self.fc2.weight
return (weight1, weight2)
def get_biases(self):
return (self.fc1.bias, self.fc2.bias)
def get_weight_grads(self):
return (self.fc1.weight.grad, self.fc2.weight.grad)
def get_bias_grads(self):
return (self.fc1.bias.grad, self.fc2.bias.grad)

View File

@ -0,0 +1,548 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import itertools
import sys
from dataclasses import dataclass
from functools import wraps
from typing import Any, Callable, cast, Dict, Iterator, List, Sequence, Tuple, TypeVar
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.distributed._tensor.placement_types import Placement
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
MultiThreadedTestCase,
skip_if_lt_x_gpu,
run_subtests,
TEST_SKIPS,
)
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
DEVICE_TYPE = (
"cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else "cpu"
)
NUM_DEVICES = 4
# We use this as a proxy for "multiple GPUs exist"
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# when we actually have multiple GPUs, relax the requirement to smaller counts.
NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count())
T = TypeVar("T")
# simple RMSNorm layer for testing
class RMSNormPython(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x)
return output * self.weight
class MLPModule(nn.Module):
def __init__(self, device, bias: bool = True):
super().__init__()
torch.manual_seed(5)
self.net1 = nn.Linear(10, 16, bias=bias, device=device)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 10, bias=bias, device=device)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def reset_parameters(self):
self.net1.reset_parameters()
self.net2.reset_parameters()
class MLPStacked(nn.Module):
def __init__(self, device, n_layers: int = 2):
super().__init__()
self.layers = nn.ModuleList([MLPModule(device) for i in range(n_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
@dataclass
class ModelArgs:
n_layers: int = 2
vocab_size: int = 8
max_seq_len: int = 16
dim: int = 16
n_heads: int = 4
dropout_p: float = 0.1
use_attn_mask: bool = True
weight_tying: bool = True
checkpoint_activations: bool = False
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
self.dropout_p = args.dropout_p
self.resid_dropout = nn.Dropout(args.dropout_p)
self.use_attn_mask = args.use_attn_mask
self.wq = nn.Linear(args.dim, args.dim, bias=False)
self.wk = nn.Linear(args.dim, args.dim, bias=False)
self.wv = nn.Linear(args.dim, args.dim, bias=False)
self.wo = nn.Linear(args.dim, args.dim, bias=False)
def forward(self, x):
bsz, seq_len, _ = x.size()
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
output = F.scaled_dot_product_attention(
queries,
keys,
values,
None,
self.dropout_p if self.training else 0,
self.use_attn_mask,
)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.resid_dropout(self.wo(output))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_p):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim)
self.gelu = nn.GELU()
self.w2 = nn.Linear(hidden_dim, dim)
self.resid_dropout = nn.Dropout(dropout_p)
def forward(self, x):
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention_norm = nn.LayerNorm(args.dim)
self.attention = Attention(args)
self.ffn_norm = nn.LayerNorm(args.dim)
self.feed_forward = FeedForward(
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
)
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size is not None
assert args.max_seq_len is not None
self.model_args = args
self.max_seq_len = args.max_seq_len
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
self.dropout = nn.Dropout(args.dropout_p)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = nn.LayerNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
if args.weight_tying:
self.output.weight = self.tok_embeddings.weight
self.checkpoint_activations = args.checkpoint_activations
def forward(self, tokens):
_bsz, seq_len = tokens.size()
assert seq_len <= self.max_seq_len
h = self.tok_embeddings(tokens)
pos = torch.arange(0, seq_len, device=tokens.device)
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
h = h + p
h = self.dropout(h)
for layer in self.layers:
if self.checkpoint_activations:
h = torch.utils.checkpoint.checkpoint(layer, h, use_reentrant=False)
else:
h = layer(h)
h = self.norm(h)
output = self.output(h).float()
return output
@staticmethod
def parallelize(
module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: bool = False
) -> nn.Module:
assert isinstance(module, Transformer), f"Requires Transformer but got {module}"
# Parallelize the root submodules.
if use_seq_parallel:
root_plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(0)),
"norm": SequenceParallel(),
}
else:
root_plan = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
"pos_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()),
}
module_tp = parallelize_module(module, device_mesh, root_plan)
# Parallelize the attention and feed forward submodules.
for layer in module_tp.layers:
layer_parallelize_plan = {}
if use_seq_parallel:
layer_parallelize_plan["attention"] = PrepareModuleInput(
input_layouts=Shard(1),
desired_input_layouts=Replicate(),
)
# shard the RMSNorms
layer_parallelize_plan["attention_norm"] = SequenceParallel()
layer_parallelize_plan["ffn_norm"] = SequenceParallel()
layer_parallelize_plan["attention.wq"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wk"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wv"] = ColwiseParallel(use_local_output=local_output_for_attn)
layer_parallelize_plan["attention.wo"] = (
RowwiseParallel(output_layouts=Shard(1))
if use_seq_parallel
else RowwiseParallel()
)
layer_parallelize_plan["feed_forward.w1"] = (
ColwiseParallel(input_layouts=Shard(1))
if use_seq_parallel
else ColwiseParallel()
)
layer_parallelize_plan["feed_forward.w2"] = (
RowwiseParallel(output_layouts=Shard(1))
if use_seq_parallel
else RowwiseParallel()
)
parallelize_module(layer, device_mesh, layer_parallelize_plan)
# Parallelize the output submodule. If weight tying is enabled, we need to
# make sure output.weight is sharded consistently as tok_embeddings.weight,
# at the cost of the all_reduce operation using RowwiseParallel.
output_parallelize_plan = (
ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
)
if use_seq_parallel
else ColwiseParallel(output_layouts=Replicate())
)
parallelize_module(module_tp.output, device_mesh, output_parallelize_plan)
if local_output_for_attn:
for layer in module_tp.layers:
layer.attention.n_heads = module_tp.model_args.n_heads // device_mesh.size()
# Manually set output.weight so that parameters and gradients are shared.
if module_tp.model_args.weight_tying:
module_tp.output.weight = module_tp.tok_embeddings.weight
return module_tp
def skip_unless_torch_gpu(method: T) -> T:
"""
Test decorator which skips the test unless there's a GPU available to torch.
>>> # xdoctest: +SKIP
>>> @skip_unless_torch_gpu
>>> def test_some_method(self) -> None:
>>> ...
"""
# The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
class DTensorTestBase(MultiProcessTestCase):
@property
def world_size(self) -> int:
return NUM_DEVICES
@property
def backend(self) -> str:
backend = "nccl" if self.device_type == "cuda" else "gloo"
return backend
def build_device_mesh(self) -> DeviceMesh:
return DeviceMesh(self.device_type, list(range(self.world_size)))
def init_pg(self) -> None:
if "nccl" in self.backend and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl"]:
raise RuntimeError(f"Backend {self.backend} not supported!")
dist.init_process_group(
backend=self.backend,
world_size=self.world_size,
rank=self.rank, # pyre-ignore[16]
init_method=f"file://{self.file_name}", # pyre-ignore[16]
)
# set device for nccl pg for collectives
if "nccl" in self.backend:
torch.cuda.set_device(self.rank)
def destroy_pg(self) -> None:
# Wait for all ranks to reach here before starting shutdown.
# FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
# dist.all_reduce(torch.zeros((1,), device="cuda" if torch.cuda.is_available() else "cpu"))
# FIXME can't use the above all_reduce as it causes hangs on bionic and focal. It hangs:
# test_dtensor.py -- DTensorMeshTest.test_dtensor_device_mesh_device_conversion
dist.barrier()
dist.destroy_process_group()
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
# pyre-ignore[2]:
def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
out = op_call(*args, **kwargs)
dtc = DTensorConverter(mesh, args, kwargs)
for d_args, d_kwargs in dtc:
# pyre can't find assertTrue anymore?
self.assertEqual(dtc.successful(), True)
d_out = op_call(*d_args, **d_kwargs)
self.assertEqual(d_out.full_tensor(), out)
def run_subtests(self, *args, **kwargs):
return run_subtests(self, *args, **kwargs)
TestFunc = Callable[[object], object]
# wrapper to initialize comms (processgroup)
def with_comms(func: TestFunc) -> TestFunc:
assert func is not None
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
# if enough GPU we can use GPU, otherwise we fallback to CPU
if not torch.cuda.is_available() or torch.cuda.device_count() < self.world_size:
self.device_type = "cpu"
else:
self.device_type = DEVICE_TYPE
self.init_pg()
try:
func(self, *args, **kwargs) # type: ignore[misc]
except Exception as e:
dist.destroy_process_group()
raise e
self.destroy_pg()
return wrapper
class DTensorOpTestBase(MultiThreadedTestCase):
@property
def world_size(self) -> int:
return NUM_DEVICES
@property
def device_type(self) -> str:
return DEVICE_TYPE
def build_device_mesh(self):
return DeviceMesh(self.device_type, list(range(self.world_size)))
def setUp(self) -> None:
super().setUp()
self._spawn_threads()
# This is a class for converting args/kwargs of an op into distributed args/kwargs
class DTensorConverter:
def __init__(
self,
mesh: DeviceMesh,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> None:
self.hit = 0
self.miss = 0
self.mesh = mesh
self.args = args
self.kwargs = kwargs
flatten_args, flatten_args_spec = tree_flatten(args)
flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
self.flatten_args: List[object] = flatten_args
self.flatten_args_spec: TreeSpec = flatten_args_spec
self.flatten_kwargs: List[object] = flatten_kwargs
self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
choices_for_args = []
for arg in self.flatten_args:
if isinstance(arg, torch.Tensor):
choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
for arg in self.flatten_kwargs:
if isinstance(arg, torch.Tensor):
choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
self.sharding_combs: Iterator[Sequence[Placement]] = iter(
itertools.product(*choices_for_args)
)
def successful(self) -> bool:
return self.hit > 0 and self.miss == 0
def is_supported_tensor(self, t: torch.Tensor) -> bool:
# TODO: dist tensor need to support quantized and sparse
# tensors, quantized tensor might be relatively easy, but
# sparse tensor have special layouts that we need to possibly
# deal with, until we are clear about them, we don't officially
# support them.
return not any(
[
t.is_sparse_csr,
t.is_sparse,
t.is_mkldnn,
t.is_quantized,
t.is_nested,
torch._is_functional_tensor(t),
t.is_neg(),
t.is_conj(),
t.device.type in ("lazy", "meta"),
# We need a way to test if a tensor is batched but there
# is no official APi to do it
# torch._C._is_batched(t),
]
)
def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
mesh_size = self.mesh.size()
sharding_choices: List[Placement] = [Replicate()]
# c10d collective does not support bool tensor
# for bool tensor we treat it as replicated
if arg.dtype != torch.bool:
# only generating choices with: replicate, or sharding
# evenly on a dimension that could be sharded
sharding_choices = sharding_choices + [
Shard(i)
for i, s in enumerate(arg.shape)
if s > 1 and s % mesh_size == 0
]
# TODO: add multi mesh choices
# all_choices = itertools.product(
# *(self.mesh.ndim * [sharding_choices])
# )
return sharding_choices
def __iter__(self) -> "DTensorConverter":
return self
def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]:
try:
next_sharding_choices = next(self.sharding_combs)
idx = 0
new_args: List[object] = []
for arg in self.flatten_args:
if isinstance(arg, torch.Tensor):
new_args.append(
self.to_dist_tensor(
arg, self.mesh, [next_sharding_choices[idx]]
)
)
idx += 1
else:
new_args.append(arg)
new_kwargs: List[object] = []
for arg in self.flatten_kwargs:
if isinstance(arg, torch.Tensor):
new_kwargs.append(
self.to_dist_tensor(
arg, self.mesh, [next_sharding_choices[idx]]
)
)
idx += 1
else:
new_kwargs.append(arg)
return (
tree_unflatten(new_args, self.flatten_args_spec),
tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
)
except StopIteration as e:
raise StopIteration from e
def to_dist_tensor(
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
) -> torch.Tensor:
if type(t) is torch.Tensor or type(t) is nn.Parameter:
if self.is_supported_tensor(t):
self.hit += 1
if t.ndim == 0:
# scalar tensor by default will be replicated
r = distribute_tensor(t, mesh, [Replicate()] * mesh.ndim)
else:
# distribute non-scalar tensors
r = distribute_tensor(t, mesh, placements)
if type(t) is nn.Parameter:
r = nn.Parameter( # type: ignore[assignment]
r, requires_grad=r.requires_grad
)
return r
else:
self.miss += 1
return t
elif torch.overrides.is_tensor_like(t):
# Blindly converting tensor subclasses to dist tensor can cause
# unpredictable problems, we explicitly disable this conversion
# for now (i.e. we don't support DTensor holding tensor subclass
# until there's a strong reason later).
self.miss += 1
return t
else:
raise RuntimeError(f"Trying to convert to DTensor, but got {type(t)}")

View File

@ -0,0 +1,51 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import os
import shutil
import tempfile
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
def with_temp_dir(
func: Optional[Callable] = None,
) -> Optional[Callable]:
"""
Wrapper to initialize temp directory for distributed checkpoint.
"""
assert func is not None
@wraps(func)
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
if dist.is_initialized():
# Only create temp_dir when rank is 0
if dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {temp_dir}")
else:
temp_dir = ""
object_list = [temp_dir]
# Broadcast temp_dir to all the other ranks
os.sync()
dist.broadcast_object_list(object_list)
self.temp_dir = object_list[0]
os.sync()
else:
temp_dir = tempfile.mkdtemp()
print(f"No process group initialized, using temp directory: {temp_dir}")
self.temp_dir = temp_dir
try:
func(self, *args, **kwargs)
finally:
if dist.is_initialized() and dist.get_rank() == 0:
shutil.rmtree(self.temp_dir, ignore_errors=True)
else:
shutil.rmtree(self.temp_dir, ignore_errors=True)
return wrapper

View File

@ -0,0 +1,122 @@
# mypy: allow-untyped-defs
# Owner(s): ["oncall: distributed"]
import copy
from itertools import chain
from typing import Any, Dict
import torch
import torch.nn as nn
from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _gather_state_dict
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
_PG,
_STATE,
set_state_dict,
StateDictOptions,
)
class VerifyStateDictMixin:
def _compare_tensor(self, orig_tensor, dist_tensor, offload_to_cpu=False):
if isinstance(dist_tensor, (DTensor, ShardedTensor)):
dist_tensor = _gather_state_dict({"mykey": dist_tensor}).pop("mykey")
if offload_to_cpu:
orig_tensor = orig_tensor.cpu()
dist_tensor = dist_tensor.cpu()
self.assertTrue(isinstance(dist_tensor, torch.Tensor))
self.assertTrue(torch.allclose(orig_tensor, dist_tensor))
def _verify_msd(
self,
msd: Dict[str, Any],
dist_msd: Dict[str, Any],
options: StateDictOptions = StateDictOptions(),
offload_to_cpu=False,
) -> None:
if not options.ignore_frozen_params:
self.assertEqual(len(msd), len(dist_msd))
for fqn, param in msd.items():
dist_param = dist_msd.get(fqn, None)
if not options.ignore_frozen_params:
self.assertIsNotNone(dist_param, f"{fqn=}")
try:
self._compare_tensor(param, dist_param, offload_to_cpu)
except AssertionError as e:
raise AssertionError(
f"{fqn} has mismatched value {param} {dist_param}"
) from e
elif dist_param is None:
self.assertFalse(param.requires_grad, f"{fqn=}")
def _verify_osd(
self,
model: nn.Module,
optim: torch.optim.Optimizer,
osd: Dict[str, Any],
dist_osd: Dict[str, Any],
) -> None:
params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
param_pid_mapping = dict(zip(params, range(len(params))))
fqn_pid_mapping = {}
for fqn, param in model.named_parameters():
pid = param_pid_mapping[param]
fqn_pid_mapping[fqn] = pid
fqn_pid_mapping[pid] = fqn
# Check optimizer_state_dict state
self.assertEqual(len(osd[_STATE]), len(dist_osd[_STATE]))
for pid, states in osd[_STATE].items():
fqn = fqn_pid_mapping[pid]
dist_states = dist_osd[_STATE].get(fqn, None)
self.assertIsNotNone(dist_states, fqn)
self.assertEqual(len(states), len(dist_states))
for key, state in states.items():
dist_state = states.get(key, None)
self.assertIsNotNone(dist_state)
self._compare_tensor(state, dist_state)
# Check optimizer_state_dict param_group
old_dist_osd_pg = dist_osd[_PG]
if len(osd[_PG]) != len(dist_osd[_PG]):
self.assertTrue(len(dist_osd[_PG]) > len(osd[_PG]))
new_pg = copy.deepcopy(dist_osd[_PG][0])
new_pg["params"] = []
for dist_group in dist_osd[_PG]:
new_pg["params"].extend(dist_group["params"])
dist_osd[_PG] = [new_pg]
self.assertEqual(len(osd[_PG]), len(dist_osd[_PG]))
for group, dist_group in zip(osd[_PG], dist_osd[_PG]):
self.assertEqual(len(group), len(dist_group))
for key, value in group.items():
# Below doesn't work because param_groups can have None
# values.
# dist_value = dist_group.get(key, None)
# self.assertIsNotNone(dist_value, (dist_group, group))
dist_value = dist_group[key]
if key == "params":
fqns = [fqn_pid_mapping[pid] for pid in value]
self.assertEqual(sorted(fqns), sorted(dist_value))
else:
self.assertEqual(value, dist_value)
dist_osd[_PG] = old_dist_osd_pg
def _verify_osd_by_load(
self,
model: nn.Module,
optim: torch.optim.Optimizer,
new_optim: torch.optim.Optimizer,
dist_osd: Dict[str, Any],
) -> None:
new_dist_osd = _gather_state_dict(dist_osd)
set_state_dict(
model,
optimizers=new_optim,
model_state_dict={},
optim_state_dict=new_dist_osd,
)
self.assertEqual(optim.state_dict(), new_optim.state_dict())

View File

@ -0,0 +1,733 @@
# mypy: allow-untyped-defs
import contextlib
import enum
import logging
import os
import threading
from typing import NamedTuple
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.nn as nn
from torch.distributed import rpc
from torch.distributed.nn import RemoteModule
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
requires_gloo,
requires_nccl,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
)
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
RpcAgentTestFixture,
)
NUM_EM_ROW = 2
D_SPARSE = 3
D_DENSE = 2
D_HID = 3
D_OUT = 1
NUM_TRAINERS = 4
# Trainers + the master + the remote worker
WORLD_SIZE = NUM_TRAINERS + 2
TRAINER_RANKS = list(range(NUM_TRAINERS))
REMOTE_WORKER_RANK = TRAINER_RANKS[-1] + 1
MASTER_RANK = REMOTE_WORKER_RANK + 1
class DdpMode(enum.Enum):
# Don't apply DDP
NONE = enum.auto()
# Apply DDP to the top level nn.Module
OUTSIDE = enum.auto()
# Embed DDP inside the top level nn.Module
INSIDE = enum.auto()
def init_logger():
logger = logging.getLogger(__name__)
level = logging.DEBUG if "debug" in os.environ else logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
# add the handlers to the logger
logger.addHandler(console)
logger.propagate = False
return logger
gLogger = init_logger()
class FeatureSet(NamedTuple):
""" A feature set has 2 types of features"""
dense_features: torch.Tensor
sparse_features: torch.LongTensor
values: torch.Tensor
def _call_method(method, rref, *args, **kwargs):
return method(rref.local_value(), *args, **kwargs)
def _remote_method(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_sync(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
def _remote_method_async(method, rref, *args, **kwargs):
args_tup = tuple([method, rref] + list(args))
return rpc.rpc_async(rref.owner(), _call_method, args=args_tup, kwargs=kwargs)
class RemoteEM(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int):
gLogger.info("Initing RemoteEM with %s %s", num_embeddings, embedding_dim)
super().__init__()
init_em = [0.5] * embedding_dim
self.em = nn.EmbeddingBag(
num_embeddings,
embedding_dim,
_weight=torch.tensor([init_em] * num_embeddings),
)
def forward(self, input: torch.Tensor):
gLogger.debug("Running RemoteEM.forward() on: %s", input)
return self.em(input, offsets=torch.LongTensor(range(input.shape[0])))
# Return a linear module with predefined parameters.
def getLinear(d_in, d_out):
l = nn.Linear(d_in, d_out, bias=False)
w = torch.ones((d_out, d_in))
w[0][0] = -1
w.requires_grad_()
l.weight.data = w
return l
class RemoteNet(nn.Module):
def __init__(self, d_in: int, d_out: int):
gLogger.info("Initing RemoteNet with %s %s", d_in, d_out)
super().__init__()
self.fc = getLinear(d_in, d_out)
self.relu = nn.ReLU()
def forward(self, input: torch.Tensor):
gLogger.debug("Running RemoteNet.forward() on: %s", input)
return self.relu(self.fc(input))
class HybridModel(nn.Module):
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
process_group_for_ddp: dist.ProcessGroup = None,
):
super().__init__()
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.fc1 = getLinear(D_DENSE, D_DENSE)
self.fc2 = getLinear(D_HID, D_OUT)
self.non_ddp_params = tuple(self.fc1.parameters()) + tuple(
self.fc2.parameters()
)
self.ddp_params = ()
if process_group_for_ddp is not None:
self.non_ddp_params, self.ddp_params = (
tuple(self.fc1.parameters()),
tuple(self.fc2.parameters()),
)
gLogger.info("Use DDP for the second local net.")
self.fc2 = DistributedDataParallel(
self.fc2, check_reduction=True, process_group=process_group_for_ddp
)
gLogger.info(
"HybridModel has %s groups of parameters.", len(list(self.parameters()))
)
def forward(self, input: FeatureSet):
gLogger.debug("Running HybridModel.forward on %s", input)
sparse = _remote_method(
RemoteEM.forward, self.remote_em_rref, input.sparse_features
)
# The same size of mini batch.
assert sparse.shape[0] == input.dense_features.shape[0]
dense = self.fc1(input.dense_features)
x = torch.cat((dense, sparse), 1)
gLogger.debug("Concatenated feature: %s", x)
x = _remote_method(RemoteNet.forward, self.remote_net_rref, x)
return self.fc2(x)
class Trainer:
def __init__(
self,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
ddp_mode: DdpMode,
rank: int,
):
self.rank = rank
self.trainer_group = (
dist.new_group(TRAINER_RANKS)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE)
else None
)
self.remote_em_rref = remote_em_rref
self.remote_net_rref = remote_net_rref
self.hybrid_module = HybridModel(
self.remote_em_rref,
self.remote_net_rref,
self.trainer_group if ddp_mode in (DdpMode.INSIDE,) else None,
)
self.ddp_params, self.non_ddp_params = (
self.hybrid_module.ddp_params,
self.hybrid_module.non_ddp_params,
)
if ddp_mode == DdpMode.OUTSIDE:
gLogger.info("Wrapping the whole hybrid module into DDP.")
self.ddp_params += self.non_ddp_params
self.non_ddp_params = ()
self.hybrid_module = DistributedDataParallel(
self.hybrid_module,
check_reduction=True,
process_group=self.trainer_group,
)
gLogger.info(
"Succeeded in creating a HybridModel instance with "
"%s ddp params and %s other local params.",
len(self.ddp_params), len(self.non_ddp_params)
)
def destroy_pg(self):
if self.trainer_group:
dist.destroy_process_group(self.trainer_group)
def train_batch(
self,
mini_batch: FeatureSet,
trainer_has_less_inputs: bool,
simulate_uneven_inputs: bool,
):
grads_dict = None
if not simulate_uneven_inputs:
input_batches = [mini_batch]
else:
# Split into microbatches, and trim to simulate uneven inputs.
dense_features = mini_batch.dense_features
sparse_features = mini_batch.sparse_features
values = mini_batch.values
dense_microbatch = torch.split(dense_features, 2)
sparse_microbatch = torch.split(sparse_features, 2)
values_microbatch = torch.split(values, 2)
batches = []
for d, s, v in zip(dense_microbatch, sparse_microbatch, values_microbatch):
feature_set = FeatureSet(dense_features=d, sparse_features=s, values=v)
batches.append(feature_set)
if trainer_has_less_inputs:
input_batches = batches[: len(batches) // 2]
gLogger.info(
"Trainer reduced input patches from %s "
"to %s to simulate uneven inputs.",
len(batches), len(input_batches)
)
else:
input_batches = batches
with self.hybrid_module.join() if simulate_uneven_inputs else contextlib.nullcontext():
for b in input_batches:
with dist_autograd.context() as context_id:
output = self.hybrid_module.forward(b)
loss = (output * mini_batch.values).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info(
"Loss is %s for mini batch: %s. "
"Grads dict has %s entries: %s", loss, mini_batch, len(grads_dict), grads_dict
)
return (
tuple(grads_dict[param] for param in self.ddp_params),
tuple(grads_dict[param] for param in self.non_ddp_params),
)
def get_training_examples():
n = 16
training_examples = FeatureSet(
dense_features=torch.zeros((n, D_DENSE)),
sparse_features=torch.zeros(n, dtype=torch.long),
values=torch.zeros(n),
)
idx = 0
# Every example has another one that has exactly the same features but an
# opposite value. Therefore, their grads cancel each other in all-reduce.
for value in (-1, 1):
for x in (-1.0 * value, 1.0 * value):
for y in (1.0 * value, -1.0 * value):
for z in (0, 1):
training_examples.dense_features[idx, :] = torch.tensor((x, y))
training_examples.sparse_features[idx] = z
training_examples.values[idx] = value
idx += 1
# Split the examples among NUM_TRAINERS trainers
assert 0 == (n % NUM_TRAINERS)
examples_per_trainer = int(n / NUM_TRAINERS)
return [
FeatureSet(
dense_features=training_examples.dense_features[
start : start + examples_per_trainer, :
],
sparse_features=training_examples.sparse_features[
start : start + examples_per_trainer
],
values=training_examples.values[start : start + examples_per_trainer],
)
for start in range(0, n, examples_per_trainer)
]
shutdown_signal = threading.Condition()
def set_shutdown_signal():
global shutdown_signal
with shutdown_signal:
shutdown_signal.notify()
class DdpUnderDistAutogradTest(RpcAgentTestFixture):
@property
def world_size(self) -> int:
return WORLD_SIZE
def remote_worker_name(self) -> str:
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{REMOTE_WORKER_RANK}"
def trainer_name(self, rank):
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{rank}"
def _remote_worker_process(self, ddp_mode):
gLogger.info("The remote worker is running.")
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
# new_group needs to be called on ranks.
dist.new_group(TRAINER_RANKS)
global shutdown_signal
with shutdown_signal:
shutdown_signal.wait()
gLogger.info("Exiting remote worker.")
dist.destroy_process_group()
def _trainer_process(self, rank: int):
gLogger.info("Running the trainer #%s...", rank)
gLogger.info(
"Initing trainer process group by trainer #%s with ranks %s", rank, TRAINER_RANKS
)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
gLogger.info("Waiting for shutdown signal on trainer #%s...", rank)
global shutdown_signal
with shutdown_signal:
shutdown_signal.wait()
gLogger.info("Exiting the trainer #%s...", rank)
dist.destroy_process_group()
def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
gLogger.info("Running the master process...")
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
remote_em_rref = rpc.remote(
self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
)
remote_net_rref = rpc.remote(
self.remote_worker_name(), RemoteNet, args=(D_DENSE + D_SPARSE, D_HID)
)
gLogger.info("Created remote rrefs on master")
self.do_test_on_master(
ddp_mode, simulate_uneven_inputs, remote_em_rref, remote_net_rref
)
def do_test_on_master(
self,
ddp_mode: DdpMode,
simulate_uneven_inputs: bool,
remote_em_rref: rpc.RRef,
remote_net_rref: rpc.RRef,
):
if simulate_uneven_inputs:
gLogger.info(
"Running DDP + RPC test with simulating uneven inputs across trainers."
)
trainer_rrefs = []
for rank in TRAINER_RANKS:
trainer = self.trainer_name(rank)
trainer_rrefs.append(
rpc.remote(
trainer,
Trainer,
args=(remote_em_rref, remote_net_rref, ddp_mode, rank),
)
)
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
# new_group needs to be called on ranks.
dist.new_group(TRAINER_RANKS)
training_examples = get_training_examples()
for _ in range(3):
futures = []
num_trainers = len(trainer_rrefs)
for idx, trainer_rref in enumerate(trainer_rrefs):
# Half the trainers will deplete inputs earlier than the rest.
trainer_has_less_inputs = (
simulate_uneven_inputs and idx < num_trainers // 2
)
futures.append(
_remote_method_async(
Trainer.train_batch,
trainer_rref,
training_examples[idx],
trainer_has_less_inputs,
simulate_uneven_inputs,
)
)
for future in futures:
ddp_grads, non_ddp_grads = future.wait()
# When there are uneven inputs, it is not necessary that grads
# cancel each other out, since some trainers contribute 0 grad.
if not simulate_uneven_inputs:
for grad in ddp_grads:
self.assertEqual(
grad,
torch.zeros_like(grad),
msg=f"The grad for any ddp parameter should be zeros, because "
"the training examples' grads cancel each other. Received "
f"gradient {grad}",
)
for grad in non_ddp_grads:
self.assertNotEqual(
grad,
torch.zeros_like(grad),
msg="The grad for any non-ddp parameter shouldn't be zeros",
)
# Destroy process groups
for idx, trainer_rref in enumerate(trainer_rrefs):
_remote_method_async(Trainer.destroy_pg, trainer_rref).wait()
# Send shutdown signals.
for rank in TRAINER_RANKS:
trainer = self.trainer_name(rank)
rpc.rpc_sync(trainer, set_shutdown_signal, args=())
rpc.rpc_sync(self.remote_worker_name(), set_shutdown_signal, args=())
def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
if self.rank == MASTER_RANK:
self._master_process(ddp_mode, simulate_uneven_inputs)
elif self.rank == REMOTE_WORKER_RANK:
self._remote_worker_process(ddp_mode)
elif self.rank in TRAINER_RANKS:
self._trainer_process(self.rank)
else:
raise RuntimeError(f"Unknown process rank: {self.rank}")
@requires_gloo()
@dist_init
def test_backward_no_ddp(self):
self._do_test(DdpMode.NONE)
@requires_gloo()
@dist_init
def test_backward_ddp_outside(self):
self._do_test(DdpMode.OUTSIDE)
@requires_gloo()
@dist_init
def test_backward_ddp_outside_uneven_inputs(self):
self._do_test(DdpMode.OUTSIDE, simulate_uneven_inputs=True)
@requires_gloo()
@dist_init
def test_backward_ddp_inside(self):
self._do_test(DdpMode.INSIDE)
# Common utils for both CPU and CUDA test suites
class CommonDdpComparisonTest(RpcAgentTestFixture):
@property
def world_size(self) -> int:
return NUM_TRAINERS
def trainer_name(self, rank):
# The name has to be consistent with that in 'dist_init' decorator.
return f"worker{rank}"
@staticmethod
def get_remote_grads(rref, context_id):
return dist_autograd.get_gradients(context_id)[rref.local_value().weight]
class DdpComparisonTest(CommonDdpComparisonTest):
def _run_test_ddp_comparision(self, simulate_uneven_inputs=False):
gLogger.info("Running trainer rank: %s", self.rank)
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
# Postfix file_name with "pg" since file_name is also used by RPC agent
init_method=INIT_METHOD_TEMPLATE.format(file_name=f"{self.file_name}_pg"),
world_size=self.world_size,
rank=self.rank,
)
net = nn.Linear(2, 3)
ddp_net = DistributedDataParallel(net)
# Odd ranks join early if simulate_uneven_inputs.
num_inputs = 1
if simulate_uneven_inputs:
if self.rank % 2 == 0:
num_inputs += 2
inputs_list = [torch.rand((3, 2)) for _ in range(num_inputs)]
if simulate_uneven_inputs:
gLogger.info("Rank %s training with %s inputs.", self.rank, len(inputs_list))
# Use distributed autograd. The gradients will be in RPC context map.
grads_dict = {}
with ddp_net.join(simulate_uneven_inputs):
for i, inputs in enumerate(inputs_list):
with dist_autograd.context() as context_id:
loss = ddp_net(inputs).norm()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
gLogger.info("Trainer #%s got grad dict: %s", self.rank, grads_dict)
# Use local autograd. The gradients will be in each variable's '.grad'.
ddp_net.zero_grad()
loss = ddp_net(inputs).norm()
loss.backward()
# The gradients should be the same
for param in net.parameters():
self.assertTrue(
param in grads_dict,
msg=f"Param {param} is not in dist_auto grad dict {grads_dict} for iteration {i}",
)
self.assertEqual(
grads_dict[param],
param.grad,
msg=f"The grads for param {param} are different under local "
f"and dist autograd: {param.grad} \n---\n {grads_dict[param]} for iteration {i}",
)
dist.destroy_process_group()
@requires_gloo()
@dist_init
def test_ddp_comparison(self):
self._run_test_ddp_comparision()
@requires_gloo()
@dist_init
def test_ddp_comparison_uneven_inputs(self):
# test with simulating uneven inputs in DDP
self._run_test_ddp_comparision(simulate_uneven_inputs=True)
@requires_gloo()
@dist_init
def test_ddp_dist_autograd_sparse_grads(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
model = nn.EmbeddingBag(10, 3, sparse=True)
ddp_model = DistributedDataParallel(model)
# Different inputs for each
input = torch.LongTensor(10).random_(0, 10)
offsets = torch.LongTensor([0, 4])
# Run local.
loss = ddp_model(input, offsets).sum()
loss.backward()
with dist_autograd.context() as context_id:
loss = ddp_model(input, offsets).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
self.assertEqual(1, len(grads_dict))
self.assertEqual(model.weight.grad, grads_dict[model.weight])
@requires_gloo()
@dist_init
def test_ddp_dist_autograd_local_vs_remote(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
# Use two different remote device input string, w/ and w/o the default
# device string "cpu", respectively.
for remote_device in ["worker0/cpu", "worker0"]:
remote_layer1 = RemoteModule(
remote_device=remote_device, module_cls=nn.Linear, args=(10, 5, False)
)
layer1 = nn.Linear(10, 5, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight
# Run local case.
layer2 = nn.Linear(5, 1)
inputs = torch.rand((10, 10))
ddp_model = DistributedDataParallel(layer2)
loss = ddp_model(layer1(inputs)).sum()
loss.backward()
# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_model(remote_layer1(inputs)).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
)
class CudaDdpComparisonTest(CommonDdpComparisonTest):
@skip_if_lt_x_gpu(NUM_TRAINERS)
@requires_nccl()
@dist_init
@skip_if_rocm_multiprocess
def test_ddp_dist_autograd_local_vs_remote_gpu(self):
# Each trainer uses a different random seed. Otherwise, they are going
# to have exactly the same initial model parameters, input, and
# therefore grads. That means the grads will be the same before and
# after DDP's all-reduce.
torch.manual_seed(self.rank)
dist.init_process_group(
backend="gloo",
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
world_size=self.world_size,
rank=self.rank,
)
remote_layer1 = RemoteModule(
remote_device="worker0/cpu", module_cls=nn.Linear, args=(10, 7, False)
)
layer1 = nn.Linear(10, 7, False)
# Start with the same parameters for remote and local
layer1.weight = remote_layer1.module_rref.to_here().weight
layer2 = nn.Linear(7, 5).cuda(self.rank)
ddp_layer2 = DistributedDataParallel(layer2, device_ids=[self.rank])
remote_layer3 = RemoteModule(
remote_device="worker0/cpu", module_cls=nn.Linear, args=(5, 3, False)
)
layer3 = nn.Linear(5, 3, False)
# Start with the same parameters for remote and local
layer3.weight = remote_layer3.module_rref.to_here().weight
layer4 = nn.Linear(3, 1).cuda(self.rank)
ddp_layer4 = DistributedDataParallel(layer4, device_ids=[self.rank])
# Run local case.
inputs = torch.rand((10, 10))
loss = ddp_layer4(
layer3(ddp_layer2(layer1(inputs).cuda(self.rank)).cpu()).cuda(self.rank)
).sum()
loss.backward()
# Run remote case.
with dist_autograd.context() as context_id:
loss = ddp_layer4(
remote_layer3(
ddp_layer2(remote_layer1(inputs).cuda(self.rank)).cpu()
).cuda(self.rank)
).sum()
dist_autograd.backward(context_id, [loss])
grads_dict = dist_autograd.get_gradients(context_id)
dist.barrier()
self.assertEqual(
layer1.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer1.module_rref, context_id),
),
)
self.assertEqual(layer2.weight.grad, grads_dict[layer2.weight])
self.assertEqual(
layer3.weight.grad,
rpc.rpc_sync(
"worker0",
CommonDdpComparisonTest.get_remote_grads,
args=(remote_layer3.module_rref, context_id),
),
)
self.assertEqual(layer4.weight.grad, grads_dict[layer4.weight])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,66 @@
# mypy: allow-untyped-defs
from contextlib import contextmanager
from datetime import timedelta
from functools import (
partial,
wraps,
)
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
class MockProcessGroup(dist.ProcessGroup):
def __init__(self, rank, world):
super().__init__(rank, world)
def getBackendName(self):
return "mock_process_group"
def create_mock_pg(prefix_store, rank, world_size, timeout):
return MockProcessGroup(rank, world_size)
dist.Backend.register_backend('mock_process_group', create_mock_pg)
def mock_init_dist(rank, world_size):
# !!! WARNING !!!
# Kids don't try this at home, this is a cute pile of hacks that
# depends on a small mountain of c10d internals
assert not dist.is_initialized()
store = dist.HashStore()
# Trick _store_based_barrier into believing everyone else already checked-in
# Zero is the group index
store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
dist.init_process_group(
backend="mock_process_group",
rank=rank,
world_size=world_size,
store=store,
group_name="fake",
timeout=timedelta(seconds=1))
@contextmanager
def with_dist(rank=0, world_size=2):
"""
Context manager that initializer c10d with a fake process group.
"""
mock_init_dist(rank=rank, world_size=world_size)
try:
yield
finally:
dist.destroy_process_group()
def with_fake_comms(func=None, rank=0, world_size=2):
"""
Function wrapper that inits a fake process group designed for testing.
Right now only querying for world size is available
"""
if func is None:
return partial(with_fake_comms, rank=rank, world_size=world_size)
@wraps(func)
def wrapper(self, *args, **kwargs):
with with_dist(rank, world_size):
func(self, *args, **kwargs)
return wrapper

Some files were not shown because too many files have changed in this diff Show More