I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,57 @@
from torch.masked._ops import (
_canonical_dim,
_combine_input_and_mask,
_generate_docstring,
_input_mask,
_output_mask,
_reduction_identity,
_where,
amax,
amin,
argmax,
argmin,
cumprod,
cumsum,
log_softmax,
logaddexp,
logsumexp,
mean,
median,
norm,
normalize,
prod,
softmax,
softmin,
std,
sum,
var,
)
from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor
__all__ = [
"amax",
"amin",
"argmax",
"argmin",
"as_masked_tensor",
"cumprod",
"cumsum",
"is_masked_tensor",
"log_softmax",
"logaddexp",
"logsumexp",
"masked_tensor",
"MaskedTensor",
"mean",
"median",
"norm",
"normalize",
"prod",
"softmax",
"softmin",
"std",
"sum",
"var",
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# flake8: noqa
from .binary import _apply_native_binary, _is_native_binary
from .core import is_masked_tensor, MaskedTensor
from .passthrough import _apply_pass_through_fn, _is_pass_through_fn
from .reductions import _apply_reduction, _is_reduction
from .unary import _apply_native_unary, _is_native_unary

View File

@ -0,0 +1,531 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Callable, Dict, TYPE_CHECKING
import torch
from .binary import _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS
from .core import (
_get_data,
_masks_match,
_maybe_get_mask,
is_masked_tensor,
MaskedTensor,
)
from .passthrough import _apply_pass_through_fn, PASSTHROUGH_FNS
from .reductions import (
_apply_reduction,
NATIVE_REDUCE_FNS,
TENSOR_REDUCE_FNS,
TORCH_REDUCE_FNS,
)
from .unary import _apply_native_unary, NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS
if TYPE_CHECKING:
from torch._ops import OpOverload
__all__ = [] # type: ignore[var-annotated]
def _check_args_kwargs_length(
args, kwargs, error_prefix, len_args=None, len_kwargs=None
):
if len_args is not None and len_args != len(args):
raise ValueError(
f"{error_prefix}: len(args) must be {len_args} but got {len(args)}"
)
if len_kwargs is not None and len_kwargs != len(kwargs):
raise ValueError(
f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}"
)
class _MaskedContiguous(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
if input.is_contiguous():
return input
data = input.get_data()
mask = input.get_mask()
return MaskedTensor(data.contiguous(), mask.contiguous())
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _MaskedToDense(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
if input.layout == torch.strided:
return input
ctx.layout = input.layout
data = input.get_data()
mask = input.get_mask()
return MaskedTensor(data.to_dense(), mask.to_dense())
@staticmethod
def backward(ctx, grad_output):
layout = ctx.layout
if layout == torch.sparse_coo:
return grad_output.to_sparse_coo()
elif layout == torch.sparse_csr:
return grad_output.to_sparse_csr()
elif layout == torch.strided:
return grad_output.to_dense()
raise ValueError("to_dense: Unsupported input layout: ", layout)
class _MaskedToSparse(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
# Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo
if input.layout == torch.sparse_coo:
return input
data = input.get_data()
mask = input.get_mask()
sparse_mask = mask.to_sparse_coo().coalesce()
sparse_data = data.sparse_mask(sparse_mask)
return MaskedTensor(sparse_data, sparse_mask)
@staticmethod
def backward(ctx, grad_output):
return grad_output.to_dense()
class _MaskedToSparseCsr(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
if not is_masked_tensor(input):
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
if input._masked_data.ndim != 2:
raise ValueError(
f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}"
)
if input.layout == torch.sparse_csr:
return input
data = input.get_data()
mask = input.get_mask()
sparse_mask = mask.to_sparse_csr()
sparse_data = data.sparse_mask(sparse_mask)
return MaskedTensor(sparse_data, sparse_mask)
@staticmethod
def backward(ctx, grad_output):
return grad_output.to_dense()
class _MaskedWhere(torch.autograd.Function):
@staticmethod
def forward(ctx, cond, self, other):
ctx.mark_non_differentiable(cond)
ctx.save_for_backward(cond)
return torch.ops.aten.where(cond, self, other)
@staticmethod
def backward(ctx, grad_output):
(cond,) = ctx.saved_tensors
def masked_out_like(mt):
return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool())
return (
None,
torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)),
torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output),
)
_MASKEDTENSOR_FUNCTION_TABLE = {}
_function_fn_apply_map = {
(
tuple(NATIVE_REDUCE_FNS),
tuple(TORCH_REDUCE_FNS),
tuple(TENSOR_REDUCE_FNS),
): _apply_reduction,
}
for fn_map_list, apply_fn in _function_fn_apply_map.items():
for fn_map in fn_map_list:
for fn in fn_map:
_MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn)
def register_function_func(ops):
"""
Used for registering a new __torch_function__ function to MaskedTensor
Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
The code to register a new function looks like:
@register_function_func(list_of_ops)
def foo(func, *args, **kwargs):
<implementation>
"""
def wrapper(func):
for op in ops:
_MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op)
return wrapper
@register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
def _general_function_reductions(func, *args, **kwargs):
return _apply_reduction(func, *args, **kwargs)
@register_function_func([torch.Tensor.where, torch.where])
def _function_where(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0
)
return _MaskedWhere.apply(*args)
@register_function_func([torch.Tensor.contiguous])
def _function_contiguous(func, *args, **kwargs):
return _MaskedContiguous.apply(args[0])
@register_function_func([torch.Tensor.to_dense])
def _function_to_dense(func, *args, **kwargs):
return _MaskedToDense.apply(args[0])
@register_function_func([torch.Tensor.to_sparse])
def _function_to_sparse(func, *args, **kwargs):
return _MaskedToSparse.apply(args[0])
@register_function_func([torch.Tensor.to_sparse_csr])
def _function_to_sparse_csr(func, *args, **kwargs):
return _MaskedToSparseCsr.apply(args[0])
_MASKEDTENSOR_DISPATCH_TABLE: Dict["OpOverload", Callable[..., Any]] = {}
def register_dispatch_func(aten_ops):
"""
Used for registering a new __torch_dispatch__ function to MaskedTensor
Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
The code to register a new function looks like:
@register_dispatch_func(list_of_ops)
def foo(func, *args, **kwargs):
<implementation>
"""
def wrapper(func):
for aten_op in aten_ops:
_MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op)
return wrapper
@register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS)
def _general_reduction(func, *args, **kwargs):
return _apply_reduction(func, *args, **kwargs)
@register_dispatch_func(PASSTHROUGH_FNS)
def _general_passthrough(func, *args, **kwargs):
return _apply_pass_through_fn(func, *args, **kwargs)
@register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS)
def _general_unary(func, *args, **kwargs):
return _apply_native_unary(func, *args, **kwargs)
@register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
def _general_binary(func, *args, **kwargs):
return _apply_native_binary(func, *args, **kwargs)
@register_dispatch_func([torch.ops.aten.stride])
def stride(func, *args, **kwargs):
return None
@register_dispatch_func([torch.ops.aten.sym_stride])
def sym_stride(func, *args, **kwargs):
return None
@register_dispatch_func([torch.ops.prim.layout])
def layout(func, *args, **kwargs):
return _get_data(args[0]).layout
@register_dispatch_func([torch.ops.aten.is_contiguous])
def is_contiguous(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError("MaskedTensors with sparse data do not have is_contiguous")
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.is_strides_like_format])
def is_strides_like_format(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have is_strides_like_format"
)
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense])
def is_non_overlapping_and_dense(func, *args, **kwargs):
data = _get_data(args[0])
if data.is_sparse:
raise ValueError(
"MaskedTensors with sparse data do not have is_non_overlapping_and_dense"
)
return func(data, *args[1:], **kwargs)
@register_dispatch_func([torch.ops.aten.contiguous])
def contiguous(func, *args, **kwargs):
if _get_data(args[0]).is_sparse:
raise ValueError("MaskedTensors with sparse data do not have contiguous")
return _MaskedContiguous.apply(args[0])
@register_dispatch_func([torch.ops.aten.new_empty_strided])
def new_empty_strided(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3)
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
if tuple(args[1]) != tuple(data.size()):
raise ValueError(
f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()"
)
if tuple(args[2]) != tuple(data.stride()):
raise ValueError(
f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()"
)
return MaskedTensor(func(data, args[1], args[2], **kwargs), mask)
@register_dispatch_func([torch.ops.aten._local_scalar_dense])
def _local_scalar_dense(func, *args, **kwargs):
if not _maybe_get_mask(args[0]):
raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor")
return torch.ops.aten._local_scalar_dense(_get_data(args[0]))
@register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone])
def _apply_fn_on_data(func, *args, **kwargs):
return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0]))
@register_dispatch_func([torch.ops.aten._to_copy])
def _to_copy(func, *args, **kwargs):
new_data = func(_get_data(args[0]), *args[1:], **kwargs)
return MaskedTensor(new_data, _maybe_get_mask(args[0]))
@register_dispatch_func([torch.ops.aten._softmax])
def _softmax(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
)
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2)
return MaskedTensor(result_data, mask)
@register_dispatch_func([torch.ops.aten.ones_like])
def ones_like(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1)
result_data = func(_get_data(args[0]), **kwargs)
return MaskedTensor(result_data, _maybe_get_mask(args[0]))
@register_dispatch_func([torch.ops.aten._softmax_backward_data])
def _softmax_backward_data(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4)
grad, output, dim, input_dtype = args
if is_masked_tensor(grad) and is_masked_tensor(output):
if not _masks_match(grad, output):
raise ValueError(
"__torch_dispatch__, {func}: expected the masks of grad and output to match"
)
grad_data = _get_data(grad)
new_grad_data = torch.ops.aten._masked_softmax_backward(
grad_data,
_get_data(output),
~_maybe_get_mask(grad),
dim % grad_data.ndim,
)
res = MaskedTensor(new_grad_data, _maybe_get_mask(grad))
return res
else:
raise ValueError(
f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors"
)
@register_dispatch_func([torch.ops.aten.copy_])
def copy_(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])):
raise ValueError("args[0] mask and args[1] mask must match but do not")
func(_get_data(args[0]), _get_data(args[1]))
return args[0]
@register_dispatch_func([torch.ops.aten.where])
def where(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mx = args[1]
my = args[2]
if not is_masked_tensor(mx):
mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool))
if not is_masked_tensor(my):
my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool))
new_data = func(args[0], mx.get_data(), my.get_data())
new_mask = func(args[0], mx.get_mask(), my.get_mask())
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten._to_sparse])
def _to_sparse(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
if not is_masked_tensor(mt):
mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool))
if mt.is_sparse_coo():
return mt
new_mask = func(_maybe_get_mask(args[0])).coalesce()
new_data = _get_data(args[0]).sparse_mask(new_mask)
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten._to_sparse_csr])
def _to_sparse_csr(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
if not is_masked_tensor(mt):
mt = MaskedTensor(mt, torch.ones_like(mt).bool())
if mt.is_sparse_csr():
return mt
new_mask = func(_maybe_get_mask(args[0]))
new_data = _get_data(args[0]).sparse_mask(new_mask)
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten._to_dense])
def _to_dense(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
if not torch.is_tensor(args[0]):
raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor")
mt = args[0]
if not is_masked_tensor(mt):
mt = MaskedTensor(mt, torch.ones_like(mt).bool())
new_data = func(_get_data(args[0]))
new_mask = func(_maybe_get_mask(args[0]))
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten._indices])
def _indices(func, *args, **kwargs):
# Assumes data is sparse
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
data = _get_data(args[0]).indices()
return MaskedTensor(data, torch.ones_like(data).bool())
@register_dispatch_func([torch.ops.aten._values])
def _values(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
data = _get_data(args[0]).values()
return MaskedTensor(data, torch.ones_like(data).bool())
@register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors])
def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs):
new_args = list(args)
if is_masked_tensor(args[-1]):
new_args[-1] = args[-1].get_data()
if is_masked_tensor(args[-2]):
new_args[-2] = args[-2].get_data()
new_data = func(*new_args, **kwargs)
new_args[-1] = torch.ones_like(new_args[-1])
new_mask = func(*new_args, **kwargs).bool()
return MaskedTensor(new_data, new_mask)
@register_dispatch_func([torch.ops.aten.is_same_size])
def is_same_size(func, *args, **kwargs):
_check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2)
return _get_data(args[0]).is_same_size(_get_data(args[1]))
@register_dispatch_func([torch.ops.aten._is_any_true])
def _is_any_true(func, *args, **kwargs):
_check_args_kwargs_length(
args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0
)
data = _get_data(args[0])
mask = _maybe_get_mask(args[0])
if mask is None:
raise ValueError(
f"__torch_dispatch__, {func}: expected args[0] to be a MaskedTensor"
)
if data.dtype != torch.bool:
raise ValueError(f"__torch_dispatch__, {func}: expected a boolean tensor")
if data.is_sparse:
raise ValueError(f"MaskedTensors with sparse data do not have {func}")
return MaskedTensor(func(data & mask), torch.tensor(True))

