I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
import torch
import torch._prims_common as utils
# Utilities should come BEFORE this import
from torch._decomp import register_decomposition
from torch._prims_common import TensorLikeType
from torch._prims_common.wrappers import out_wrapper
from torch._refs import _broadcast_shapes
# Data conversion references.
#
# Note: this module breaks the usual _refs to torch naming scheme where
# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not
# part of _refs/__init__.py to avoid name clashes with Python builtin types
# (like int).
__all__ = [
# dtypes
"bfloat16",
"bool",
"byte",
"cdouble",
"cfloat",
"chalf",
"char",
"double",
"float",
"half",
"int",
"long",
"short",
# misc
"complex",
"polar",
]
def _make_conversion_method(name: str, dtype: torch.dtype):
def fn(
self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
) -> TensorLikeType:
return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload]
fn.__name__ = name
return fn
bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
bool = _make_conversion_method("bool", torch.bool)
byte = _make_conversion_method("byte", torch.uint8)
cdouble = _make_conversion_method("cdouble", torch.cdouble)
cfloat = _make_conversion_method("cfloat", torch.cfloat)
chalf = _make_conversion_method("chalf", torch.complex32)
char = _make_conversion_method("char", torch.int8)
double = _make_conversion_method("double", torch.double)
float = _make_conversion_method("float", torch.float)
half = _make_conversion_method("half", torch.half)
int = _make_conversion_method("int", torch.int)
long = _make_conversion_method("long", torch.long)
short = _make_conversion_method("short", torch.short)
@register_decomposition(torch._ops.ops.aten.complex)
# Note: complex has type promotion tests disabled due to different semantics.
# exact_dtype is for compat with complex_check_dtype from core.
@out_wrapper(exact_dtype=True)
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
torch._check(
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
lambda: (
f"Expected both inputs to be Half, Float or Double tensors but got "
f"{real.dtype} and {imag.dtype}"
),
)
torch._check(
real.dtype == imag.dtype,
lambda: (
f"Expected object of scalar type {real.dtype} but got "
f"scalar type {imag.dtype} for second argument"
),
)
result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type]
common_shape = _broadcast_shapes(real.shape, imag.shape)
result = real.new_empty(
common_shape,
dtype=result_dtype,
layout=real.layout,
device=real.device,
# pin_memory=real.is_pinned(), # NYI
)
result.real = real
result.imag = imag
return result
@register_decomposition(torch._ops.ops.aten.polar)
# Note: polar has type promotion tests disabled due to different semantics.
# exact_dtype is for compat with complex_check_dtype from core.
@out_wrapper(exact_dtype=True)
def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
result = torch.complex(abs, angle)
result.real = abs * torch.cos(angle)
result.imag = abs * torch.sin(angle)
return result

View File

