I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,465 @@
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import SymInt, Tensor
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
from torch.types import _device as Device, _dtype as DType
__all__ = [
"to_padded_tensor",
"as_nested_tensor",
"nested_tensor",
"nested_tensor_from_jagged",
"narrow",
"masked_select",
]
# Nested Tensor constructor functions
def as_nested_tensor(
ts: Union[Tensor, List[Tensor], Tuple[Tensor, ...]],
dtype: Optional[DType] = None,
device: Optional[Device] = None,
layout=None
) -> Tensor:
r"""
Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of
tensors.
If a nested tensor is passed, it will be returned directly unless the device / dtype / layout
differ. Note that converting device / dtype will result in a copy, while converting layout
is not currently supported by this function.
If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size.
A copy will be incurred if the passed device / dtype differ from those of the input OR if
the input is non-contiguous. Otherwise, the input's storage will be used directly.
If a tensor list is provided, tensors in the list are always copied during construction of
the nested tensor.
Args:
ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a
list / tuple of tensors with the same ndim
Keyword arguments:
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
Default: if None, same :class:`torch.device` as leftmost tensor in the list
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.as_nested_tensor([a, b])
>>> nt.is_leaf
False
>>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
>>> nt.backward(fake_grad)
>>> a.grad
tensor([1., 1., 1.])
>>> b.grad
tensor([0., 0., 0., 0., 0.])
>>> c = torch.randn(3, 5, requires_grad=True)
>>> nt2 = torch.nested.as_nested_tensor(c)
"""
is_tensor_list = isinstance(ts, (list, tuple)) and all(isinstance(t, Tensor) for t in ts)
if not isinstance(ts, Tensor) and not is_tensor_list:
raise TypeError(
"as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors "
)
# convert tuple -> list if needed
if is_tensor_list and not isinstance(ts, list):
ts = list(ts)
if isinstance(ts, Tensor) and ts.dim() < 2:
raise RuntimeError("as_nested_tensor(): Expected tensor argument to have dim() > 1")
if isinstance(ts, Tensor) and ts.is_nested:
if layout == ts.layout:
# return input directly or input copied to device / dtype
return ts.to(device=device, dtype=dtype)
else:
# TODO: Just use nt.to(layout=layout) when it exists.
raise RuntimeError(
"as_nested_tensor(): Converting between nested tensor layouts is not supported")
if layout is None:
layout = torch.strided
if layout == torch.strided:
if isinstance(ts, Tensor):
# contiguous() might be necessary to get flattened view.
# we could probably be more precise about when to do this as an optimization
buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype)
nested_sizes = torch.tensor([t.shape for t in ts])
return torch._nested_view_from_buffer(
buffer,
nested_sizes,
*torch._nested_compute_contiguous_strides_offsets(nested_sizes))
else:
assert isinstance(ts, list)
return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)
elif layout == torch.jagged:
if isinstance(ts, Tensor):
if device is None:
device = ts.device
# contiguous() might be necessary to get flattened view.
# we could probably be more precise about when to do this as an optimization
values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype)
batch_size = ts.shape[0]
seq_len = ts.shape[1]
offsets = torch.arange(0, batch_size * seq_len + 1, seq_len,
device=device, dtype=torch.int64)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
return nested_view_from_values_offsets(
values, offsets, min_seqlen=seq_len, max_seqlen=seq_len
)
else:
from torch.nested._internal.nested_tensor import jagged_from_list
assert isinstance(ts, list)
nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype)
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
# Note: This not only adds doc strings for the nested ops, but
# also connects the torch.nested Python namespace to the torch._C._nested builtins.
to_padded_tensor = _add_docstr(
_nested.nested_to_padded_tensor,
r"""
to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor.
The leading entries will be filled with the nested data,
while the trailing entries will be padded.
.. warning::
:func:`to_padded_tensor` always copies the underlying data,
since the nested and the non-nested tensors differ in memory layout.
Args:
padding (float): The padding value for the trailing entries.
Keyword args:
output_size (Tuple[int]): The size of the output tensor.
If given, it must be large enough to contain all nested data;
else, will infer by taking the max size of each nested sub-tensor along each dimension.
out (Tensor, optional): the output tensor.
Example::
>>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
nested_tensor([
tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]),
tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
[ 0.2773, 0.8793, -0.5183, -0.6447],
[ 1.8009, 1.8468, -0.9832, -1.5272]])
])
>>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]])
>>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000],
[-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
[[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000],
[ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000],
[ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
""",
)
def nested_tensor(tensor_list, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor:
r"""
Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see
:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
Args:
tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor,
where each element of the list has the same dimensionality.
Keyword arguments:
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
Default: if None, same :class:`torch.device` as leftmost tensor in the list
requires_grad (bool, optional): If autograd should record operations on the
returned nested tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned nested tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
Example::
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
>>> nt.is_leaf
True
"""
if layout is None:
layout = torch.strided
if layout == torch.strided:
return _nested.nested_tensor(
tensor_list,
dtype=dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory)
elif layout == torch.jagged:
# Need to wrap lists of scalars as tensors
list_of_tensors = [t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list]
from torch.nested._internal.nested_tensor import jagged_from_list
with torch.no_grad():
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
nt.requires_grad_(requires_grad)
if pin_memory:
nt = nt.pin_memory() # type: ignore[assignment]
return nt
else:
raise RuntimeError(f"Specified layout is unsupported for nested tensors: {layout}")
def narrow(tensor: Tensor, dim: int, start: Union[int, Tensor], length: Union[int, Tensor], layout=torch.strided) -> Tensor:
r"""
Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
shows only the elements in the interval `[start, start+length)`. As nested representations
allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
can also be tensors of shape `tensor.shape[0]`.
There's some differences depending on the layout you use for the nested tensor. If using strided layout,
torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
representation is really useful for representing kv-caches in Transformer models, as specialized
SDPA kernels can deal with format easily, resulting in performance improvements.
Args:
tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
for the nested tensor if using the jagged layout or will be copied for the strided layout.
dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
jagged layout, while strided supports all dim
start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
Keyword arguments:
layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
Only strided and jagged layouts are supported. Default: if None, the strided layout.
Example::
>>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
>>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
>>> narrow_base = torch.randn(5, 10, 20)
>>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
>>> nt_narrowed.is_contiguous()
False
"""
if not isinstance(start, (int, SymInt, Tensor)):
raise RuntimeError("start must be an integer or a tensor")
if not isinstance(length, (int, SymInt, Tensor)):
raise RuntimeError("length must be an integer or a tensor")
if layout == torch.strided:
if isinstance(start, Tensor) or isinstance(length, Tensor):
raise RuntimeError("start and length must be integers for the strided layout NT impl")
# TODO: switch to as_nested_tensor(tensor) when it is available
nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(dim, start, length)
elif layout == torch.jagged:
if dim != 1:
raise RuntimeError("jagged layout only supports dim=1")
from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
if isinstance(start, (int, SymInt)):
start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
if isinstance(length, (int, SymInt)):
length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
else:
raise RuntimeError(f"Specified layout is unsupported for nested narrow: {layout}")
return nt
def nested_tensor_from_jagged(
values: Tensor,
offsets: Optional[Tensor] = None,
lengths: Optional[Tensor] = None,
jagged_dim: Optional[int] = None,
min_seqlen: Optional[int] = None,
max_seqlen: Optional[int] = None,
) -> Tensor:
r"""
Constructs a jagged layout nested tensor from the given jagged components. The jagged layout
consists of a required values buffer with the jagged dimension packed into a single dimension.
The offsets / lengths metadata determines how this dimension is split into batch elements
and are expected to be allocated on the same device as the values buffer.
Expected metadata formats:
* offsets: Indices within the packed dimension splitting it into heterogeneously-sized
batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6
should be conceptually split into batch elements of length [2, 1, 3]. Note that both the
beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1).
* lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3]
indicates that a packed jagged dim of size 6 should be conceptually split into batch
elements of length [2, 1, 3].
Note that it can be useful to provide both offsets and lengths. This describes a nested tensor
with "holes", where the offsets indicate the start position of each batch item and the length
specifies the total number of elements (see example below).
The returned jagged layout nested tensor will be a view of the input values tensor.
Args:
values (:class:`torch.Tensor`): The underlying buffer in the shape of
(sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension,
with the offsets / lengths metadata used to distinguish batch elements.
offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1.
lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B.
jagged_dim (optional int): Indicates which dimension in values is the packed jagged
dimension. If None, this is set to dim=1 (i.e. the dimension immediately following
the batch dimension). Default: None
min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence
length for the returned nested tensor. This can be a useful alternative to computing
this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence
length for the returned nested tensor. This can be a useful alternative to computing
this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
Example::
>>> values = torch.randn(12, 5)
>>> offsets = torch.tensor([0, 3, 5, 6, 10, 12])
>>> nt = nested_tensor_from_jagged(values, offsets)
>>> # 3D shape with the middle dimension jagged
>>> nt.shape
torch.Size([5, j2, 5])
>>> # Length of each item in the batch:
>>> offsets.diff()
tensor([3, 2, 1, 4, 2])
>>> values = torch.randn(6, 5)
>>> offsets = torch.tensor([0, 2, 3, 6])
>>> lengths = torch.tensor([1, 1, 2])
>>> # NT with holes
>>> nt = nested_tensor_from_jagged(values, offsets, lengths)
>>> a, b, c = nt.unbind()
>>> # Batch item 1 consists of indices [0, 1)
>>> torch.equal(a, values[0:1, :])
True
>>> # Batch item 2 consists of indices [2, 3)
>>> torch.equal(b, values[2:3, :])
True
>>> # Batch item 3 consists of indices [3, 5)
>>> torch.equal(c, values[3:5, :])
True
"""
from torch.fx._symbolic_trace import is_fx_tracing
if is_fx_tracing():
raise RuntimeError(
"torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. "
"Use fx.wrap to wrap the function that calls nested_tensor_from_jagged."
)
if offsets is None:
if lengths is None:
raise RuntimeError(
"nested_tensor_from_jagged(): At least one of offsets or lengths is required."
)
else:
# TODO: Truly support offsets=None at some point?
# For now, just convert lengths -> offsets for kernel convenience
offsets = F.pad(lengths.cumsum(0), (1, 0))
lengths = None
if jagged_dim is None:
jagged_dim = 1
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets_lengths
return nested_view_from_values_offsets_lengths(
values, offsets, lengths, ragged_idx=jagged_dim, min_seqlen=min_seqlen, max_seqlen=max_seqlen)
def masked_select(tensor: Tensor, mask: Tensor) -> Tensor:
r"""
Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor
will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is
represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor.
Args:
tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from.
mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input
Example::
>>> tensor = torch.randn(3, 3)
>>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([3, j4])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([1, 2, 1])
>>> tensor = torch.randn(6, 5)
>>> mask = torch.tensor([False])
>>> nt = torch.nested.masked_select(tensor, mask)
>>> nt.shape
torch.Size([6, j5])
>>> # Length of each item in the batch:
>>> nt.offsets().diff()
tensor([0, 0, 0, 0, 0, 0])
"""
if tensor.layout != torch.strided:
raise RuntimeError(
f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}"
)
if mask.layout != torch.strided:
raise RuntimeError(
f"torch.nested.masked_select requires a strided mask, given: {mask.layout}"
)
res_values = tensor.masked_select(mask)
expanded_mask = mask.expand(tensor.shape)
res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1)
from torch.nested._internal.nested_tensor import (
nested_view_from_values_offsets,
)
return nested_view_from_values_offsets(
values=res_values,
offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)),
)