View File

@ -0,0 +1,199 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from .core import (
_map_mt_args_kwargs,
_masks_match,
_tensors_match,
_wrap_result,
is_masked_tensor,
)
__all__ = [] # type: ignore[var-annotated]
BINARY_NAMES = [
"add",
"atan2",
"arctan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"div",
"divide",
"floor_divide",
"fmod",
"logaddexp",
"logaddexp2",
"mul",
"multiply",
"nextafter",
"remainder",
"sub",
"subtract",
"true_divide",
"eq",
"ne",
"le",
"ge",
"greater",
"greater_equal",
"gt",
"less_equal",
"lt",
"less",
"maximum",
"minimum",
"fmax",
"fmin",
"not_equal",
]
INPLACE_BINARY_NAMES = [
n + "_"
for n in (
list(
set(BINARY_NAMES)
- {
"logaddexp",
"logaddexp2",
"equal",
"fmin",
"minimum",
"maximum",
"fmax",
}
)
)
]
def _get_at_least_one_mask(a, b):
if not is_masked_tensor(a) and not is_masked_tensor(b):
raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
if not _masks_match(a, b):
raise ValueError("a and b must have matching masks")
if is_masked_tensor(a):
return a.get_mask()
return b.get_mask()
def _binary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError("len(kwargs) must equal 0")
for a in args[2:]:
if torch.is_tensor(a):
raise TypeError(
"MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
)
if not _masks_match(*args[:2]):
raise ValueError(
"Input masks must match. If you need support for this, please open an issue on Github."
)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
args0_layout = data_args[0].layout
same_layout = (
torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
) and (args0_layout == data_args[1].layout)
if args0_layout == torch.sparse_coo:
if same_layout:
if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
raise ValueError(
"sparse_coo indices must match. If you need support for this, please open an issue on Github."
)
if data_args[0].size() != data_args[1].size():
raise ValueError(
"input1 and input2 must have the same size for binary functions."
)
data_args[1] = data_args[1].values()
i = data_args[0].indices()
size = data_args[0].size()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_coo_tensor(i, v, size)
elif args0_layout == torch.sparse_csr:
if same_layout:
if not (
_tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
and _tensors_match(
data_args[0].col_indices(), data_args[1].col_indices()
)
):
raise ValueError(
"sparse_csr indices must match. If you need support for this, please open an issue on Github."
)
data_args[1] = data_args[1].values()
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
else:
result_data = fn(*data_args)
if inplace:
args[0]._set_data_mask(result_data, mask_args[0])
return args[0]
else:
result_mask = _get_at_least_one_mask(*args[:2])
# sparse tensors don't have strides so we can only expand if the layout is strided
if args0_layout == torch.strided:
result_mask = result_mask.expand_as(result_data)
return _wrap_result(result_data, result_mask)
def _torch_binary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def binary_fn(*args, **kwargs):
return _binary_helper(fn, args, kwargs, inplace=False)
return binary_fn
def _torch_inplace_binary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def binary_fn(*args, **kwargs):
return _binary_helper(fn, args, kwargs, inplace=True)
return binary_fn
NATIVE_BINARY_MAP = {
getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
}
NATIVE_INPLACE_BINARY_MAP = {
getattr(torch.ops.aten, name): _torch_inplace_binary(name)
for name in INPLACE_BINARY_NAMES
}
NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
def _is_native_binary(fn):
return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
def _apply_native_binary(fn, *args, **kwargs):
if fn in NATIVE_BINARY_FNS:
return NATIVE_BINARY_MAP[fn](*args, **kwargs)
if fn in NATIVE_INPLACE_BINARY_FNS:
return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
return NotImplemented

