I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,564 @@
# mypy: allow-untyped-defs
from typing import * # noqa: F403
from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch._prims_common import is_expandable_to
from torch.utils.weak import WeakTensorKeyDictionary
_tensor_id_counter = 0
_tensor_symint_registry = WeakTensorKeyDictionary()
def get_tensor_symint(tensor, *, coeff=1):
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
# NB: Only FakeTensor is associated with a memo
tensor = mb_unwrap_functional_tensor(tensor)
if isinstance(tensor, FakeTensor):
return tensor.get_nested_int(coeff=coeff)
global _tensor_id_counter
tensor_symint = _tensor_symint_registry.get(tensor)
if tensor_symint is None:
tensor_symint = torch._C._get_nested_int(_tensor_id_counter, coeff)
_tensor_id_counter += 1
_tensor_symint_registry[tensor] = tensor_symint
return tensor_symint
# SDPA metadata; max / min seqlens are needed for e.g. flash
def _get_sdpa_extreme_seqlen(func, tensor):
return int(func(tensor).item())
def _store_val_in_tensor(val) -> torch.Tensor:
# hack to get dynamic shapes support: store in a (val, 0) shaped tensor
return torch.zeros(val, 0)
def _load_val_from_tensor(t: torch.Tensor):
return t.shape[0]
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_lengths: Optional[torch.Tensor]
# NOTE [ Nested ints for ragged sizes and strides ]
#
# Jagged layout tensors are tensors that represent a n-dim tensor with a
# ragged dimension, but are backed by an (n-1)-dim tensor underneath, e.g.,
# a jagged tensor with outer shape [B, x, D] is represented internally by a
# tensor with shape [sum(x), D] where we introduce what we call a nested int
# denoted as "x" here (but sometimes denoted with "*" to
# represent the ragged dimension, and sum(x) represents the dim of the inner
# tensor or equivalently the sum of all the sizes of the constituent
# tensors' varying lengths.
#
# We also use nested ints to represent the strides of this tensor.
# For example, a jagged tensor with shape [B, x, D] can be strided in two
# ways: [xD, D, 1] and [x, 1, sum(x)], where xD represents x multiplied by D
_size: Tuple[int, ...]
_strides: Tuple[int, ...]
# Indicates that the nth dimension is ragged
_ragged_idx: int
_metadata_cache: Dict[str, Any]
@staticmethod
def __new__(
cls,
values,
offsets,
*,
lengths=None,
**kwargs,
):
ks = DispatchKeySet(DispatchKey.NestedTensor)
ks = ks.add(DispatchKey.AutogradNestedTensor)
# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
assert values.device == offsets.device
# Query cache for the symint associated with offsets or lengths
# (create a new one if needed).
ragged_source = offsets if lengths is None else lengths
ragged_size = get_tensor_symint(ragged_source, coeff=1)
_ragged_idx = kwargs.get("_ragged_idx", 1)
B = offsets.shape[0] - 1
if lengths is not None:
assert B == lengths.shape[0]
# subtract 1 to convert to values dim space
r = _ragged_idx - 1
_size = (B, *values.shape[:r], ragged_size, *values.shape[r + 1 :])
stride = values.stride()
_strides = (ragged_size * stride[r], *stride)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
_size,
_strides,
0,
torch.contiguous_format,
values.dtype,
torch.jagged,
values.device,
False,
kwargs.get("requires_grad", False),
"sizes",
False,
True, # dispatch_layout
ks,
# don't try to calculate storage based on non-zero size
storage_size=values.untyped_storage().size(),
)
r._ragged_idx = _ragged_idx
r._size = _size
r._strides = _strides
return r
def __init__(self, values, offsets, *, lengths=None, **kwargs):
super().__init__()
self._values = values
self._offsets = offsets
self._lengths = lengths
# holds properties that are computed lazily
self._metadata_cache = kwargs.get("_metadata_cache") or {}
# collapsed ragged dim must always be dynamic
torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
# min / max sequence length should be dynamic if present
max_seqlen_tensor = self._metadata_cache.get("max_seqlen", None)
if max_seqlen_tensor is not None:
torch._dynamo.mark_dynamic(max_seqlen_tensor, 0)
min_seqlen_tensor = self._metadata_cache.get("min_seqlen", None)
if min_seqlen_tensor is not None:
torch._dynamo.mark_dynamic(min_seqlen_tensor, 0)
def values(self):
# dispatch to get proper view relationship
return torch._nested_get_values(self) # type: ignore[attr-defined]
def offsets(self):
return self._offsets
def lengths(self):
return self._lengths
# Private accessor functions for min / max sequence length. They're
# purposefully not @properties because those don't work with PT2 (yet).
# These compute / cache if not present.
# TODO: Revisit this when @properties are better supported by PT2. I think the ideal
# state would be to have public @properties for min / max sequence length that compile
# (including setters).
def _get_max_seqlen(self):
max_seqlen_tensor = self._max_seqlen_tensor
if max_seqlen_tensor is None:
# compute & cache
max_val = _get_sdpa_extreme_seqlen(
torch.max,
self._offsets.diff() if self._lengths is None else self._lengths,
)
max_seqlen_tensor = _store_val_in_tensor(max_val)
self._metadata_cache["max_seqlen"] = max_seqlen_tensor
return _load_val_from_tensor(max_seqlen_tensor)
def _get_min_seqlen(self):
min_seqlen_tensor = self._min_seqlen_tensor
if min_seqlen_tensor is None:
# compute & cache
min_val = _get_sdpa_extreme_seqlen(
torch.min,
self._offsets.diff() if self._lengths is None else self._lengths,
)
min_seqlen_tensor = _store_val_in_tensor(min_val)
self._metadata_cache["min_seqlen"] = min_seqlen_tensor
return _load_val_from_tensor(min_seqlen_tensor)
# Private accessors used for treating min / max seqlen as inner tensors for
# flatten / unflatten. These must be properties to work with the traceable wrapper
# subclass logic. These do not compute / cache if not present.
@property
def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("max_seqlen", None)
@property
def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
return self._metadata_cache.get("min_seqlen", None)
# These are old private @property accessors that are kept around for internal BC
# reasons. TODO: Remove these!
@property
def _max_seqlen(self):
return self._get_max_seqlen()
@property
def _min_seqlen(self):
return self._get_min_seqlen()
def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
# SymNodes are not serializable
assert "_size" in state and "_strides" in state
state = dict(state)
del state["_size"]
del state["_strides"]
# TODO: Update this to handle the other inner tensors
func = NestedTensor
args = (self._values, self._offsets)
return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))
def __tensor_flatten__(self):
ctx = {
"requires_grad": self.requires_grad,
"ragged_idx": self._ragged_idx,
}
inner_tensors = ["_values", "_offsets"]
if self._lengths is not None:
inner_tensors.append("_lengths")
if self._min_seqlen_tensor is not None:
inner_tensors.append("_min_seqlen_tensor")
if self._max_seqlen_tensor is not None:
inner_tensors.append("_max_seqlen_tensor")
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors: Dict, meta, outer_size, outer_stride):
from torch._subclasses.fake_tensor import FakeTensor
# inner tensors: _values, _offsets, [_lengths], [_min_seqlen], [_max_seqlen]
assert len(inner_tensors) >= 2 and len(inner_tensors) <= 5
values = inner_tensors["_values"]
offsets = inner_tensors["_offsets"]
lengths = inner_tensors.get("_lengths", None)
min_seqlen_tensor = inner_tensors.get("_min_seqlen_tensor", None)
max_seqlen_tensor = inner_tensors.get("_max_seqlen_tensor", None)
metadata_cache = {}
if min_seqlen_tensor is not None:
metadata_cache["min_seqlen"] = min_seqlen_tensor
if max_seqlen_tensor is not None:
metadata_cache["max_seqlen"] = max_seqlen_tensor
ragged_idx = meta["ragged_idx"]
# Alternatively, we could make it the caller's responsibility to
# cache it. But this heuristic seems simple enough.
ragged_source = offsets if lengths is None else lengths
if isinstance(ragged_source, FakeTensor):
ragged_size = outer_size[ragged_idx]
ragged_source.nested_int_memo = ragged_size
return NestedTensor(
values,
offsets=offsets,
lengths=lengths,
requires_grad=meta["requires_grad"],
_ragged_idx=ragged_idx,
_metadata_cache=metadata_cache,
)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
# Lazy import to avoid circular dependency
from .ops import lookup_jagged
fn = lookup_jagged(func, *args, **kwargs)
if fn is not None:
return fn(*args, **kwargs)
raise NotImplementedError(func)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
from torch.fx.experimental.proxy_tensor import maybe_enable_thunkify
from .ops import jagged_torch_function
# This should be removed after
# https://github.com/pytorch/pytorch/pull/125941/ lands
with maybe_enable_thunkify():
try:
return jagged_torch_function(func, *args, **kwargs)
except NotImplementedError:
pass
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
# NB: These fake view autograd.Functions are superseded by real view ops. Don't use them!
# TODO: Remove ViewBufferFromNested, ViewNestedFromBuffer, and buffer_from_jagged once the
# internal BC period has passed.
# Not actually a view!
class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
ctx.metadata_cache = x._metadata_cache
ctx.ragged_idx = x._ragged_idx
return x._values
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(
gO,
offsets=offsets,
_metadata_cache=ctx.metadata_cache,
_ragged_idx=ctx.ragged_idx,
)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
values: torch.Tensor,
offsets: torch.Tensor,
metadata_cache: Optional[Dict[str, Any]] = None,
): # type: ignore[override]
# maintain BC with this usages of this where the seqlens are stuffed
# directly into the metadata cache as non-Tensors / ints
if metadata_cache is not None:
min_seqlen = metadata_cache.get("min_seqlen", None)
max_seqlen = metadata_cache.get("max_seqlen", None)
if min_seqlen is not None and not isinstance(min_seqlen, torch.Tensor):
metadata_cache["min_seqlen"] = _store_val_in_tensor(min_seqlen)
if max_seqlen is not None and not isinstance(max_seqlen, torch.Tensor):
metadata_cache["max_seqlen"] = _store_val_in_tensor(max_seqlen)
return NestedTensor(
values.detach(),
offsets=offsets,
_metadata_cache=metadata_cache,
)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO._values, None, None
def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)
# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: List[torch.Tensor],
offsets: Optional[torch.Tensor],
dtype=None,
device=None,
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
if not len(set(t.dtype for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must have the same dtype"
)
if not len(set(t.device for t in tensors)) == 1: # noqa: C401
raise RuntimeError(
"When constructing a nested tensor, all tensors in list must be on the same device"
)
# Check that the NT is representable by the jagged layout.
# Jagged layout represents (B, *, D_0, D_1, ..., D_N), where the only
# raggedness allowed is for the single dim immediately adjacent to the batch dim.
sizes = [t.shape for t in tensors]
non_first_sizes = [s[1:] for s in sizes]
at_most_first_ragged = all(s == non_first_sizes[0] for s in non_first_sizes)
if not at_most_first_ragged:
raise RuntimeError(
"Cannot represent given tensor list as a nested tensor with the jagged layout. "
"Note that the jagged layout only represents shapes of the form "
"(B, *, D_0, D_1, ..., D_N), with only * allowed to be ragged."
)
# Set properties appropriately.
values = torch.cat(tensors, dim=0)
to_kwargs = {}
if device is not None:
to_kwargs["device"] = device
if dtype is not None:
to_kwargs["dtype"] = dtype
values = values.to(**to_kwargs)
# Calculate jagged offsets if not provided.
if offsets is None:
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
# TODO: An alternative way to construct offsets is to use F.pad. This avoids creating
# an extra leaf tensor during the forward, potentially resolving compatibility issues.
offsets = torch.cat(
[
torch.zeros(1, dtype=torch.int64, device=values.device),
torch.tensor([s[0] for s in sizes], device=values.device).cumsum(dim=0),
]
)
# compute this now since it's easy
min_seqlen = min(t.shape[0] for t in tensors)
max_seqlen = max(t.shape[0] for t in tensors)
ret_nt = nested_view_from_values_offsets(
values, offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
)
return (ret_nt, offsets) # type: ignore[return-value]
def jagged_from_tensor_and_lengths(
tensor: torch.Tensor, starts: torch.Tensor, lengths: torch.Tensor
) -> Tuple[NestedTensor, torch.Tensor, Optional[torch.Tensor]]:
"""Constructs a NestedTensor backed by jagged layout from a tensor, starts of sequences, and sequence lengths"""
batch_size = tensor.shape[0]
if is_expandable_to(starts.shape, (batch_size,)) and is_expandable_to(
lengths.shape, (batch_size,)
):
start_list = starts.expand(batch_size)
length_list = lengths.expand(batch_size)
else:
raise RuntimeError(
"When constructing a jagged nested tensor using narrow(), "
"your start and length must be Tensors that broadcast to input.shape[0]"
)
# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device
)
# Jagged layout specifies that offsets are stored as int64 on the same device as values.
offsets = torch.cat(
[
start_list + offset_lengths,
(start_list[-1] + offset_lengths[-1] + length_list[-1]).unsqueeze(0),
]
)
# Reshape buffer to flatten the 1st and 2nd dimension (view used to enforce non-copy)
if len(tensor.shape) > 2:
values = tensor.view(-1, *tensor.shape[2:])
else:
values = tensor.view(-1)
# Check if offsets and lengths make it possibly contiguous and return a regular NT
is_contiguous = True
orig_dim = tensor.shape[1]
if torch.any(length_list[1:-1].ne(orig_dim)):
is_contiguous = False
if torch.any(offsets[1:-2].diff().ne(orig_dim)):
is_contiguous = False
if offsets[0] + length_list[0] != orig_dim:
is_contiguous = False
actual_max_seqlen = int(torch.max(lengths).item())
min_seqlen = int(torch.min(lengths).item())
if is_contiguous:
ret_nt = nested_view_from_values_offsets(
values[offsets[0] : offsets[-1]],
offsets - offsets[0],
min_seqlen=min_seqlen,
max_seqlen=actual_max_seqlen,
)
else:
ret_nt = nested_view_from_values_offsets_lengths(
values,
offsets,
length_list,
min_seqlen=min_seqlen,
max_seqlen=actual_max_seqlen,
)
return (ret_nt, offsets, None if is_contiguous else length_list)
# NB: A dummy arg is required so that NestedTensor.__torch_dispatch__() is invoked
# for _nested_view_from_values_offsets(). Sizes don't matter much, but they shouldn't be
# 0/1 because the dummy can be fake-ified and we want to avoid specializing.
# This arg is otherwise unused.
_dummy_instance: Optional[torch.Tensor] = None
def _nt_view_dummy() -> torch.Tensor:
global _dummy_instance
if _dummy_instance is None:
_dummy_instance = NestedTensor(
values=torch.zeros(3, 3, device="meta"),
offsets=torch.zeros(3, device="meta", dtype=torch.int64),
).detach()
return _dummy_instance
def nested_view_from_values_offsets(
values, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None
):
min_seqlen_tensor = None
if min_seqlen is not None:
min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
max_seqlen_tensor = None
if max_seqlen is not None:
max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
return torch._nested_view_from_jagged( # type: ignore[attr-defined]
values,
offsets,
_nt_view_dummy(),
None,
ragged_idx,
min_seqlen_tensor,
max_seqlen_tensor,
) # type: ignore[return-value]
def nested_view_from_values_offsets_lengths(
values, offsets, lengths, ragged_idx=1, min_seqlen=None, max_seqlen=None
):
min_seqlen_tensor = None
if min_seqlen is not None:
min_seqlen_tensor = _store_val_in_tensor(min_seqlen)
max_seqlen_tensor = None
if max_seqlen is not None:
max_seqlen_tensor = _store_val_in_tensor(max_seqlen)
return torch._nested_view_from_jagged( # type: ignore[attr-defined]
values,
offsets,
_nt_view_dummy(),
lengths,
ragged_idx,
min_seqlen_tensor,
max_seqlen_tensor,
) # type: ignore[return-value]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,871 @@
# mypy: allow-untyped-defs
import logging
from typing import Optional, Tuple
import torch
import torch.nn
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
from torch.nn.attention import SDPBackend
from .nested_tensor import NestedTensor
log = logging.getLogger(__name__)
def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
if (
not isinstance(query, NestedTensor)
or not isinstance(key, NestedTensor)
or not isinstance(value, NestedTensor)
):
raise ValueError(
f"Expected query, key, and value to be nested tensors, "
f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
f"and value.is_nested: {value.is_nested} instead."
)
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
if query.device != key.device or query.device != value.device:
raise ValueError(
f"Expected query, key, and value to have the same device type, "
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 3 or key.dim() < 3 or value.dim() < 3:
raise ValueError(
f"Expected query, key, and value to all be at least 3 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)
if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
raise ValueError(
f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
)
if attn_mask is not None:
# TODO: Figure out whether masks are actually supported for this layout or not
raise ValueError("Masks are not yet supported!")
if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
raise ValueError(
f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
)
def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
# This is expected to be called after check_tensor_shapes ensuring that the
# size() calls won't error since the inputs are all 4 dimensional
q_batch_size = params.query.size(0)
k_batch_size = params.key.size(0)
v_batch_size = params.value.size(0)
# num_heads logic for nested input is checked in
# check_for_seq_len_0_nested_tensor as there is handling there to make sure
# num_heads is not ragged
return q_batch_size == k_batch_size and q_batch_size == v_batch_size
def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
max_size = 256
query_size_last = params.query.size(-1)
key_size_last = params.key.size(-1)
value_size_last = params.value.size(-1)
same_head_dim_size = (
query_size_last == key_size_last and query_size_last == value_size_last
)
if not (
same_head_dim_size
and (query_size_last % 8 == 0)
and (query_size_last <= max_size)
):
if debug:
log.warning(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same "
"last dimension and to be a multiple of 8 and less than or equal to 256. "
"Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
query_size_last,
key_size_last,
value_size_last,
)
return False
return True
def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
param: torch.Tensor, param_name: str, debug=False
) -> bool:
assert isinstance(param, NestedTensor), "param should be a jagged NT"
if param._ragged_idx == 1:
# num_head_dims is ragged
if debug:
log.warning(
"Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
param_name,
)
return False
# This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
if param._get_min_seqlen() == 0:
if debug:
log.warning(
"Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
param_name,
)
return False
return True
def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
max_size = max(q_size, k_size, v_size)
if (
(q_size != max_size and q_size != 1)
or (k_size != max_size and k_size != 1)
or (v_size != max_size and v_size != 1)
):
if debug:
log.warning(
"Both fused kernels require query, key and value to have broadcastable %s, "
"got Query %s %d, Key %s %d, Value %s %d instead.",
param_name,
param_name,
q_size,
param_name,
k_size,
param_name,
v_size,
)
return False
return True
def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
# When this function is called we are assured that the nt is dim==4
q_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.query, "query", debug
)
if params.query.is_nested
else True
)
# short circuit if any is unsafe
if not q_is_safe:
return False
k_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.key, "key", debug
)
if params.key.is_nested
else True
)
# short circuit if any is unsafe
if not k_is_safe:
return False
v_is_safe = (
_check_for_seq_len_0_and_consistent_head_dim_nested_helper(
params.value, "value", debug
)
if params.value.is_nested
else True
)
# short circuit if any is unsafe
if not v_is_safe:
return False
# We now know none of the inputs have ragged num_heads, so we can safely
# access .size(1)
q_num_heads = params.query.size(1)
k_num_heads = params.key.size(1)
v_num_heads = params.value.size(1)
same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
if not same_num_heads:
if (
params.query.requires_grad
or params.key.requires_grad
or params.value.requires_grad
):
if debug:
log.warning(
"Both fused kernels do not support training with broadcasted NT inputs."
)
return False
return _try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads", debug
)
return True
def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_head_dim_size_flash_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
constraints = (
_check_batch_size_nested,
_check_for_seq_len_0_nested,
)
for constraint in constraints:
if not constraint(params, debug):
return False
return True
def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
if (
not params.query.transpose(1, 2).is_contiguous()
or not params.key.transpose(1, 2).is_contiguous()
or not params.value.transpose(1, 2).is_contiguous()
):
if debug:
log.warning(
"If inputs are nested tensors they must be contiguous after transposing."
)
return False
if params.is_causal:
if debug:
log.warning(
"Nested tensors for query / key are not supported when is_causal=True."
)
return False
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
and not math_sdp_enabled()
):
return SDPBackend.ERROR
ordering = (
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
return SDPBackend.FLASH_ATTENTION
if backend == SDPBackend.EFFICIENT_ATTENTION:
if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
params
):
return SDPBackend.EFFICIENT_ATTENTION
if backend == SDPBackend.MATH:
if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
return SDPBackend.MATH
log.warning("Memory efficient kernel not used because:")
can_use_efficient_attention(params, debug=True)
_can_use_efficient_sdpa_jagged(params, debug=True)
log.warning("Flash attention kernel not used because:")
can_use_flash_attention(params, debug=True)
_can_use_flash_sdpa_jagged(params, debug=True)
log.warning("Math attention kernel not used because:")
_can_use_math_sdpa_jagged(params, debug=True)
return SDPBackend.ERROR
def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
# This function is used to calculate two pieces of metadata that are needed
# for use with flash-attention and efficient_attention kernels. They are the
# cumulative sequence_length over a batch of sequences and the maximum
# sequence length.
# It returns a tuple of cumulative sequence lengths and the maximum sequence
# length, and the last element in the cumulative_sequence_lengths
if not isinstance(qkv, NestedTensor):
raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
if qkv.lengths() is None:
# TODO: Explore performance impact of copying
cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
max_seqlen = qkv._get_max_seqlen()
n_elem = qkv.values().shape[0]
else:
# TODO: Explore performance impact of copying
cumulative_seqlen = (
qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
)
batch_size = qkv.size(0)
max_seqlen = qkv._get_max_seqlen()
# TODO: Explore performance impact when compiling
n_elem = int(cumulative_seqlen[-1].item())
return cumulative_seqlen, max_seqlen, n_elem
def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
# This function checks if a nested tensor is valid for
# use with the flash-attention and efficient_attention kernels without
# needing to call contiguous on the nested tensor input.
# It checks that the storage offsets' adjacent_differences are a constant
# mutiple of the previous tensor in the nested tensor and that the strides
# are monitonically decreasing. This check is done after calling transpose on
# the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
# Returns a boolean indicating if contiguous needs to be called for input
assert isinstance(tensor, NestedTensor)
offsets = tensor.offsets()
strides = tensor._strides
n_tensors = offsets.size(0) - 1
if n_tensors <= 1:
return True
# Check initially that the tensor strides are in strictly descending order
prev_stride = strides[1]
for stride in strides[2:]:
if prev_stride <= stride:
# This would mean that the last stride is greater than the seq_len
# stride
return False
prev_stride = stride
# Congrats you made it!
return True
def _view_as_dense(
tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
) -> torch.Tensor:
if tensor.is_nested:
return tensor.values()
return tensor.view(Nnz, num_heads, head_dim)
# TODO: Next iteration should add test cases and check it works
# def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
# # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# q_batch_size = query.size(0)
# k_batch_size = key.size(0)
# v_batch_size = value.size(0)
# output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
# q_num_heads = query.size(1)
# k_num_heads = key.size(1)
# v_num_heads = value.size(1)
# output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
# head_dim_qk = query.size(3)
# head_dim_v = value.size(3)
# q_t = query.transpose(1, 2)
# k_t = key.transpose(1, 2)
# v_t = value.transpose(1, 2)
# # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
# # output_batch_size/num_heads then they are 1
# q_batch_size_needs_broadcast = q_batch_size != output_batch_size
# k_batch_size_needs_broadcast = k_batch_size != output_batch_size
# v_batch_size_needs_broadcast = v_batch_size != output_batch_size
# # If {*}_batch_size_needs_broadcast, then
# # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
# # this is because needs_broadcast indicates that the batch_size is 1
# # and hence there is only 1 value for seq_len
# # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
# # ..., outut_batch_size * {*}_t.size(1)]
# # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
# if q_batch_size_needs_broadcast or not q_t.is_nested:
# max_seqlen_batch_q = q_t.size(1)
# cumulative_sequence_length_q = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_q,
# max_seqlen_batch_q,
# device=q_t.device,
# dtype=torch.int32,
# )
# Nnz_q = output_batch_size * max_seqlen_batch_q
# else:
# (
# cumulative_sequence_length_q,
# max_seqlen_batch_q,
# Nnz_q,
# ) = _cumulative_and_max_seq_len_nnz(q_t)
# if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
# assert k_t.size(1) == v_t.size(1)
# max_seqlen_batch_kv = k_t.size(1)
# cumulative_sequence_length_kv = torch.arange(
# 0,
# (output_batch_size + 1) * max_seqlen_batch_kv,
# max_seqlen_batch_kv,
# device=k_t.device,
# dtype=torch.int32,
# )
# Nnz_kv = output_batch_size * max_seqlen_batch_kv
# else:
# cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
# _cumulative_and_max_seq_len_nnz(v_t)
# if k_batch_size_needs_broadcast
# else _cumulative_and_max_seq_len_nnz(k_t)
# )
# q_num_heads_needs_broadcast = q_num_heads != output_num_heads
# k_num_heads_needs_broadcast = k_num_heads != output_num_heads
# v_num_heads_needs_broadcast = v_num_heads != output_num_heads
# if not q_t.is_nested:
# query_buffer_reshaped = q_t.expand(
# output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
# )
# query_buffer_reshaped = query_buffer_reshaped.reshape(
# Nnz_q, output_num_heads, head_dim_qk
# )
# else:
# if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
# q_t = q_t.contiguous()
# # If we are broadcasting then Nnz_q will be the output_batch_size since
# # seq_len is 1
# effective_batch_size_q = (
# output_batch_size if q_batch_size_needs_broadcast else Nnz_q
# )
# query_buffer_reshaped = _view_as_dense(
# q_t, effective_batch_size_q, output_num_heads, head_dim_qk
# )
# # If the physical layout of the NestedTensor's storage
# # is not: batch, {seq_len}, num_heads, head_dim then we need
# # to call contiguous
# if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
# k_t = k_t.contiguous()
# if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
# v_t = v_t.contiguous()
# effective_batch_size_k = (
# output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
# )
# key_buffer_reshaped = _view_as_dense(
# k_t, effective_batch_size_k, output_num_heads, head_dim_qk
# )
# effective_batch_size_v = (
# output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
# )
# value_buffer_reshaped = _view_as_dense(
# v_t, effective_batch_size_v, output_num_heads, head_dim_v
# )
# if not q_batch_size_needs_broadcast:
# output_shape = q_t._size
# if head_dim_v != head_dim_qk:
# output_shape[-1] = head_dim_v
# if q_num_heads_needs_broadcast:
# output_shape[1] = output_num_heads
# else:
# output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
# output_shape[0] = q_t.size(1)
# output_shape[1] = output_num_heads
# output_shape[2] = head_dim_v
# return (
# query_buffer_reshaped,
# key_buffer_reshaped,
# value_buffer_reshaped,
# cumulative_sequence_length_q,
# cumulative_sequence_length_kv,
# max_seqlen_batch_q,
# max_seqlen_batch_kv,
# output_shape,
# )
def _sdpa_nested_preprocessing(query, key, value):
# Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
# Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
# Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
q_batch_size = query.size(0)
k_batch_size = key.size(0)
v_batch_size = value.size(0)
q_num_heads = query.size(1)
k_num_heads = key.size(1)
v_num_heads = value.size(1)
if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
q_num_heads == k_num_heads and k_num_heads == v_num_heads
):
raise RuntimeError(
"This path is currently not implemented for jagged layout NT."
)
# return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
num_heads = query.size(1)
head_dim_qk = query.size(3)
head_dim_v = value.size(3)
q_t = query.transpose(1, 2)
k_t = key.transpose(1, 2)
v_t = value.transpose(1, 2)
(
cumulative_sequence_length_q,
max_seqlen_batch_q,
Nnz_q,
) = _cumulative_and_max_seq_len_nnz(q_t)
(
cumulative_sequence_length_kv,
max_seqlen_batch_kv,
Nnz_kv,
) = _cumulative_and_max_seq_len_nnz(k_t)
# [TODO] K and V have to have the same Nnz, should probably torch_check
# assume in order to not iterate over v
# If the physical layout of the NestedTensor's storage
# is not: batch, {seq_len}, num_heads, head_dim then we need
# to call contiguous
if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
q_t = q_t.contiguous()
if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
k_t = k_t.contiguous()
if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
v_t = v_t.contiguous()
query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
output_nt_info = {
"offsets": q_t.offsets(),
"_max_seqlen": q_t._get_max_seqlen(),
"_min_seqlen": q_t._get_min_seqlen(),
}
return (
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
)
def _pad_last_dim(
tensor: torch.Tensor, alignment_size: int, slice: bool
) -> torch.Tensor:
# FlashAttentionV2 requires that head dimension be a multiple of 8
# This was previously done within the kernel, however
# This causes the kernel to maybe alias query, key, value
# So instead we pad the head_dimensions to be a multiple of 8
# in the composite region
last_dim_size = tensor.size(-1)
if last_dim_size % alignment_size == 0:
return tensor
pad_count = alignment_size - (last_dim_size % alignment_size)
tensor = torch.nn.functional.pad(tensor, [0, pad_count])
if slice:
return tensor[..., 0:last_dim_size]
return tensor
# TODO: coalesce with torch/nn/utils/attention.py
def _calculate_scale(query, scale):
# TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
return softmax_scale
def _post_process_flash_output(out: torch.Tensor, og_size):
if not out.is_nested and out.size(-1) != og_size:
out = out[..., 0:og_size]
return out
def _is_computing_meta_flops(x):
# Note: there's a use case of using meta tensors & the dispatch-based flop counter.
# We can use this function to check for this scenario in order to handle it specially.
if not torch.jit.is_scripting() and x.device.type == "meta":
torch_dispatch_mode_stack = (
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
)
return any(
type(x) == torch.utils.flop_counter.FlopCounterMode
for x in torch_dispatch_mode_stack
)
return False
def _autocast(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
[Autocasting SDPA for NJT]
Normal autocasting doesn't work for NJT+SDPA right now:
* NJT intercepts the __torch_function__ call for scaled_dot_product_attention, which happens
before we get to any aten ops or dispatcher logic; then the torch_function logic calls into
efficient attention or flash attention. So, autocasting on the scaled_dot_product_attention
op won't work because we never see that aten op.
* If we put autocasting on `_flash_attention_forward`, then we'll get autocasting to run, but
the kernel selection logic in torch_function handling (ie. jagged_scaled_dot_product_attention)
won't work correctly: the kernel selection logic will run before autocasting, and choose
a kernel based on the un-autocasted dtypes; but then autocasting will run and the actual
attention computation will happen in a different dtype.
An alternative is to just change the backend selection logic for SDPA+NJT to be autocast-aware
and rely on autocasting to do the actual conversions for flash attention / efficient attention.
However, by manually doing the actual autocast before the backend selection, we ensure that the
autocast handling for backend selection doesn't diverge from the autocast handling for the
actual dtype conversions.
"""
device_type = query.device.type
# meta device is not supported by autocast, so break early for it
if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
return query, key, value, attn_mask
def cvt(x):
if x is None:
return x
target_dtype = torch.get_autocast_dtype(device_type)
if (
(not x.dtype.is_floating_point)
or x.dtype == target_dtype
or x.dtype == torch.float64
):
return x
return x.to(target_dtype)
return cvt(query), cvt(key), cvt(value), cvt(attn_mask)
def jagged_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
):
query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
assert (
isinstance(query, NestedTensor)
and isinstance(key, NestedTensor)
and isinstance(value, NestedTensor)
)
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
# Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
# second batch dim instead). For this case, we can just send the dense buffers through
# vanilla SDPA.
if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
output = F.scaled_dot_product_attention(
query.values(),
key.values(),
value.values(),
attn_mask=(
attn_mask.values() if isinstance(attn_mask, NestedTensor) else attn_mask
),
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
)
return nested_view_from_values_offsets(output, query.offsets())
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
)
if _is_computing_meta_flops(query):
# Backend choice will probably not be correct if we have a meta device,
# because backend choice is device-aware. In this case, we mostly just
# want to avoid using math backend (which does a .item() call).
# Arbitrarily choose flash attention.
backend_choice = SDPBackend.FLASH_ATTENTION
if backend_choice == SDPBackend.FLASH_ATTENTION:
og_size = query.size(-1)
query_padded = _pad_last_dim(query, 8, False)
key_padded = _pad_last_dim(key, 8, False)
value_padded = _pad_last_dim(value, 8, False)
# We need to calculate the scale based off the OG head dim size
og_scale = _calculate_scale(query, scale)
(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
(
attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask,
) = torch.ops.aten._flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
is_causal,
False,
scale=og_scale,
)
# Reshape output to convert nnz to batch_size and seq_len
attention = nested_view_from_values_offsets(
attention, # output from flash_attn is [total_q, num_heads, head_size_og]
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
).transpose(1, 2)
return _post_process_flash_output(attention, og_size)
elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_nt_info,
) = _sdpa_nested_preprocessing(query, key, value)
(
attention,
log_sumexp,
seed,
offset,
max_seqlen_q,
max_seqlen_batch_kv,
) = torch.ops.aten._efficient_attention_forward(
query_reshaped.unsqueeze(0),
key_reshaped.unsqueeze(0),
value_reshaped.unsqueeze(0),
None,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
max_seqlen_batch_kv,
dropout_p,
int(is_causal),
compute_logsumexp,
scale=scale,
)
# Reshape output to convert nnz to batch_size and seq_len
return nested_view_from_values_offsets(
attention.squeeze(0),
output_nt_info["offsets"],
min_seqlen=output_nt_info["_min_seqlen"],
max_seqlen=output_nt_info["_max_seqlen"],
).transpose(1, 2)
elif backend_choice == SDPBackend.MATH:
# save the offsets and shape of the inputs, so we can reshape the final output
# query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
# attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
offsets = query.offsets()
d1 = query._size[1]
d2 = value._size[-1]
min_seqlen_tensor = query._metadata_cache.get(
"min_seqlen", None
) # type: ignore[attr-defined]
max_seqlen_tensor = query._metadata_cache.get(
"max_seqlen", None
) # type: ignore[attr-defined]
# convert jagged layout Nested Tensor to strided layout Nested Tensor
# which support the math implementation of SDPA
def get_strided_layout_nested_tensor(jagged_layout_nt):
lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
transpose = torch.transpose(jagged_layout_nt, 1, 2)
tensor_list = transpose.values().split(list(lengths), dim=0)
strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
strided_nt = strided_nt.transpose(1, 2).contiguous()
return strided_nt
query = get_strided_layout_nested_tensor(query)
key = get_strided_layout_nested_tensor(key)
value = get_strided_layout_nested_tensor(value)
attn_out = torch._scaled_dot_product_attention_math(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)[0]
from torch.nested._internal.nested_tensor import _load_val_from_tensor
# convert strided layout Nested Tensor back to jagged layout Nested Tensor
attn_out = attn_out.transpose(1, 2).contiguous().values()
attn_out = attn_out.view(-1, d1, d2)
attn_out = nested_view_from_values_offsets(
attn_out,
offsets,
min_seqlen=(
None
if min_seqlen_tensor is None
else _load_val_from_tensor(min_seqlen_tensor)
),
max_seqlen=(
None
if max_seqlen_tensor is None
else _load_val_from_tensor(max_seqlen_tensor)
),
).transpose(1, 2)
return attn_out
else:
raise RuntimeError(
"No viable backend for scaled_dot_product_attention was found."
)