View File

@ -0,0 +1,564 @@
# mypy: allow-untyped-defs
from typing import * # noqa: F403
from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.utils.weak import WeakTensorKeyDictionary
_tensor_id_counter = 0
_tensor_symint_registry = WeakTensorKeyDictionary()
def get_tensor_symint(tensor, *, coeff=1):
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
# NB: Only FakeTensor is associated with a memo
tensor = mb_unwrap_functional_tensor(tensor)
if isinstance(tensor, FakeTensor):
return tensor.get_nested_int(coeff=coeff)
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
# SDPA metadata; max / min seqlens are needed for e.g. flash
def _get_sdpa_extreme_seqlen(func, tensor):
return int(func(tensor).item())
def _store_val_in_tensor(val) -> torch.Tensor:
# hack to get dynamic shapes support: store in a (val, 0) shaped tensor
return torch.zeros(val, 0)
def _load_val_from_tensor(t: torch.Tensor):
return t.shape[0]
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Nested ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
# a jagged tensor with outer shape [B, x, D] is represented internally by a
# tensor with shape [sum(x), D] where we introduce what we call a nested int
# denoted as "x" here (but sometimes denoted with "*" to
# represent the ragged dimension, and sum(x) represents the dim of the inner
# tensor or equivalently the sum of all the sizes of the constituent
# tensors' varying lengths.
#
# We also use nested ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]
_strides: Tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
_metadata_cache: Dict[str, Any]
@staticmethod
def __new__(
cls,
values,
offsets,
*,
lengths=None,
**kwargs,
):
ks = DispatchKeySet(DispatchKey.NestedTensor)
ks = ks.add(DispatchKey.AutogradNestedTensor)
# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
assert values.device == offsets.device
# Query cache for the symint associated with offsets or lengths
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
_ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
if lengths is not None:
assert B == lengths.shape[0]
# subtract 1 to convert to values dim space
r = _ragged_idx - 1
_size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
stride = values.stride()
_strides = (ragged_size * stride[r], *stride)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
_size,
_strides,
0,
torch.contiguous_format,
values.dtype,
torch.jagged,
values.device,
False,
kwargs.get("requires_grad", False),
"sizes",
False,
True, # dispatch_layout
ks,
# don't try to calculate storage based on non-zero size
storage_size=values.untyped_storage().size(),
)
r._ragged_idx = _ragged_idx
r._size = _size
r._strides = _strides
return r
def __init__(self, values, offsets, *, lengths=None, **kwargs):
super().__init__()
self._values = values
self._offsets = offsets
self._lengths = lengths
# holds properties that are computed lazily
self._metadata_cache = kwargs.get("_metadata_cache") or {}
# collapsed ragged dim must always be dynamic
torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
# min / max sequence length should be dynamic if present
max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
if max_seqlen_tensor is not None:
torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
if min_seqlen_tensor is not None:
torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
def values(self):
# dispatch to get proper view relationship
return torch._nested_get_values(self) # type: ignore[attr-defined]
def offsets(self):
return self._offsets
def lengths(self):
return self._lengths
# Private accessor functions for min / max sequence length. They're
# purposefully not @properties because those don't work with PT2 (yet).
# These compute / cache if not present.
# TODO: Revisit this when @properties are better supported by PT2. I think the ideal
# state would be to have public @properties for min / max sequence length that compile
# (including setters).
def _get_max_seqlen(self):
max_seqlen_tensor = self._max_seqlen_tensor
if max_seqlen_tensor is None:
# compute & cache
max_val = _get_sdpa_extreme_seqlen(
torch.max,
self._offsets.diff() if self._lengths is None else self._lengths,
)
max_seqlen_tensor = _store_val_in_tensor(max_val)
self._metadata_cache["max_seqlen"] = max_seqlen_tensor
return _load_val_from_tensor(max_seqlen_tensor)
def _get_min_seqlen(self):
min_seqlen_tensor = self._min_seqlen_tensor
if min_seqlen_tensor is None:
# compute & cache
min_val = _get_sdpa_extreme_seqlen(
torch.min,
self._offsets.diff() if self._lengths is None else self._lengths,
)
min_seqlen_tensor = _store_val_in_tensor(min_val)
self._metadata_cache["min_seqlen"] = min_seqlen_tensor
return _load_val_from_tensor(min_seqlen_tensor)
# Private accessors used for treating min / max seqlen as inner tensors for
# flatten / unflatten. These must be properties to work with the traceable wrapper
# subclass logic. These do not compute / cache if not present.
@property
def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("max_seqlen", None)
@property
def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("min_seqlen", None)
# These are old private @property accessors that are kept around for internal BC
# reasons. TODO: Remove these!
@property
def _max_seqlen(self):
return self._get_max_seqlen()
@property
def _min_seqlen(self):
return self._get_min_seqlen()
def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
# SymNodes are not serializable
assert "_size" in state and "_strides" in state
state = dict(state)
del state["_size"]
del state["_strides"]
# TODO: Update this to handle the other inner tensors
func = NestedTensor
args = (self._values, self._offsets)
return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
def __tensor_flatten__(self):
ctx = {
"requires_grad": self.requires_grad,
"ragged_idx": self._ragged_idx,
}
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
inner_tensors.append("_lengths")
if self._min_seqlen_tensor is not None:
inner_tensors.append("_min_seqlen_tensor")
if self._max_seqlen_tensor is not None:
inner_tensors.append("_max_seqlen_tensor")
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
from torch._subclasses.fake_tensor import FakeTensor
# inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors.get("_lengths", None)
min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
metadata_cache = {}
if min_seqlen_tensor is not None:
metadata_cache["min_seqlen"] = min_seqlen_tensor
if max_seqlen_tensor is not None:
metadata_cache["max_seqlen"] = max_seqlen_tensor
ragged_idx = meta["ragged_idx"]
# Alternatively, we could make it the caller's responsibility to
# cache it. But this heuristic seems simple enough.
ragged_source = offsets if lengths is None else lengths
if isinstance(ragged_source, FakeTensor):
ragged_size = outer_size[ragged_idx]
ragged_source.nested_int_memo = ragged_size
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
requires_grad=meta["requires_grad"],
_ragged_idx=ragged_idx,
_metadata_cache=metadata_cache,
)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
# Lazy import to avoid circular dependency
from .ops import lookup_jagged
fn = lookup_jagged(func, *args, **kwargs)
if fn is not None:
return fn(*args, **kwargs)
raise NotImplementedError(func)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
from .ops import jagged_torch_function
# This should be removed after
# https://github.com/pytorch/pytorch/pull/125941/ lands
with maybe_enable_thunkify():
try:
return jagged_torch_function(func, *args, **kwargs)
except NotImplementedError:
pass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
# internal BC period has passed.
# Not actually a view!
class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
ctx.metadata_cache = x._metadata_cache
ctx.ragged_idx = x._ragged_idx
return x._values
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(
gO,
offsets=offsets,
_metadata_cache=ctx.metadata_cache,
_ragged_idx=ctx.ragged_idx,
)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
values: torch.Tensor,
offsets: torch.Tensor,
metadata_cache: Optional[Dict[str, Any]] = None,
): # type: ignore[override]
# maintain BC with this usages of this where the seqlens are stuffed
# directly into the metadata cache as non-Tensors / ints
if metadata_cache is not None:
min_seqlen = metadata_cache.get("min_seqlen", None)
max_seqlen = metadata_cache.get("max_seqlen", None)
if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
return NestedTensor(
values.detach(),
offsets=offsets,
_metadata_cache=metadata_cache,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO._values, None, None
def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)
# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
offsets: Optional[torch.Tensor],
dtype=None,
device=None,
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must have the same dtype"
)
if not len(set(t.device for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must be on the same device"
)
# Check that the NT is representable by the jagged layout.
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
sizes = [t.shape for t in tensors]
non_first_sizes = [s[1:] for s in sizes]
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
if not at_most_first_ragged:
raise RuntimeError(
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
"Note that the jagged layout only represents shapes of the form "
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
)
# Set properties appropriately.
values = torch.cat(tensors, dim=0)
to_kwargs = {}
if device is not None:
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
values = values.to(**to_kwargs)
# Calculate jagged offsets if not provided.
if offsets is None:
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
# TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
# an extra leaf tensor during the forward, potentially resolving compatibility issues.
offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=values.device),
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
]
)
# compute this now since it's easy
min_seqlen = min(t.shape[0] for t in tensors)
max_seqlen = max(t.shape[0] for t in tensors)
ret_nt = nested_view_from_values_offsets(
values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
)
return (ret_nt, offsets) # type: ignore[return-value]
def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
lengths.shape, (batch_size,)
):
start_list = starts.expand(batch_size)
length_list = lengths.expand(batch_size)
else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be Tensors that broadcast to input.shape[0]"
)
# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = torch.cat(
[
start_list + offset_lengths,
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
]
)
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
if len(tensor.shape) > 2:
values = tensor.view(-1, *tensor.shape[2:])
else:
values = tensor.view(-1)
# Check if offsets and lengths make it possibly contiguous and return a regular NT
is_contiguous = True
orig_dim = tensor.shape[1]
if torch.any(length_list[1:-1].ne(orig_dim)):
is_contiguous = False
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
is_contiguous = False
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
actual_max_seqlen = int(torch.max(lengths).item())
min_seqlen = int(torch.min(lengths).item())
if is_contiguous:
ret_nt = nested_view_from_values_offsets(
values[offsets[0] : offsets[-1]],
offsets - offsets[0],
min_seqlen=min_seqlen,
max_seqlen=actual_max_seqlen,
)
else:
ret_nt = nested_view_from_values_offsets_lengths(
values,
offsets,
length_list,
min_seqlen=min_seqlen,
max_seqlen=actual_max_seqlen,
)
return (ret_nt, offsets, None if is_contiguous else length_list)
# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
# This arg is otherwise unused.
_dummy_instance: Optional[torch.Tensor] = None
def _nt_view_dummy() -> torch.Tensor:
global _dummy_instance
if _dummy_instance is None:
_dummy_instance = NestedTensor(
values=torch.zeros(3, 3, device="meta"),
offsets=torch.zeros(3, device="meta", dtype=torch.int64),
).detach()
return _dummy_instance
def nested_view_from_values_offsets(
values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
):
min_seqlen_tensor = None
if min_seqlen is not None:
min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
max_seqlen_tensor = None
if max_seqlen is not None:
max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
return torch._nested_view_from_jagged( # type: ignore[attr-defined]
values,
offsets,
_nt_view_dummy(),
None,
ragged_idx,
min_seqlen_tensor,
max_seqlen_tensor,
) # type: ignore[return-value]
def nested_view_from_values_offsets_lengths(
values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
):
min_seqlen_tensor = None
if min_seqlen is not None:
min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
max_seqlen_tensor = None
if max_seqlen is not None:
max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
return torch._nested_view_from_jagged( # type: ignore[attr-defined]
values,
offsets,
_nt_view_dummy(),
lengths,
ragged_idx,
min_seqlen_tensor,
max_seqlen_tensor,
) # type: ignore[return-value]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,871 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional, Tuple
import torch
import torch.nn
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
from torch.nn.attention import SDPBackend
from .nested_tensor import NestedTensor
log = logging.getLogger(__name__)
def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
if (
not isinstance(query, NestedTensor)
or not isinstance(key, NestedTensor)
or not isinstance(value, NestedTensor)
):
raise ValueError(
f"Expected query, key, and value to be nested tensors, "
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
f"and value.is_nested: {value.is_nested} instead."
)
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
if query.device != key.device or query.device != value.device:
raise ValueError(
f"Expected query, key, and value to have the same device type, "
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
raise ValueError(
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
raise ValueError(
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
)
if attn_mask is not None:
# TODO: Figure out whether masks are actually supported for this layout or not
raise ValueError("Masks are not yet supported!")
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
raise ValueError(
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
)
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
# This is expected to be called after check_tensor_shapes ensuring that the
# size() calls won't error since the inputs are all 4 dimensional
q_batch_size = params.query.size(0)
k_batch_size = params.key.size(0)
v_batch_size = params.value.size(0)
# num_heads logic for nested input is checked in
# check_for_seq_len_0_nested_tensor as there is handling there to make sure
# num_heads is not ragged
return q_batch_size == k_batch_size and q_batch_size == v_batch_size
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
max_size = 256
query_size_last = params.query.size(-1)
key_size_last = params.key.size(-1)
value_size_last = params.value.size(-1)
same_head_dim_size = (
query_size_last == key_size_last and query_size_last == value_size_last
)
if not (
same_head_dim_size
and (query_size_last % 8 == 0)
and (query_size_last <= max_size)
):
if debug:
log.warning(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same "
"last dimension and to be a multiple of 8 and less than or equal to 256. "
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
query_size_last,
key_size_last,
value_size_last,
)
return False
return True
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
param: torch.Tensor, param_name: str, debug=False
) -> bool:
assert isinstance(param, NestedTensor), "param should be a jagged NT"
if param._ragged_idx == 1:
# num_head_dims is ragged
if debug:
log.warning(
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
param_name,
)
return False
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
if param._get_min_seqlen() == 0:
if debug:
log.warning(
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
param_name,
)
return False
return True
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
max_size = max(q_size, k_size, v_size)
if (
(q_size != max_size and q_size != 1)
or (k_size != max_size and k_size != 1)
or (v_size != max_size and v_size != 1)
):
if debug:
log.warning(
"Both fused kernels require query, key and value to have broadcastable %s, "
"got Query %s %d, Key %s %d, Value %s %d instead.",
param_name,
param_name,
q_size,
param_name,
k_size,
param_name,
v_size,
)
return False
return True
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
# When this function is called we are assured that the nt is dim==4
q_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.query, "query", debug
)
if params.query.is_nested
else True
)
# short circuit if any is unsafe
if not q_is_safe:
return False
k_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.key, "key", debug
)
if params.key.is_nested
else True
)
# short circuit if any is unsafe
if not k_is_safe:
return False
v_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.value, "value", debug
)
if params.value.is_nested
else True
)
# short circuit if any is unsafe
if not v_is_safe:
return False
# We now know none of the inputs have ragged num_heads, so we can safely
# access .size(1)
q_num_heads = params.query.size(1)
k_num_heads = params.key.size(1)
v_num_heads = params.value.size(1)
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
if not same_num_heads:
if (
params.query.requires_grad
or params.key.requires_grad
or params.value.requires_grad
):
if debug:
log.warning(
"Both fused kernels do not support training with broadcasted NT inputs."
)
return False
return _try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads", debug
)
return True
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_head_dim_size_flash_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
if (
not params.query.transpose(1, 2).is_contiguous()
or not params.key.transpose(1, 2).is_contiguous()
or not params.value.transpose(1, 2).is_contiguous()
):
if debug:
log.warning(
"If inputs are nested tensors they must be contiguous after transposing."
)
return False
if params.is_causal:
if debug:
log.warning(
"Nested tensors for query / key are not supported when is_causal=True."
)
return False
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
and not math_sdp_enabled()
):
return SDPBackend.ERROR
ordering = (
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
return SDPBackend.FLASH_ATTENTION
if backend == SDPBackend.EFFICIENT_ATTENTION:
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
params
):
return SDPBackend.EFFICIENT_ATTENTION
if backend == SDPBackend.MATH:
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
return SDPBackend.MATH
log.warning("Memory efficient kernel not used because:")
can_use_efficient_attention(params, debug=True)
_can_use_efficient_sdpa_jagged(params, debug=True)
log.warning("Flash attention kernel not used because:")
can_use_flash_attention(params, debug=True)
_can_use_flash_sdpa_jagged(params, debug=True)
log.warning("Math attention kernel not used because:")
_can_use_math_sdpa_jagged(params, debug=True)
return SDPBackend.ERROR
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
# This function is used to calculate two pieces of metadata that are needed
# for use with flash-attention and efficient_attention kernels. They are the
# cumulative sequence_length over a batch of sequences and the maximum
# sequence length.
# It returns a tuple of cumulative sequence lengths and the maximum sequence
# length, and the last element in the cumulative_sequence_lengths
if not isinstance(qkv, NestedTensor):
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
if qkv.lengths() is None:
# TODO: Explore performance impact of copying
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
max_seqlen = qkv._get_max_seqlen()
n_elem = qkv.values().shape[0]
else:
# TODO: Explore performance impact of copying
cumulative_seqlen = (
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
)
batch_size = qkv.size(0)
max_seqlen = qkv._get_max_seqlen()
# TODO: Explore performance impact when compiling
n_elem = int(cumulative_seqlen[-1].item())
return cumulative_seqlen, max_seqlen, n_elem
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
# This function checks if a nested tensor is valid for
# use with the flash-attention and efficient_attention kernels without
# needing to call contiguous on the nested tensor input.
# It checks that the storage offsets' adjacent_differences are a constant
# mutiple of the previous tensor in the nested tensor and that the strides
# are monitonically decreasing. This check is done after calling transpose on
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
# Returns a boolean indicating if contiguous needs to be called for input
assert isinstance(tensor, NestedTensor)
offsets = tensor.offsets()
strides = tensor._strides
n_tensors = offsets.size(0) - 1
if n_tensors <= 1:
return True
# Check initially that the tensor strides are in strictly descending order
prev_stride = strides[1]
for stride in strides[2:]:
if prev_stride <= stride:
# This would mean that the last stride is greater than the seq_len
# stride
return False
prev_stride = stride
# Congrats you made it!
return True
def _view_as_dense(
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
) -> torch.Tensor:
if tensor.is_nested:
return tensor.values()
return tensor.view(Nnz, num_heads, head_dim)
# TODO: Next iteration should add test cases and check it works
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# q_batch_size = query.size(0)
# k_batch_size = key.size(0)
# v_batch_size = value.size(0)
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
# q_num_heads = query.size(1)
# k_num_heads = key.size(1)
# v_num_heads = value.size(1)
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
# head_dim_qk = query.size(3)
# head_dim_v = value.size(3)
# q_t = query.transpose(1, 2)
# k_t = key.transpose(1, 2)
# v_t = value.transpose(1, 2)
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
# # output_batch_size/num_heads then they are 1
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size
# # If {*}_batch_size_needs_broadcast, then
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
# # this is because needs_broadcast indicates that the batch_size is 1
# # and hence there is only 1 value for seq_len
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
# # ..., outut_batch_size * {*}_t.size(1)]
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
# if q_batch_size_needs_broadcast or not q_t.is_nested:
# max_seqlen_batch_q = q_t.size(1)
# cumulative_sequence_length_q = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_q,
# max_seqlen_batch_q,
# device=q_t.device,
# dtype=torch.int32,
# )
# Nnz_q = output_batch_size * max_seqlen_batch_q
# else:
# (
# cumulative_sequence_length_q,
# max_seqlen_batch_q,
# Nnz_q,
# ) = _cumulative_and_max_seq_len_nnz(q_t)
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
# assert k_t.size(1) == v_t.size(1)
# max_seqlen_batch_kv = k_t.size(1)
# cumulative_sequence_length_kv = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_kv,
# max_seqlen_batch_kv,
# device=k_t.device,
# dtype=torch.int32,
# )
# Nnz_kv = output_batch_size * max_seqlen_batch_kv
# else:
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
# _cumulative_and_max_seq_len_nnz(v_t)
# if k_batch_size_needs_broadcast
# else _cumulative_and_max_seq_len_nnz(k_t)
# )
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads
# if not q_t.is_nested:
# query_buffer_reshaped = q_t.expand(
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
# )
# query_buffer_reshaped = query_buffer_reshaped.reshape(
# Nnz_q, output_num_heads, head_dim_qk
# )
# else:
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
# q_t = q_t.contiguous()
# # If we are broadcasting then Nnz_q will be the output_batch_size since
# # seq_len is 1
# effective_batch_size_q = (
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q
# )
# query_buffer_reshaped = _view_as_dense(
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk
# )
# # If the physical layout of the NestedTensor's storage
# # is not: batch, {seq_len}, num_heads, head_dim then we need
# # to call contiguous
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
# k_t = k_t.contiguous()
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
# v_t = v_t.contiguous()
# effective_batch_size_k = (
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
# )
# key_buffer_reshaped = _view_as_dense(
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk
# )
# effective_batch_size_v = (
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
# )
# value_buffer_reshaped = _view_as_dense(
# v_t, effective_batch_size_v, output_num_heads, head_dim_v
# )
# if not q_batch_size_needs_broadcast:
# output_shape = q_t._size
# if head_dim_v != head_dim_qk:
# output_shape[-1] = head_dim_v
# if q_num_heads_needs_broadcast:
# output_shape[1] = output_num_heads
# else:
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
# output_shape[0] = q_t.size(1)
# output_shape[1] = output_num_heads
# output_shape[2] = head_dim_v
# return (
# query_buffer_reshaped,
# key_buffer_reshaped,
# value_buffer_reshaped,
# cumulative_sequence_length_q,
# cumulative_sequence_length_kv,
# max_seqlen_batch_q,
# max_seqlen_batch_kv,
# output_shape,
# )
def _sdpa_nested_preprocessing(query, key, value):
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
q_batch_size = query.size(0)
k_batch_size = key.size(0)
v_batch_size = value.size(0)
q_num_heads = query.size(1)
k_num_heads = key.size(1)
v_num_heads = value.size(1)
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
q_num_heads == k_num_heads and k_num_heads == v_num_heads
):
raise RuntimeError(
"This path is currently not implemented for jagged layout NT."
)
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
num_heads = query.size(1)
head_dim_qk = query.size(3)
head_dim_v = value.size(3)
q_t = query.transpose(1, 2)
k_t = key.transpose(1, 2)
v_t = value.transpose(1, 2)
(
cumulative_sequence_length_q,
max_seqlen_batch_q,
Nnz_q,
) = _cumulative_and_max_seq_len_nnz(q_t)
(
cumulative_sequence_length_kv,
max_seqlen_batch_kv,
Nnz_kv,
) = _cumulative_and_max_seq_len_nnz(k_t)
# [TODO] K and V have to have the same Nnz, should probably torch_check
# assume in order to not iterate over v
# If the physical layout of the NestedTensor's storage
# is not: batch, {seq_len}, num_heads, head_dim then we need
# to call contiguous
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
q_t = q_t.contiguous()
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
k_t = k_t.contiguous()
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
v_t = v_t.contiguous()
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
output_nt_info = {
"offsets": q_t.offsets(),
"_max_seqlen": q_t._get_max_seqlen(),
"_min_seqlen": q_t._get_min_seqlen(),
}
return (
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
)
def _pad_last_dim(
tensor: torch.Tensor, alignment_size: int, slice: bool
) -> torch.Tensor:
# FlashAttentionV2 requires that head dimension be a multiple of 8
# This was previously done within the kernel, however
# This causes the kernel to maybe alias query, key, value
# So instead we pad the head_dimensions to be a multiple of 8
# in the composite region
last_dim_size = tensor.size(-1)
if last_dim_size % alignment_size == 0:
return tensor
pad_count = alignment_size - (last_dim_size % alignment_size)
tensor = torch.nn.functional.pad(tensor, [0, pad_count])
if slice:
return tensor[..., 0:last_dim_size]
return tensor
# TODO: coalesce with torch/nn/utils/attention.py
def _calculate_scale(query, scale):
# TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
return softmax_scale
def _post_process_flash_output(out: torch.Tensor, og_size):
if not out.is_nested and out.size(-1) != og_size:
out = out[..., 0:og_size]
return out
def _is_computing_meta_flops(x):
# Note: there's a use case of using meta tensors & the dispatch-based flop counter.
# We can use this function to check for this scenario in order to handle it specially.
if not torch.jit.is_scripting() and x.device.type == "meta":
torch_dispatch_mode_stack = (
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
)
return any(
type(x) == torch.utils.flop_counter.FlopCounterMode
for x in torch_dispatch_mode_stack
)
return False
def _autocast(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
[Autocasting SDPA for NJT]
Normal autocasting doesn't work for NJT+SDPA right now:
* NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
op won't work because we never see that aten op.
* If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
won't work correctly: the kernel selection logic will run before autocasting, and choose
a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
attention computation will happen in a different dtype.
An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
and rely on autocasting to do the actual conversions for flash attention / efficient attention.
However, by manually doing the actual autocast before the backend selection, we ensure that the
autocast handling for backend selection doesn't diverge from the autocast handling for the
actual dtype conversions.
"""
device_type = query.device.type
# meta device is not supported by autocast, so break early for it
if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
return query, key, value, attn_mask
def cvt(x):
if x is None:
return x
target_dtype = torch.get_autocast_dtype(device_type)
if (
(not x.dtype.is_floating_point)
or x.dtype == target_dtype
or x.dtype == torch.float64
):
return x
return x.to(target_dtype)
return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
def jagged_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
):
query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
assert (
isinstance(query, NestedTensor)
and isinstance(key, NestedTensor)
and isinstance(value, NestedTensor)
)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
# second batch dim instead). For this case, we can just send the dense buffers through
# vanilla SDPA.
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
output = F.scaled_dot_product_attention(
query.values(),
key.values(),
value.values(),
attn_mask=(
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return nested_view_from_values_offsets(output, query.offsets())
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
)
if _is_computing_meta_flops(query):
# Backend choice will probably not be correct if we have a meta device,
# because backend choice is device-aware. In this case, we mostly just
# want to avoid using math backend (which does a .item() call).
# Arbitrarily choose flash attention.
backend_choice = SDPBackend.FLASH_ATTENTION
if backend_choice == SDPBackend.FLASH_ATTENTION:
og_size = query.size(-1)
query_padded = _pad_last_dim(query, 8, False)
key_padded = _pad_last_dim(key, 8, False)
value_padded = _pad_last_dim(value, 8, False)
# We need to calculate the scale based off the OG head dim size
og_scale = _calculate_scale(query, scale)
(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
(
attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
is_causal,
False,
scale=og_scale,
)
# Reshape output to convert nnz to batch_size and seq_len
attention = nested_view_from_values_offsets(
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
).transpose(1, 2)
return _post_process_flash_output(attention, og_size)
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query, key, value)
(
attention,
log_sumexp,
seed,
offset,
max_seqlen_q,
max_seqlen_batch_kv,
) = torch.ops.aten._efficient_attention_forward(
query_reshaped.unsqueeze(0),
key_reshaped.unsqueeze(0),
value_reshaped.unsqueeze(0),
None,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
int(is_causal),
compute_logsumexp,
scale=scale,
)
# Reshape output to convert nnz to batch_size and seq_len
return nested_view_from_values_offsets(
attention.squeeze(0),
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
offsets = query.offsets()
d1 = query._size[1]
d2 = value._size[-1]
min_seqlen_tensor = query._metadata_cache.get(
"min_seqlen", None
) # type: ignore[attr-defined]
max_seqlen_tensor = query._metadata_cache.get(
"max_seqlen", None
) # type: ignore[attr-defined]
# convert jagged layout Nested Tensor to strided layout Nested Tensor
# which support the math implementation of SDPA
def get_strided_layout_nested_tensor(jagged_layout_nt):
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
transpose = torch.transpose(jagged_layout_nt, 1, 2)
tensor_list = transpose.values().split(list(lengths), dim=0)
strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
strided_nt = strided_nt.transpose(1, 2).contiguous()
return strided_nt
query = get_strided_layout_nested_tensor(query)
key = get_strided_layout_nested_tensor(key)
value = get_strided_layout_nested_tensor(value)
attn_out = torch._scaled_dot_product_attention_math(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
from torch.nested._internal.nested_tensor import _load_val_from_tensor
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()
attn_out = attn_out.view(-1, d1, d2)
attn_out = nested_view_from_values_offsets(
attn_out,
offsets,
min_seqlen=(
None
if min_seqlen_tensor is None
else _load_val_from_tensor(min_seqlen_tensor)
),
max_seqlen=(
None
if max_seqlen_tensor is None
else _load_val_from_tensor(max_seqlen_tensor)
),
).transpose(1, 2)
return attn_out
else:
raise RuntimeError(
"No viable backend for scaled_dot_product_attention was found."
)