View File

@ -0,0 +1,359 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from typing import Any
from typing_extensions import TypeGuard
import torch
from torch.overrides import get_default_nowrap_functions
__all__ = [
"MaskedTensor",
"is_masked_tensor",
]
def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
r"""Returns True if the input is a MaskedTensor, else False
Args:
a: any input
Examples:
>>> # xdoctest: +SKIP
>>> from torch.masked import MaskedTensor
>>> data = torch.arange(6).reshape(2,3)
>>> mask = torch.tensor([[True, False, False], [True, True, False]])
>>> mt = MaskedTensor(data, mask)
>>> is_masked_tensor(mt)
True
"""
return isinstance(obj, MaskedTensor)
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
if is_masked_tensor(a) or is_masked_tensor(b):
raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
if a.layout != b.layout:
raise ValueError(
f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
)
if a.dtype != b.dtype:
b = b.type(a.dtype)
if a.layout == b.layout == torch.sparse_coo:
return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
a.indices(), b.indices(), exact
)
elif a.layout == b.layout == torch.sparse_csr:
return (
_tensors_match(a.crow_indices(), b.crow_indices(), exact)
and _tensors_match(a.col_indices(), b.col_indices(), exact)
and _tensors_match(a.values(), b.values(), exact)
)
if exact:
return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
def _masks_match(a, b):
if is_masked_tensor(a) and is_masked_tensor(b):
mask_a = a.get_mask()
mask_b = b.get_mask()
return _tensors_match(mask_a, mask_b, exact=True)
return True
def _map_mt_args_kwargs(args, kwargs, map_fn):
def _helper(a, map_fn):
if is_masked_tensor(a):
return map_fn(a)
elif torch.is_tensor(a):
return a
elif isinstance(a, list):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return a_impl
elif isinstance(a, tuple):
a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
return tuple(a_impl)
else:
return a
if kwargs is None:
kwargs = {}
impl_args = []
for a in args:
impl_args.append(_helper(a, map_fn))
impl_kwargs = {}
for k in kwargs.keys():
impl_kwargs[k] = _helper(a, map_fn)
return impl_args, impl_kwargs
def _wrap_result(result_data, result_mask):
if isinstance(result_data, list):
return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
if isinstance(result_data, tuple):
return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
if torch.is_tensor(result_data):
return MaskedTensor(result_data, result_mask)
# Expect result_data and result_mask to be Tensors only
return NotImplemented
def _masked_tensor_str(data, mask, formatter):
if data.layout in {torch.sparse_coo, torch.sparse_csr}:
data = data.to_dense()
mask = mask.to_dense()
if data.dim() == 1:
formatted_elements = [
formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
for d in data
]
max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
return (
"["
+ ", ".join(
[
"--".rjust(max_len) if m else e
for (e, m) in zip(formatted_elements, ~mask)
]
)
+ "]"
)
sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
sub_strings = ["\n".join([" " + si for si in s.split("\n")]) for s in sub_strings]
return "[\n" + ",\n".join(sub_strings) + "\n]"
def _get_data(a):
if is_masked_tensor(a):
return a._masked_data
return a
def _maybe_get_mask(a):
if is_masked_tensor(a):
return a.get_mask()
return None
class MaskedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, mask, requires_grad=False):
if is_masked_tensor(data) or not torch.is_tensor(data):
raise TypeError("data must be a Tensor")
if is_masked_tensor(mask) or not torch.is_tensor(mask):
raise TypeError("mask must be a Tensor")
# Use a Tensor that of the give size for the wrapper.
kwargs = {
"device": data.device,
"dtype": data.dtype,
"layout": data.layout,
"requires_grad": requires_grad,
"dispatch_sizes_strides_policy": "strides",
"dispatch_layout": True,
}
warnings.warn(
(
"The PyTorch API of MaskedTensors is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.masked "
"module for further information about the project."
),
UserWarning,
stacklevel=2,
)
if data.requires_grad:
warnings.warn(
"It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()",
UserWarning,
stacklevel=2,
)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
def _preprocess_data(self, data, mask):
from .._ops import _sparse_coo_where, _sparse_csr_where
if data.layout != mask.layout:
raise TypeError("data and mask must have the same layout.")
if data.layout == torch.sparse_coo:
data = data.coalesce()
mask = mask.coalesce()
if data._nnz() != mask._nnz():
data = _sparse_coo_where(mask, data, torch.tensor(0))
elif data.layout == torch.sparse_csr:
if data._nnz() != mask._nnz():
data = _sparse_csr_where(mask, data, torch.tensor(0))
# Have to pick awkward names to not conflict with existing fields such as data
self._masked_data = data.clone()
self._masked_mask = mask.clone()
def _validate_members(self):
data = self._masked_data
mask = self.get_mask()
if type(data) != type(mask):
raise TypeError(
f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
)
if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
raise TypeError(f"data layout of {data.layout} is not supported.")
if data.layout == torch.sparse_coo:
if not _tensors_match(data.indices(), mask.indices(), exact=True):
raise ValueError(
"data and mask are both sparse COO tensors but do not have the same indices."
)
elif data.layout == torch.sparse_csr:
if not _tensors_match(
data.crow_indices(), mask.crow_indices(), exact=True
) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
raise ValueError(
"data and mask are both sparse CSR tensors but do not share either crow or col indices."
)
if mask.dtype != torch.bool:
raise TypeError("mask must have dtype bool.")
if not (
data.dtype == torch.float16
or data.dtype == torch.float32
or data.dtype == torch.float64
or data.dtype == torch.bool
or data.dtype == torch.int8
or data.dtype == torch.int16
or data.dtype == torch.int32
or data.dtype == torch.int64
):
raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
if data.dim() != mask.dim():
raise ValueError("data.dim() must equal mask.dim()")
if data.size() != mask.size():
raise ValueError("data.size() must equal mask.size()")
def __init__(self, data, mask, requires_grad=False):
self._preprocess_data(data, mask)
self._validate_members()
@staticmethod
def _from_values(data, mask):
"""Differentiable constructor for MaskedTensor"""
class Constructor(torch.autograd.Function):
@staticmethod
def forward(ctx, data, mask):
return MaskedTensor(data, mask)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
result = Constructor.apply(data, mask)
return result
def _set_data_mask(self, data, mask):
self._masked_data = data
self._masked_mask = mask
self._validate_members()
def __repr__(self):
formatter = "{0:8.4f}"
if self.dim() == 0:
scalar_data = self.get_data().item()
data_formatted = (
formatter.format(scalar_data)
if isinstance(scalar_data, float)
else str(scalar_data)
)
if not self.get_mask().item():
data_formatted = "--"
return (
"MaskedTensor("
+ data_formatted
+ ", "
+ str(self.get_mask().item())
+ ")"
)
s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
s = "\n".join(" " + si for si in s.split("\n"))
return "MaskedTensor(\n" + s + "\n)"
# Seems like this needs to be defined before torch_dispatch to work
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
if func in _MASKEDTENSOR_FUNCTION_TABLE:
return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
if not all(issubclass(cls, t) for t in types):
return NotImplemented
with torch._C.DisableTorchFunctionSubclass():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
else:
return torch._tensor._convert(ret, cls)
@classmethod
def unary(cls, fn, data, mask):
return MaskedTensor(fn(data), mask)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
func = func.overloadpacket
from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
if func in _MASKEDTENSOR_DISPATCH_TABLE:
return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
msg = (
f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
"In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics."
)
warnings.warn(msg)
return NotImplemented
def __lt__(self, other):
if is_masked_tensor(other):
return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
return MaskedTensor(self.get_data() < other, self.get_mask())
def to_tensor(self, value):
return self.get_data().masked_fill(~self.get_mask(), value)
def get_data(self):
class GetData(torch.autograd.Function):
@staticmethod
def forward(ctx, self):
return self._masked_data
@staticmethod
def backward(ctx, grad_output):
if is_masked_tensor(grad_output):
return grad_output
return MaskedTensor(grad_output, self.get_mask())
return GetData.apply(self)
def get_mask(self):
return self._masked_mask
def is_sparse_coo(self):
return self.layout == torch.sparse_coo
def is_sparse_csr(self):
return self.layout == torch.sparse_csr
# Update later to support more sparse layouts
@property
def is_sparse(self):
return self.is_sparse_coo() or self.is_sparse_csr()

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from .core import MaskedTensor
__all__ = [
"as_masked_tensor",
"masked_tensor",
]
# These two factory functions are intended to mirror
# torch.tensor - guaranteed to be a leaf node
# torch.as_tensor - differentiable constructor that preserves the autograd history
def masked_tensor(data, mask, requires_grad=False):
return MaskedTensor(data, mask, requires_grad)
def as_masked_tensor(data, mask):
return MaskedTensor._from_values(data, mask)