@ -0,0 +1,590 @@
import math
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import torch._prims as prims
import torch._prims_common as utils
from torch._decomp import register_decomposition
from torch._prims_common import DimsType, ShapeType, TensorLikeType
from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
__all__ = [
# Transforms
"fft",
"fft2",
"fftn",
"hfft",
"hfft2",
"hfftn",
"rfft",
"rfft2",
"rfftn",
"ifft",
"ifft2",
"ifftn",
"ihfft",
"ihfft2",
"ihfftn",
"irfft",
"irfft2",
"irfftn",
# Helpers
"fftshift",
"ifftshift",
]
NormType = Union[None, Literal["forward", "backward", "ortho"]]
_NORM_VALUES = {None, "forward", "backward", "ortho"}
aten = torch._ops.ops.aten
def _apply_norm(
x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
) -> TensorLikeType:
"""Apply normalization to the un-normalized FFT result"""
torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
if norm == "ortho":
return x * (1 / math.sqrt(signal_numel))
normalize = (not forward and (norm is None or norm == "backward")) or (
forward and norm == "forward"
)
return x * (1 / signal_numel) if normalize else x
def _promote_type_fft(
dtype: torch.dtype, require_complex: bool, device: torch.device
) -> torch.dtype:
"""Helper to promote a dtype to one supported by the FFT primitives"""
if dtype.is_complex:
return dtype
# Promote integral to default float type
if not dtype.is_floating_point:
dtype = torch.get_default_dtype()
allowed_types = [torch.float32, torch.float64]
maybe_support_half = device.type in ["cuda", "meta"]
if maybe_support_half:
allowed_types.append(torch.float16)
torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
if require_complex:
dtype = utils.corresponding_complex_dtype(dtype)
return dtype
def _maybe_promote_tensor_fft(
t: TensorLikeType, require_complex: bool = False
) -> TensorLikeType:
"""Helper to promote a tensor to a dtype supported by the FFT primitives"""
cur_type = t.dtype
new_type = _promote_type_fft(cur_type, require_complex, t.device)
return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
def _resize_fft_input(
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
) -> TensorLikeType:
"""
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
either by zero-padding, or by slicing x starting from 0.
"""
assert len(dims) == len(sizes)
must_copy = False
x_sizes = x.shape
pad_amount = [0] * len(x_sizes) * 2
for i in range(len(dims)):
if sizes[i] == -1:
continue
if x_sizes[dims[i]] < sizes[i]:
must_copy = True
pad_idx = len(pad_amount) - 2 * dims[i] - 1
pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
if x_sizes[dims[i]] > sizes[i]:
x = x.narrow(dims[i], 0, sizes[i])
return torch.constant_pad_nd(x, pad_amount) if must_copy else x
def _fft_c2r(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for performing any complex to real FFT (irfft or hfft)"""
input = _maybe_promote_tensor_fft(input, require_complex=True)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified",
)
if n is not None:
input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
if forward:
input = torch.conj(input)
output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
def _fft_r2c(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
onesided: bool,
) -> TensorLikeType:
"""Common code for performing any real to complex FFT (rfft or ihfft)"""
torch._check(
not input.dtype.is_complex,
lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
)
input = _maybe_promote_tensor_fft(input)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
dim_size = n if n is not None else input.shape[dim]
torch._check(
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
ret = _apply_norm(ret, norm, dim_size, forward)
return ret if forward else torch.conj(ret)
def _fft_c2c(
func_name: str,
input: TensorLikeType,
n: Optional[int],
dim: int,
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for performing any complex to complex FFT (fft or ifft)"""
torch._check(
input.dtype.is_complex,
lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
)
dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
dim_size = n if n is not None else input.shape[dim]
torch._check(
dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
)
if n is not None:
input = _resize_fft_input(input, dims, (n,))
ret = prims.fft_c2c(input, dim=dims, forward=forward)
return _apply_norm(ret, norm, dim_size, forward)
@register_decomposition(aten.fft_fft)
@out_wrapper()
def fft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
if input.dtype.is_complex:
return _fft_c2c("fft", input, n, dim, norm, forward=True)
else:
return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
@register_decomposition(aten.fft_ifft)
@out_wrapper()
def ifft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
if input.dtype.is_complex:
return _fft_c2c("ifft", input, n, dim, norm, forward=False)
else:
return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
@register_decomposition(aten.fft_rfft)
@out_wrapper()
def rfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
@register_decomposition(aten.fft_irfft)
@out_wrapper()
def irfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_c2r("irfft", input, n, dim, norm, forward=False)
@register_decomposition(aten.fft_hfft)
@out_wrapper()
def hfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_c2r("hfft", input, n, dim, norm, forward=True)
@register_decomposition(aten.fft_ihfft)
@out_wrapper()
def ihfft(
input: TensorLikeType,
n: Optional[int] = None,
dim: int = -1,
norm: NormType = None,
) -> TensorLikeType:
return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
class _ShapeAndDims(NamedTuple):
shape: Tuple[int, ...]
dims: Tuple[int, ...]
def _canonicalize_fft_shape_and_dim_args(
input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
) -> _ShapeAndDims:
"""Convert the shape and dim arguments into a canonical form where neither are optional"""
input_dim = input.ndim
input_sizes = input.shape
if dim is not None:
if not isinstance(dim, Sequence):
dim = (dim,)
ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
# Check dims are unique
torch._check(
len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
)
if shape is not None:
if not isinstance(shape, Sequence):
shape = (shape,)
# Has shape, might have dim
torch._check(
dim is None or len(dim) == len(shape),
lambda: "When given, dim and shape arguments must have the same length",
)
transform_ndim = len(shape)
torch._check(
transform_ndim <= input_dim,
lambda: f"Got shape with {transform_ndim} values but input tensor "
f"only has {input_dim} dimensions.",
)
# If shape is given, dims defaults to the last len(shape) dimensions
if dim is None:
ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
# Translate any -1 values in shape to the default length
ret_shape = tuple(
s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
)
elif dim is None:
# No shape, no dim
ret_dims = tuple(range(input_dim))
ret_shape = tuple(input_sizes)
else:
# No shape, has dim
ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
for n in ret_shape:
torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
def _prod(xs: Iterable[int]) -> int:
"""Compute product of a list"""
prod = 1
for x in xs:
prod *= x
return prod
def _fftn_c2c(
function_name: str,
input: TensorLikeType,
shape: Tuple[int, ...],
dim: Tuple[int, ...],
norm: NormType,
forward: bool,
) -> TensorLikeType:
"""Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
torch._check(
input.dtype.is_complex,
lambda: f"{function_name} expects a complex input tensor, "
f"but got {input.dtype}",
)
x = _resize_fft_input(input, dim, shape)
output = prims.fft_c2c(x, dim=dim, forward=forward)
return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
@register_decomposition(aten.fft_fftn)
@out_wrapper()
def fftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
x = _maybe_promote_tensor_fft(input, require_complex=True)
return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
@register_decomposition(aten.fft_ifftn)
@out_wrapper()
def ifftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
x = _maybe_promote_tensor_fft(input, require_complex=True)
return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
@register_decomposition(aten.fft_rfftn)
@out_wrapper()
def rfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
torch._check(
not input.dtype.is_complex,
lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
)
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape)
out = prims.fft_r2c(input, dim=dim, onesided=True)
return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
@register_decomposition(aten.fft_ihfftn)
@out_wrapper()
def ihfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
torch._check(
not input.dtype.is_complex,
lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
)
shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
input = _maybe_promote_tensor_fft(input, require_complex=False)
input = _resize_fft_input(input, dim, shape)
tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
if len(dim) == 1:
tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
return prims.conj(tmp)
tmp = prims.conj_physical(tmp)
tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
class _CanonicalizeC2rReturn(NamedTuple):
shape: Tuple[int, ...]
dim: Tuple[int, ...]
last_dim_size: int
def _canonicalize_fft_c2r_shape_and_dim_args(
fname: str,
input: TensorLikeType,
s: Optional[ShapeType],
dim: Optional[DimsType],
) -> _CanonicalizeC2rReturn:
"""Canonicalize shape and dim arguments for n-dimensional c2r transforms,
as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
(shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
if s is None or s[-1] == -1:
last_dim_size = 2 * (input.shape[dim[-1]] - 1)
else:
last_dim_size = shape[-1]
torch._check(
last_dim_size >= 1,
lambda: f"Invalid number of data points ({last_dim_size}) specified",
)
shape_list = list(shape)
shape_list[-1] = last_dim_size // 2 + 1
return _CanonicalizeC2rReturn(
shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
)
@register_decomposition(aten.fft_irfftn)
@out_wrapper()
def irfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
"irfftn", input, s, dim
)
input = _maybe_promote_tensor_fft(input, require_complex=True)
input = _resize_fft_input(input, dim, shape)
out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
@register_decomposition(aten.fft_hfftn)
@out_wrapper()
def hfftn(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = None,
norm: NormType = None,
) -> TensorLikeType:
shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
"hfftn", input, s, dim
)
input = _maybe_promote_tensor_fft(input, require_complex=True)
input = _resize_fft_input(input, dim, shape)
tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
tmp = prims.conj_physical(tmp)
out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
return _apply_norm(out, norm, last_dim_size, forward=True)
@register_decomposition(aten.fft_fft2)
@out_wrapper()
def fft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_ifft2)
@out_wrapper()
def ifft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_rfft2)
@out_wrapper()
def rfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_irfft2)
@out_wrapper()
def irfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_hfft2)
@out_wrapper()
def hfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
@register_decomposition(aten.fft_ihfft2)
@out_wrapper()
def ihfft2(
input: TensorLikeType,
s: Optional[ShapeType] = None,
dim: Optional[DimsType] = (-2, -1),
norm: NormType = None,
) -> TensorLikeType:
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
if dim is None:
return list(range(x.ndim))
elif not isinstance(dim, Sequence):
return [dim]
else:
return list(dim)
@register_decomposition(aten.fft_fftshift)
def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
dims = _default_alldims(dim, input)
shift = [input.shape[d] // 2 for d in dims]
return torch.roll(input, shift, dims)
@register_decomposition(aten.fft_ifftshift)
def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
dims = _default_alldims(dim, input)
shift = [(input.shape[d] + 1) // 2 for d in dims]
return torch.roll(input, shift, dims)

View File

@ -0,0 +1,309 @@
# mypy: allow-untyped-defs
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch._prims as prims
import torch._prims_common as utils
import torch._refs as refs
import torch._refs.linalg as linalg
from torch import Tensor
from torch._prims_common import (
check_fp_or_complex,
check_is_matrix,
Dim,
DimsType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
IntLike,
TensorLikeType,
)
from torch._prims_common.wrappers import (
_maybe_convert_to_dtype,
elementwise_type_promotion_wrapper,
out_wrapper,
)
__all__ = [
"diagonal",
"matrix_norm",
"norm",
"svd",
"svdvals",
"vector_norm",
"vecdot",
"cross",
]
def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
"""
Checks related to the dtype kwarg in `linalg.*norm` functions
"""
if dtype is not None:
torch._check(
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
)
torch._check(
utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
fn_name=fn_name,
d="complex" if utils.is_complex_dtype(x_dtype) else "real",
dtype=dtype,
),
)
torch._check(
utils.get_higher_dtype(dtype, x_dtype) == dtype,
lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
"without narrowing to the specified dtype ({dtype})",
)
import operator
# Utilities should come BEFORE this import
from torch._decomp import register_decomposition
from torch._decomp.decompositions import pw_cast_for_opmath
@register_decomposition(torch._ops.ops.aten.linalg_cross)
@out_wrapper()
@pw_cast_for_opmath
def cross(a: Tensor, b: Tensor, dim: int = -1):
torch._check(
a.ndim == b.ndim,
lambda: "linalg.cross: inputs must have the same number of dimensions.",
)
torch._check(
a.size(dim) == 3 and b.size(dim) == 3,
lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}",
)
a, b = torch.broadcast_tensors(a, b)
dim = utils.canonicalize_dim(a.ndim, dim)
idx = torch.arange(3, device=a.device)
return a.index_select(dim, (idx + 1) % 3) * b.index_select(
dim, (idx + 2) % 3
) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
def diagonal(
input: TensorLikeType,
*,
offset: int = 0,
dim1: int = -2,
dim2: int = -1,
) -> TensorLikeType:
return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2)
@register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
@out_wrapper(exact_dtype=True)
def vector_norm(
x: TensorLikeType,
ord: Union[float, int] = 2,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
# Checks
check_fp_or_complex(x.dtype, "linalg.vector_norm")
if isinstance(dim, Dim):
dim = [dim] # type: ignore[assignment]
if guard_size_oblivious(x.numel() == 0) and (ord < 0.0 or ord == float("inf")):
torch._check(
dim is not None and len(dim) != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
"because the operation does not have an identity",
)
shape = x.shape
assert dim is not None # mypy does not seem to be able to see through check?
for d in dim:
torch._check(
shape[d] != 0,
lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
f"dimension {d} because this dimension is empty and the "
"operation does not have an identity",
)
_check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
computation_dtype, result_dtype = utils.reduction_dtypes(
x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
)
to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype)
# Implementation
if ord == 0.0:
return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
elif ord == float("inf"):
return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
elif ord == float("-inf"):
return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
else:
# From here on the computation dtype is important as the reduction is non-trivial
x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment]
reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0
if not (is_ord_even and utils.is_float_dtype(x.dtype)):
x = torch.abs(x)
return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value]
def _backshift_permutation(dim0, dim1, ndim):
# Auxiliary function for matrix_norm
# Computes the permutation that moves the two given dimensions to the back
ret = [i for i in range(ndim) if i != dim0 and i != dim1]
ret.extend((dim0, dim1))
return ret
def _inverse_permutation(perm):
# Given a permutation, returns its inverse. It's equivalent to argsort on an array
return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))]
# CompositeImplicitAutograd
@out_wrapper(exact_dtype=True)
def matrix_norm(
A: TensorLikeType,
ord: Union[float, str] = "fro",
dim: DimsType = (-2, -1),
keepdim: bool = False,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
# shape
check_is_matrix(A, "linalg.matrix_norm")
# dim
dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
torch._check(
len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
)
torch._check(
dim[0] != dim[1],
lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
)
# dtype arg
_check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
if isinstance(ord, str):
# ord
torch._check(
ord in ("fro", "nuc"),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
)
# dtype
check_fp_or_complex(
A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc"
)
if ord == "fro":
return vector_norm(A, 2, dim, keepdim, dtype=dtype)
else: # ord == "nuc"
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
if keepdim:
inv_perm = _inverse_permutation(perm)
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
return result
else:
# ord
abs_ord = abs(ord)
torch._check(
abs_ord in (2, 1, float("inf")),
lambda: "linalg.matrix_norm: Order {ord} not supported.",
)
# dtype
check_fp_or_complex(
A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2
)
max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim)
if abs_ord == 2.0:
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
if keepdim:
inv_perm = _inverse_permutation(perm)
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
return result
else: # 1, -1, inf, -inf
dim0, dim1 = dim
if abs_ord == float("inf"):
dim0, dim1 = dim1, dim0
if not keepdim and (dim0 < dim1):
dim1 -= 1
return max_min(
vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1
)
# CompositeImplicitAutograd
@out_wrapper(exact_dtype=True)
def norm(
A: TensorLikeType,
ord: Optional[Union[float, str]] = None,
dim: Optional[DimsType] = None,
keepdim: bool = False,
*,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
if dim is not None:
if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment]
torch._check(
len(dim) in (1, 2),
lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
)
elif ord is not None:
torch._check(
A.ndim in (1, 2),
lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
)
if ord is not None and (
(dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)
):
if dim is None:
dim = (0, 1)
return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
else:
if ord is None:
ord = 2.0
return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type]
# CompositeImplicitAutograd
@out_wrapper("U", "S", "Vh", exact_dtype=True)
def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
return prims.svd(A, full_matrices=full_matrices)
# CompositeImplicitAutograd
@out_wrapper(exact_dtype=True)
def svdvals(A: TensorLikeType) -> Tensor:
return svd(A, full_matrices=False)[1]
# CompositeImplicitAutograd
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("x", "y"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
check_fp_or_complex(x.dtype, "linalg.vecdot")
return (x.conj() * y).sum(dim=dim)

View File

@ -0,0 +1,4 @@
from typing import List
__all__: List[str] = []

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,236 @@
# mypy: allow-untyped-defs
import math
from typing import Optional, Union
import torch
import torch._prims as prims
import torch._prims_common as utils
import torch._refs as refs
from torch import Tensor
from torch._decomp import register_decomposition
from torch._prims_common import (
ELEMENTWISE_TYPE_PROMOTION_KIND,
Number,
NumberType,
TensorLike,
TensorLikeType,
)
from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
from torch._refs import (
_make_alias,
_make_elementwise_binary_reference,
_make_elementwise_unary_reference,
)
__all__ = [
"bessel_j0",
"bessel_j1",
"entr",
"erfcx",
"expit",
"i0e",
"i1",
"i1e",
"log_ndtr",
"logit",
"log_softmax",
"multigammaln",
"ndtr",
"ndtri",
"softmax",
"spherical_bessel_j0",
"xlog1py",
"zeta",
]
aten = torch._ops.ops.aten
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def bessel_j0(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_j0(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def bessel_j1(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_j1(a)
@register_decomposition(aten.special_entr)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def entr(a: TensorLikeType) -> TensorLikeType:
return torch.where(
torch.isnan(a),
a,
torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
)
@register_decomposition(aten.special_erfcx)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def erfcx(a: TensorLikeType) -> TensorLikeType:
return prims.erfcx(a)
# alias for sigmoid
expit = _make_alias(torch.sigmoid, "expit")
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def i0e(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i0e(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def i1(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i1(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def i1e(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i1e(a)
@register_decomposition(aten.special_log_ndtr)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def log_ndtr(a: TensorLikeType) -> TensorLikeType:
# Note: M_SQRT1_2 is the value of 1 / sqrt(2)
M_SQRT1_2 = 0.707106781186547524400844362104849039
t = a * M_SQRT1_2
return torch.where(
a < 1.0,
torch.log(torch.special.erfcx(-t) / 2) - t * t,
torch.log1p(-torch.erfc(t) / 2),
)
@register_decomposition(aten.logit)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("self",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
if eps is None:
eps = -1.0
lo = eps
hi = 1 - eps
self = torch.clamp(self, lo, hi)
return torch.log(torch.true_divide(self, torch.sub(1, self)))
@register_decomposition(aten.special_xlog1py)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
torch._check(
isinstance(a, TensorLike) or isinstance(b, TensorLike),
lambda: 'Expected either argument a or b to be a Tensor"',
)
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
if isinstance(a, TensorLike) and isinstance(b, Number):
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(b, TensorLike) and isinstance(a, Number):
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
# mypy: expected "Tensor"
assert isinstance(a, TensorLike)
assert isinstance(b, TensorLike)
rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b)))
return torch.where(torch.isnan(b), float("nan"), rhs)
@register_decomposition(aten.mvlgamma)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
c = 0.25 * p * (p - 1) * math.log(math.pi)
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
@register_decomposition(aten.special_ndtr)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def ndtr(a: TensorLikeType) -> TensorLikeType:
# Note: M_SQRT1_2 is the value of 1 / sqrt(2)
M_SQRT1_2 = 0.707106781186547524400844362104849039
a_sqrt_2 = a * M_SQRT1_2
return (1 + torch.erf(a_sqrt_2)) * 0.5
@register_decomposition(aten.special_ndtri)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def ndtri(a: TensorLikeType) -> TensorLikeType:
return prims.ndtri(a)
# Forwarding alias: the special variant doesn't support the out kwarg
# CompositeImplicitAutograd - don't register decomp
def log_softmax(
a: TensorLikeType,
dim: int,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
# Forwarding alias: the special variant doesn't support the out kwarg
# CompositeImplicitAutograd - don't register decomp
def softmax(
a: TensorLikeType,
dim: int,
dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
return prims.spherical_bessel_j0(a)
# TODO: add docstring
@_make_elementwise_binary_reference(
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
return prims.zeta(a, b)