View File

@ -0,0 +1,50 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
"""
These are functions that should simply be applied to both mask and data.
Take select or stack as an example. This operation can be applied to
both the mask and data of a MaskedTensor and the result wrapped into
a new MaskedTensor as a result.
"""
import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]
PASSTHROUGH_FNS = [
torch.ops.aten.select,
torch.ops.aten.transpose,
torch.ops.aten.split,
torch.ops.aten.t,
torch.ops.aten.slice,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.index,
torch.ops.aten.expand,
torch.ops.aten.view,
torch.ops.aten._unsafe_view,
torch.ops.aten._reshape_alias,
torch.ops.aten.cat,
torch.ops.aten.unsqueeze,
torch.ops.aten.unfold,
torch.ops.aten.unfold_backward,
torch.ops.aten.im2col,
torch.ops.aten.col2im,
torch.ops.aten.stack,
]
def _is_pass_through_fn(fn):
return fn in PASSTHROUGH_FNS
def _apply_pass_through_fn(fn, *args, **kwargs):
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
result_data = fn(*data_args, **data_kwargs)
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
result_mask = fn(*mask_args, **mask_kwargs)
return _wrap_result(result_data, result_mask)

View File

@ -0,0 +1,176 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
import torch
from .core import is_masked_tensor
from .creation import as_masked_tensor, masked_tensor
__all__ = [] # type: ignore[var-annotated]
def _masked_all_all(data, mask=None):
if mask is None:
return data.all()
return data.masked_fill(~mask, True).all()
def _masked_all_dim(data, dim, keepdim=False, mask=None):
if mask is None:
return torch.all(data, dim=dim, keepdim=keepdim)
return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
def _masked_all(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 1:
return _masked_all_all(args[0], mask=kwargs["mask"])
return _masked_all_dim(*args, **kwargs)
def _multidim_any(mask, dim, keepdim):
if isinstance(dim, int):
return _multidim_any(mask, [dim], keepdim)
for d in sorted(dim, reverse=True):
mask = torch.any(mask, dim=d, keepdim=keepdim)
return mask
def _get_masked_fn(fn):
if fn == "all":
return _masked_all
return getattr(torch.masked, fn)
def _torch_reduce_all(fn):
def reduce_all(self):
masked_fn = _get_masked_fn(fn)
data = self.get_data()
mask = self.get_mask().values() if self.is_sparse else self.get_mask()
# When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
# element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
# Therefore, this implementation calculates it using the strides.
if fn == "all":
result_data = masked_fn(data, mask=mask)
elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
indices = (
data.to_sparse_coo().indices()
if not self.is_sparse_coo()
else data.indices()
)
idx = indices.unbind(1)[sparse_idx]
stride = data.size().numel() / torch.tensor(
data.size(), device=data.device
).cumprod(0)
result_data = torch.sum(idx * stride)
# we simply pass in the values for sparse COO/CSR tensors
elif self.is_sparse:
result_data = masked_fn(masked_tensor(data.values(), mask))
else:
result_data = masked_fn(self, mask=mask)
return as_masked_tensor(result_data, torch.any(mask))
return reduce_all
def _torch_reduce_dim(fn):
def reduce_dim(self, dim, keepdim=False, dtype=None):
if self.is_sparse:
msg = (
f"The sparse version of {fn} is not implemented in reductions.\n"
"If you would like this operator to be supported, please file an issue for a feature request at "
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
"In the case that the semantics for the operator are not trivial, it would be appreciated "
"to also include a proposal for the semantics."
)
warnings.warn(msg)
return NotImplemented
if not is_masked_tensor(self):
raise TypeError("Input to reduce_dim must be a MaskedTensor")
masked_fn = _get_masked_fn(fn)
data = self.get_data()
mask = self.get_mask()
if fn == "all":
result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
else:
result_data = masked_fn(
self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
)
return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
return reduce_dim
def _torch_reduce(fn):
def reduce_fn(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return _torch_reduce_all(fn)(args[0])
return _torch_reduce_dim(fn)(*args, **kwargs)
return reduce_fn
def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
return input, dim, keepdim, dtype
def _torch_grad_reduce(fn):
def grad_reduce(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return _torch_reduce_all(fn)(args[0])
# TODO: autograd.Function doesn't support kwarg
input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
return grad_reduce
REDUCE_NAMES = [
"sum",
"mean",
"amin",
"amax",
"argmin",
"argmax",
"prod",
"all",
"norm",
"var",
"std",
]
NATIVE_REDUCE_MAP = {
getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
}
TORCH_REDUCE_MAP = {
getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
}
TENSOR_REDUCE_MAP = {
getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
}
NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
def _is_reduction(fn):
return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
def _apply_reduction(fn, *args, **kwargs):
if fn in NATIVE_REDUCE_MAP:
return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
if fn in TORCH_REDUCE_MAP:
return TORCH_REDUCE_MAP[fn](*args, **kwargs)
if fn in TENSOR_REDUCE_MAP:
return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
return NotImplemented

View File

@ -0,0 +1,190 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]
UNARY_NAMES = [
"abs",
"absolute",
"acos",
"arccos",
"acosh",
"arccosh",
"angle",
"asin",
"arcsin",
"asinh",
"arcsinh",
"atan",
"arctan",
"atanh",
"arctanh",
"bitwise_not",
"ceil",
"clamp",
"clip",
"conj_physical",
"cos",
"cosh",
"deg2rad",
"digamma",
"erf",
"erfc",
"erfinv",
"exp",
"exp2",
"expm1",
"fix",
"floor",
"frac",
"lgamma",
"log",
"log10",
"log1p",
"log2",
"logit",
"i0",
"isnan",
"nan_to_num",
"neg",
"negative",
"positive",
"pow",
"rad2deg",
"reciprocal",
"round",
"rsqrt",
"sigmoid",
"sign",
"sgn",
"signbit",
"sin",
"sinc",
"sinh",
"sqrt",
"square",
"tan",
"tanh",
"trunc",
]
INPLACE_UNARY_NAMES = [
n + "_"
for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
]
# Explicitly tracking functions we know are currently not supported
# This might be due to missing code gen or because of complex semantics
UNARY_NAMES_UNSUPPORTED = [
"atan2",
"arctan2",
"bitwise_left_shift",
"bitwise_right_shift",
"copysign",
"float_power",
"fmod",
"frexp",
"gradient",
"imag",
"ldexp",
"lerp",
"logical_not",
"hypot",
"igamma",
"igammac",
"mvlgamma",
"nextafter",
"polygamma",
"real",
"remainder",
"true_divide",
"xlogy",
]
def _unary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError(
"MaskedTensor unary ops require that len(kwargs) == 0. "
"If you need support for this, please open an issue on Github."
)
for a in args[1:]:
if torch.is_tensor(a):
raise TypeError(
"MaskedTensor unary ops do not support additional Tensor arguments"
)
mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask)
data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data)
if args[0].layout == torch.sparse_coo:
data_args[0] = data_args[0].coalesce()
s = data_args[0].size()
i = data_args[0].indices()
data_args[0] = data_args[0].coalesce().values()
v = fn(*data_args)
result_data = torch.sparse_coo_tensor(i, v, size=s)
elif args[0].layout == torch.sparse_csr:
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
else:
result_data = fn(*data_args)
if inplace:
args[0]._set_data_mask(result_data, mask_args[0])
return args[0]
else:
return _wrap_result(result_data, mask_args[0])
def _torch_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=False)
return unary_fn
def _torch_inplace_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=True)
return unary_fn
NATIVE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
}
NATIVE_INPLACE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_inplace_unary(name)
for name in INPLACE_UNARY_NAMES
}
NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
def _is_native_unary(fn):
return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
def _apply_native_unary(fn, *args, **kwargs):
if fn in NATIVE_UNARY_FNS:
return NATIVE_UNARY_MAP[fn](*args, **kwargs)
if fn in NATIVE_INPLACE_UNARY_FNS:
return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
return NotImplemented