I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,703 @@
# mypy: allow-untyped-defs
# The Tensor classes are added to this module by python_tensor.cpp
# A workaround to support both TorchScript and MyPy:
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch import Tensor
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
# Semi structured sparsity support
from .semi_structured import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured,
)
if TYPE_CHECKING:
from torch.types import _dtype as DType
DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
else:
# The JIT doesn't understand Union, nor torch.dtype here
DType = int
DimOrDims = Optional[Tuple[int]]
__all__ = [
"addmm",
"check_sparse_tensor_invariants",
"mm",
"sum",
"softmax",
"solve",
"log_softmax",
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
"as_sparse_gradcheck",
]
addmm = _add_docstr(
_sparse._sparse_addmm,
r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
This function does exact same thing as :func:`torch.addmm` in the forward,
except that it supports backward for sparse COO matrix :attr:`mat1`.
When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
When inputs are COO tensors, this function also supports backward for both inputs.
Supports both CSR and COO storage formats.
.. note::
This function doesn't support computing derivaties with respect to CSR matrices.
Args:
mat (Tensor): a dense matrix to be added
mat1 (Tensor): a sparse matrix to be multiplied
mat2 (Tensor): a dense matrix to be multiplied
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
""",
)
mm = _add_docstr(
_sparse._sparse_mm,
r"""
Performs a matrix multiplication of the sparse matrix :attr:`mat1`
and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
:math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
:math:`(n \times p)` tensor.
When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
When inputs are COO tensors, this function also supports backward for both inputs.
Supports both CSR and COO storage formats.
.. note::
This function doesn't support computing derivaties with respect to CSR matrices.
This function also additionally accepts an optional :attr:`reduce` argument that allows
specification of an optional reduction operation, mathematically performs the following operation:
.. math::
z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
CSR storage format on CPU device.
Args:
mat1 (Tensor): the first sparse matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
reduce (str, optional): the reduction operation to apply for non-unique indices
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
Shape:
The format of the output tensor of this function follows:
- sparse x sparse -> sparse
- sparse x dense -> dense
Example::
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
>>> a
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 2., 3.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
>>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
>>> b
tensor([[0., 1.],
[2., 0.],
[0., 0.]], requires_grad=True)
>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseAddmmBackward0>)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 0., 2.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
>>> c = a.detach().to_sparse_csr()
>>> c
tensor(crow_indices=tensor([0, 2, 3]),
col_indices=tensor([0, 2, 1]),
values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
layout=torch.sparse_csr)
>>> y1 = torch.sparse.mm(c, b, 'sum')
>>> y1
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
>>> y2 = torch.sparse.mm(c, b, 'max')
>>> y2
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
""",
)
sampled_addmm = _add_docstr(
_sparse.sparse_sampled_addmm,
r"""
sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
Mathematically this performs the following operation:
.. math::
\text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
and :attr:`beta` are the scaling factors.
:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
.. note::
:attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
Args:
input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
the sampled matrix multiplication
mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
Keyword args:
beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
Examples::
>>> input = torch.eye(3, device='cuda').to_sparse_csr()
>>> mat1 = torch.randn(3, 5, device='cuda')
>>> mat2 = torch.randn(5, 3, device='cuda')
>>> torch.sparse.sampled_addmm(input, mat1, mat2)
tensor(crow_indices=tensor([0, 1, 2, 3]),
col_indices=tensor([0, 1, 2]),
values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
size=(3, 3), nnz=3, layout=torch.sparse_csr)
>>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
tensor([[ 0.2847, 0.0000, 0.0000],
[ 0.0000, -0.7805, 0.0000],
[ 0.0000, 0.0000, -0.1900]], device='cuda:0')
>>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
tensor(crow_indices=tensor([0, 1, 2, 3]),
col_indices=tensor([0, 1, 2]),
values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
size=(3, 3), nnz=3, layout=torch.sparse_csr)
""",
)
def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
r"""Return the sum of each row of the given sparse tensor.
Returns the sum of each row of the sparse tensor :attr:`input` in the given
dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
reduce over all of them. When sum over all ``sparse_dim``, this method
returns a dense tensor instead of a sparse tensor.
All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
tensor having :attr:`dim` fewer dimensions than :attr:`input`.
During backward, only gradients at ``nnz`` locations of :attr:`input`
will propagate back. Note that the gradients of :attr:`input` is coalesced.
Args:
input (Tensor): the input sparse tensor
dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
over all dims.
dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
Default: dtype of :attr:`input`.
Example::
>>> nnz = 3
>>> dims = [5, 5, 2, 3]
>>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, dims[2], dims[3])
>>> size = torch.Size(dims)
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> S = torch.sparse_coo_tensor(I, V, size)
>>> S
tensor(indices=tensor([[2, 0, 3],
[2, 4, 1]]),
values=tensor([[[-0.6438, -1.6467, 1.4004],
[ 0.3411, 0.0918, -0.2312]],
[[ 0.5348, 0.0634, -2.0494],
[-0.7125, -1.0646, 2.1844]],
[[ 0.1276, 0.1874, -0.6334],
[-1.9682, -0.5340, 0.7483]]]),
size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
# when sum over only part of sparse_dims, return a sparse tensor
>>> torch.sparse.sum(S, [1, 3])
tensor(indices=tensor([[0, 2, 3]]),
values=tensor([[-1.4512, 0.4073],
[-0.8901, 0.2017],
[-0.3183, -1.7539]]),
size=(5, 2), nnz=3, layout=torch.sparse_coo)
# when sum over all sparse dim, return a dense tensor
# with summed dims squeezed
>>> torch.sparse.sum(S, [0, 1, 3])
tensor([-2.6596, -1.1450])
"""
if dtype is None:
if dim is not None:
return torch._sparse_sum(input, dim)
else:
return torch._sparse_sum(input)
else:
if dim is not None:
return torch._sparse_sum(input, dim, dtype=dtype)
else:
return torch._sparse_sum(input, dtype=dtype)
softmax = _add_docstr(
_sparse._sparse_softmax,
r"""
sparse.softmax(input, dim, *, dtype=None) -> Tensor
Applies a softmax function.
Softmax is defined as:
:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
where :math:`i, j` run over sparse tensor indices and unspecified
entries are ignores. This is equivalent to defining unspecified
entries as negative infinity so that :math:`exp(x_k) = 0` when the
entry with index :math:`k` has not specified.
It is applied to all slices along `dim`, and will re-scale them so
that the elements lie in the range `[0, 1]` and sum to 1.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""",
)
spsolve = _add_docstr(
_sparse._spsolve,
r"""
sparse.spsolve(input, other, *, left=True) -> Tensor
Computes the solution of a square system of linear equations with
a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
except that the system is defined by a sparse CSR matrix with layout
`sparse_csr`.
Args:
input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
coefficients of the linear system.
other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
side of the linear system.
left (bool, optional): whether to solve the system for `input @ out = other`
(default) or `out @ input = other`. Only `left=True` is supported.
""",
)
log_softmax = _add_docstr(
_sparse._sparse_log_softmax,
r"""
sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
Applies a softmax function followed by logarithm.
See :class:`~torch.sparse.softmax` for more details.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""",
)
spdiags = _add_docstr(
_sparse._spdiags,
r"""
sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
Creates a sparse 2D tensor by placing the values from rows of
:attr:`diagonals` along specified diagonals of the output
The :attr:`offsets` tensor controls which diagonals are set.
- If :attr:`offsets[i]` = 0, it is the main diagonal
- If :attr:`offsets[i]` < 0, it is below the main diagonal
- If :attr:`offsets[i]` > 0, it is above the main diagonal
The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
and an offset may not be repeated.
Args:
diagonals (Tensor): Matrix storing diagonals row-wise
offsets (Tensor): The diagonals to be set, stored as a vector
shape (2-tuple of ints): The desired shape of the result
Keyword args:
layout (:class:`torch.layout`, optional): The desired layout of the
returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
are supported. Default: ``torch.sparse_coo``
Examples:
Set the main and first two lower diagonals of a matrix::
>>> diags = torch.arange(9).reshape(3, 3)
>>> diags
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
>>> s
tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
[0, 1, 2, 0, 1, 0]]),
values=tensor([0, 1, 2, 3, 4, 6]),
size=(3, 3), nnz=6, layout=torch.sparse_coo)
>>> s.to_dense()
tensor([[0, 0, 0],
[3, 1, 0],
[6, 4, 2]])
Change the output layout::
>>> diags = torch.arange(9).reshape(3, 3)
>>> diags
tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
>>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
>>> s
tensor(crow_indices=tensor([0, 1, 3, 6]),
col_indices=tensor([0, 0, 1, 0, 1, 2]),
values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
layout=torch.sparse_csr)
>>> s.to_dense()
tensor([[0, 0, 0],
[3, 1, 0],
[6, 4, 2]])
Set partial diagonals of a large output::
>>> diags = torch.tensor([[1, 2], [3, 4]])
>>> offsets = torch.tensor([0, -1])
>>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
tensor([[1, 0, 0, 0, 0],
[3, 2, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
.. note::
When setting the values along a given diagonal the index into the diagonal
and the index into the row of :attr:`diagonals` is taken as the
column index in the output. This has the effect that when setting a diagonal
with a positive offset `k` the first value along that diagonal will be
the value in position `k` of the row of :attr:`diagonals`
Specifying a positive offset::
>>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
>>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
tensor([[1, 2, 3, 0, 0],
[0, 2, 3, 0, 0],
[0, 0, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
""",
)
class check_sparse_tensor_invariants:
"""A tool to control checking sparse tensor invariants.
The following options exists to manage sparsr tensor invariants
checking in sparse tensor construction:
1. Using a context manager:
.. code:: python
with torch.sparse.check_sparse_tensor_invariants():
run_my_model()
2. Using a procedural approach:
.. code:: python
prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
torch.sparse.check_sparse_tensor_invariants.enable()
run_my_model()
if not prev_checks_enabled:
torch.sparse.check_sparse_tensor_invariants.disable()
3. Using function decoration:
.. code:: python
@torch.sparse.check_sparse_tensor_invariants()
def run_my_model():
...
run_my_model()
4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
For example:
>>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
"""
@staticmethod
def is_enabled():
r"""Return True if the sparse tensor invariants checking is enabled.
.. note::
Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
:func:`torch.sparse.check_sparse_tensor_invariants.disable` to
manage the state of the sparse tensor invariants checks.
"""
return torch._C._check_sparse_tensor_invariants()
@staticmethod
def enable():
r"""Enable sparse tensor invariants checking in sparse tensor constructors.
.. note::
By default, the sparse tensor invariants checks are disabled. Use
:func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
retrieve the current state of sparse tensor invariants checking.
.. note::
The sparse tensor invariants check flag is effective to all sparse
tensor constructors, both in Python and ATen.
The flag can be locally overridden by the ``check_invariants``
optional argument of the sparse tensor constructor functions.
"""
torch._C._set_check_sparse_tensor_invariants(True)
@staticmethod
def disable():
r"""Disable sparse tensor invariants checking in sparse tensor constructors.
See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
"""
torch._C._set_check_sparse_tensor_invariants(False)
# context manager support
def __init__(self, enable=True):
self.state = enable
self.saved_state: Optional[bool] = None
def __enter__(self):
if self.saved_state is not None:
raise RuntimeError(
"This context manager instance is already activated."
" Use a different context manager instance for context nesting."
)
self.saved_state = self.is_enabled()
torch._C._set_check_sparse_tensor_invariants(self.state)
def __exit__(self, type, value, traceback):
assert self.saved_state is not None
torch._C._set_check_sparse_tensor_invariants(self.saved_state)
self.saved_state = None
# decorator support
def __call__(self, mth):
def test_mth(*args, **kwargs):
with type(self)(self.state):
return mth(*args, **kwargs)
return test_mth
def as_sparse_gradcheck(gradcheck):
"""Decorate function, to extend gradcheck for sparse tensors.
Decorator for torch.autograd.gradcheck or its functools.partial
variants that extends the gradcheck function with support to input
functions that operate on or/and return sparse tensors.
The specified gradcheck function itself is guaranteed to operate
on strided tensors only.
For example:
>>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
>>> gradcheck(lambda x: x.to_sparse_csr(), x)
True
"""
def gradcheck_with_sparse_support(func, inputs, **kwargs):
"""
Create gradcheck with support for sparse tensors.
Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
"""
masked = kwargs.pop("masked", False)
sparse_layouts = {
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}
sparse_compressed_layouts = {
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
}
sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
def convert_to_strided_representation(args):
"""Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
if not isinstance(args, (list, tuple)):
args = (args,)
new_args: List[Any] = []
for obj in args:
if (
isinstance(obj, torch.Tensor)
and obj.requires_grad
and obj.layout in sparse_layouts
):
d = dict(layout=obj.layout, shape=obj.shape)
if not masked:
# Materialize unspecified elements with zero values
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
blocksize = (
obj.values().shape[batch_dim + 1 : batch_dim + 3]
if obj.layout in sparse_block_layouts
else None
)
full_mask = torch.ones(
obj.shape, device=obj.device, dtype=torch.bool
).to_sparse(
layout=obj.layout,
blocksize=blocksize,
dense_dim=obj.dense_dim(),
)
obj = obj.to_dense().sparse_mask(full_mask)
if obj.layout is torch.sparse_coo:
d.update(
indices=obj._indices(), is_coalesced=obj.is_coalesced()
)
values = obj._values()
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
d.update(
compressed_indices=obj.crow_indices(),
plain_indices=obj.col_indices(),
)
values = obj.values()
else:
d.update(
compressed_indices=obj.ccol_indices(),
plain_indices=obj.row_indices(),
)
values = obj.values()
new_args.extend(
(STRIDED_REPRESENTATION, d, values.requires_grad_(True))
)
else:
new_args.append(obj)
return tuple(new_args)
def restore_from_strided_representation(args):
"""Restore non-strided differentiable tensosr from their strided representations."""
new_args = []
args = list(args)
while args:
a = args.pop(0)
if a == STRIDED_REPRESENTATION:
d, values = args.pop(0), args.pop(0)
if d["layout"] is torch.sparse_coo:
a = torch.sparse_coo_tensor(
d["indices"],
values,
size=d["shape"],
is_coalesced=d["is_coalesced"],
)
elif d["layout"] in sparse_compressed_layouts:
a = torch.sparse_compressed_tensor(
d["compressed_indices"],
d["plain_indices"],
values,
size=d["shape"],
layout=d["layout"],
)
else:
raise NotImplementedError(
f'conversion of {d["layout"]} strided representation to tensor'
)
new_args.append(a)
return tuple(new_args)
def func_wrapper(*args, **kwargs):
restored_args = restore_from_strided_representation(args)
# convert differentiable output sparse tensors to strided
# tensors:
outputs = func(*restored_args, **kwargs)
strided_outputs = (
tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
)
strided_outputs = tuple(
(
o.to_dense(masked_grad=masked)
if isinstance(o, torch.Tensor)
and o.requires_grad
and o.layout in sparse_layouts
else o
)
for o in strided_outputs
)
return (
strided_outputs
if isinstance(outputs, (list, tuple))
else strided_outputs[0]
)
args = (func_wrapper, convert_to_strided_representation(inputs))
return gradcheck(*args, **kwargs)
return gradcheck_with_sparse_support

View File

@ -0,0 +1,356 @@
# mypy: allow-untyped-defs
import torch
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
"""
This is PyTorch implementation of main part of reorder_meta()
function, from tools/util/include/cutlass/util/host_reorder.h file
of CUTLASS source tree. Furthermore, CUTLASS template for sparse
GEMM decides upon layout of this matrix, and at the moment for the
sparse GEMM executed on tensor cores, this is layout described by
ColumnMajorInterleaved<2> data structure, in
include/cutlass/layout/matrix.h of CUTLASS source tree. The
reordering of meta matrix into meta_reordered matrix calculated
according to these segments of CUTLASS code is re-implemented here.
Note that this calculation produces offsets for scattering metadata
matrix elements into reordered metadata matrix elements (or,
equivalently, for gathering reordered metadata matrix element back
into metadata matrix elements).
"""
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
# Reorder the rows, then swizzle the 2x2 blocks.
group = 32 if meta_dtype.itemsize == 2 else 16
interweave = 4 if meta_dtype.itemsize == 2 else 2
dst_rows = (
dst_rows // group * group
+ (dst_rows % 8) * interweave
+ (dst_rows % group) // 8
)
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
dst_rows += topright - bottomleft
dst_cols -= topright - bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave = 2
cols_maj = dst_cols // interleave
cols_min = dst_cols % interleave
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
def sparse_semi_structured_from_dense_cutlass(dense):
"""
This function converts dense matrix into sparse semi-structured
representation, producing "compressed" matrix, in the layout used by
CUTLASS backend, and corresponding metadata matrix.
"""
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
)
m, k = dense.shape
device = dense.device
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError("Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 16"
)
else:
if m % 32 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 32"
)
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
)
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
)
elif quadbits_per_meta_elem == 8:
meta = (
meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28)
)
# Reorder meta tensor elements.
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device
)
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
return (sparse, meta_reordered.view(m, meta_ncols))
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
"""
This function performs reverse of the function above - it
reconstructs dense matrix from a pair of "compressed" matrix, given
in the layout used by CUTLASS backend, and accompanying metadata
matrix.
"""
if sparse.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
)
m, k = sparse.shape
device = sparse.device
if meta_reordered.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
)
if meta_reordered.device != device:
raise RuntimeError(
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
)
meta_dtype = meta_reordered.dtype
if meta_dtype not in (torch.int16, torch.int32):
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if sparse.dtype != torch.float:
ksparse = 4
else:
ksparse = 2
meta_nrows, meta_ncols = meta_reordered.shape
if meta_nrows != m:
raise RuntimeError(
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
)
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
raise RuntimeError(
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device
)
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2 = torch.empty(
(m, meta_ncols, 2 * quadbits_per_meta_elem),
dtype=meta_dtype,
device=device,
)
if quadbits_per_meta_elem == 4:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
elif quadbits_per_meta_elem == 8:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
meta_2[:, :, 8] = (meta >> 16) & 0b11
meta_2[:, :, 9] = (meta >> 18) & 0b11
meta_2[:, :, 10] = (meta >> 20) & 0b11
meta_2[:, :, 11] = (meta >> 22) & 0b11
meta_2[:, :, 12] = (meta >> 24) & 0b11
meta_2[:, :, 13] = (meta >> 26) & 0b11
meta_2[:, :, 14] = (meta >> 28) & 0b11
meta_2[:, :, 15] = (meta >> 30) & 0b11
dense_offsets = meta_2.view(-1) + (
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
).view(-1, 1).repeat(1, 2).view(-1)
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
if sparse.dtype != torch.float:
dense.scatter_(0, dense_offsets, sparse.view(-1))
else:
dense.view(torch.half).scatter_(
0, dense_offsets, sparse.view(torch.half).view(-1)
)
return dense.view(m, 2 * k)
def _sparse_semi_structured_tile(dense):
"""
This function computes a 2:4 sparse tile by greedily taking the largest values.
Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
the ultimate sparsity pattern.
Note that this function does not have the same sorting semantics as our CUDA backend,
which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
"""
def greedy_prune_tile(tile):
num_kept_row = [0, 0, 0, 0]
num_kept_col = [0, 0, 0, 0]
for x in tile.flatten().sort(descending=True, stable=True).indices:
r, c = x // 4, x % 4
if num_kept_row[r] < 2 and num_kept_col[c] < 2:
num_kept_row[r] += 1
num_kept_col[c] += 1
else:
tile[r, c] = 0
for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
for tile in batch:
greedy_prune_tile(tile)
return dense
def _compute_compressed_swizzled_bitmask(dense):
"""
Calculates the compressed swizzled bitmask from a dense tensor
"""
# first we need to convert the dense tensor to a bitmask
int_bitmask = dense.bool().to(torch.uint8)
# Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
# A, B, C and D, as displayed in the following schema:
# +---+---+
# | A | B |
# +---+---+
# | C | D |
# +---+---+
# we first need to split into the 8x8 tiles
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
# then we unfold again to get our indivdual 4x4 tiles
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
# Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
# of that tile. Note that the least siginificant bit is stored first.
# [1 1 0 0]
# [1 1 0 0] -> 0011 0011 -> 51
# [0 0 1 1] 1100 1100 204
# [0 0 1 1]
# reshape tensor to expand tiles into 8-bit vectors
bitmask_binary_representation = bitmask_4x4_chunks.reshape(
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
)
# to convert from binary representaiton, we can do a matmul with powers of two
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
# To run on GPU: cast to float to do matmul and then cast back
compressed_swizzled_bitmask = (
bitmask_binary_representation.to(torch.float) @ powers_of_two
).to(torch.uint8)
return compressed_swizzled_bitmask

View File

@ -0,0 +1,168 @@
# mypy: allow-untyped-defs
import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",
"semi_sparse_indices",
"semi_sparse_t",
"semi_sparse_view",
"semi_sparse_detach",
"semi_sparse_mm",
"semi_sparse_addmm",
"semi_sparse_linear",
]
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def fallback_dispatcher(func, types, args, kwargs):
with no_dispatch():
return func(*args)
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
return A.packed[:num_kept_elements:].view(m, -1)
else:
return A.packed.detach()
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
metadata = A.packed[num_kept_elements:].view(m, -1)
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
else:
return A.meta
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
self = args[0]
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
assert len(self.shape) == 2
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,
meta=self.meta_t,
packed_t=self.packed,
meta_t=self.meta,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
if self.compressed_swizzled_bitmask is not None
else None,
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt,
)
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
self, shape = args
if tuple(shape) != self.shape:
raise NotImplementedError(
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
)
return self
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
assert len(args) == 1
self = args[0]
return self.__class__(
shape=self.shape,
packed=self.packed,
meta=self.meta,
packed_t=self.packed_t,
meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
requires_grad=False,
)
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
row, col = B.shape
B_padded = A._pad_dense_input(B)
res = A._mm(B_padded)
return res[:, :col]
else:
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B._pad_dense_input(A)
res = B_t._mm(A_padded.t()).t()
return res[:row, :]
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 3
bias, A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if bias.ndim != 1:
raise NotImplementedError(
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
)
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B_t._pad_dense_input(A)
result = B_t._mm(A_padded.t(), bias=bias).t()
return result[:row, :]
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) in [2, 3]
A, B = args[:2]
bias = args[2] if len(args) == 3 else None
shape = A.shape
A_2d = A.view(-1, shape[-1])
if bias is None:
res = A_2d @ B.t()
else:
res = semi_sparse_addmm(
func=None,
types=None,
args=[bias, A_2d, B.t()],
)
return res.view(*shape[:-1], -1)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,648 @@
# mypy: allow-untyped-defs
import warnings
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_addmm,
semi_sparse_detach,
semi_sparse_indices,
semi_sparse_linear,
semi_sparse_mm,
semi_sparse_t,
semi_sparse_values,
semi_sparse_view,
)
__all__ = [
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
]
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
"_SEMI_STRUCTURED_SPARSE_CONFIG",
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
)
class SparseSemiStructuredTensor(torch.Tensor):
"""
This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
structured sparsity.
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
Note that as such, this class cannot be insantiated directly.
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
"""
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
BACKEND: str
SPARSE_DISPATCH: Dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
compressed_swizzled_bitmask: Optional[torch.Tensor]
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
__slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
packed: Optional[torch.Tensor],
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
compressed_swizzled_bitmask: Optional[torch.Tensor],
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
):
"""
Create a new instance of the tensor subclass from the compressed sparse representation.
We have the option to create the subclass with the compressed representations of both X and X', for training.
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
Args:
shape: The shape of the original dense tensor
packed: The compressed representation of the original dense tensor
meta: The metadata of the original dense tensor, if it is stored separately
packed_t: The compressed representation of the transposed original dense tensor
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
participate in the computation. Used for pointwise ops.
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
with a matmul, which is useful in the case of 2:4 sparse training.
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
ValueError: If all of the tensor arguments are None.
"""
if not cls._PROTOTYPE_WARNING_SHOWN:
warnings.warn(
(
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.sparse "
"module for further information about the project."
),
UserWarning,
)
cls._PROTOTYPE_WARNING_SHOWN = True
# Because this only runs onces, we also load the dispatch table here as well.
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
# we can also register the classes with dynamo when the warning is shown.
torch._dynamo.allow_in_graph(cls)
if packed is not None:
previous_tensor = packed
elif packed_t is not None:
previous_tensor = packed_t
else:
raise ValueError("At least one of packed or packed_t must be provided")
kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.packed = packed
tensor.meta = meta
tensor.packed_t = packed_t
tensor.meta_t = meta_t
tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
tensor.alg_id_cusparselt = alg_id_cusparselt
return tensor
def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"
def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (
self.shape,
self.fuse_transpose_cusparselt,
self.alg_id_cusparselt,
self.requires_grad,
)
return inner_tensors, tensor_meta
@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
compressed_swizzled_bitmask=inner_tensors.get(
"compressed_swizzled_bitmask", None
),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func._overloadpacket not in cls.SPARSE_DISPATCH:
raise NotImplementedError(
f"{cls.__name__} only supports a specific set of operations, "
f"can't perform requested op ({func.__name__})"
)
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
@classmethod
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
"""
Loads the op overload sparse dispatch table for the current class.
"""
if getattr(cls, "SPARSE_DISPATCH", None) is None:
cls.SPARSE_DISPATCH = {
torch.ops.aten.values: semi_sparse_values,
torch.ops.aten.indices: semi_sparse_indices,
torch.ops.aten.is_same_size: fallback_dispatcher,
torch.ops.aten.detach_: fallback_dispatcher,
torch.ops.aten.detach: semi_sparse_detach,
torch.ops.aten.t: semi_sparse_t,
torch.ops.aten.view: semi_sparse_view,
torch.ops.aten.mm: semi_sparse_mm,
torch.ops.aten.matmul: semi_sparse_mm,
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
torch.ops.aten._to_copy: fallback_dispatcher,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
# check contiguous
if not original_tensor.is_contiguous():
raise RuntimeError(
"Error original_tensor is not contiguous!"
"Only contiguous tensors are currently supported."
)
# check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
)
# check shape
m, n = original_tensor.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)
@classmethod
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
"""
Calculates padding for dense tensor and pads tensor if necessary.
If padding is not required, this function returns the original tensor.
"""
# only 2d matmul
assert dense_input.dim() == 2
# check shape
m, n = dense_input.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
# calculate padding
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
if to_pad_m or to_pad_n:
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
else:
return dense_input
def to_dense(self):
col = self.shape[-1]
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
@classmethod
def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError
def to_sparse_semi_structured(
original_tensor: torch.Tensor,
transposed: bool = False,
) -> SparseSemiStructuredTensor:
"""
This function converts a dense tensor into a sparse semi-structured tensor.
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
Args:
original_tensor (Tensor): the dense tensor to convert
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
Raises:
None
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
...,
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A)
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
>>> A_sparse.values()
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
>>> A_sparse.indices()
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
"""
if transposed:
warnings.warn(
"Setting transpose from `to_sparse_semi_structured` is deprecated "
"and will be removed in a future release. "
"`SparseSemiStructuredTensor` only support contiguous input tensors.",
FutureWarning,
stacklevel=2,
)
# set from _FORCE_CUTLASS flag
SPARSE_SUBCLASS = (
torch.sparse.SparseSemiStructuredTensorCUTLASS
if SparseSemiStructuredTensor._FORCE_CUTLASS
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
)
return SPARSE_SUBCLASS.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
BACKEND = "cutlass"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
}
@classmethod
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUTLASS":
cls._validate_device_dim_dtype_shape(original_tensor)
(
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
meta=meta_tensor_cutlass,
packed_t=None,
meta_t=None,
compressed_swizzled_bitmask=None,
requires_grad=original_tensor.requires_grad,
)
def to_dense(self):
assert self.meta is not None and self.packed is not None
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
)
if self.meta.ndim == 2
else super().to_dense()
)
@classmethod
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
pruned dense tensor.
Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
This can be used in the backward pass to mask the gradients.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
[1 2 6 2] [0 0 6 2] -> metadata
-> pack to transposed CUTLASS -> packed_t
semi-structured representation -> metadata_t
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUTLASS
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
```
"""
# We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=True
)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
cls_name = self.__class__.__name__
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{cls_name}` matmul: Broadcasting is not implemented"
)
if self.packed is None or self.meta is None:
raise NotImplementedError(
f"`{cls_name}` matmul: operation is not supported"
)
else:
if bias is None:
res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
)
return res[: self.shape[0]]
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
"""
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
packed = [ specified elements of original tensor | metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
attributes respectively.
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
BACKEND = "cusparselt"
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
}
@classmethod
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
meta=None,
packed_t=None,
meta_t=None,
compressed_swizzled_bitmask=None,
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
requires_grad=original_tensor.requires_grad,
)
@classmethod
def prune_dense_static_sort(
cls, original_tensor: torch.Tensor, algorithm=""
) -> "SparseSemiStructuredTensor":
"""
This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
layout and sparse matmul.
The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
[9 1 7 4] [9 0 7 0]
[1 2 3 0] [0 2 0 0]
[8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
[1 2 6 2] [0 0 6 2]
-> pack to transposed cuSPARSELt -> packed_t
semi-structured representation
-> compute swizzled bitmask -> compressed_swizzled_bitmask
The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
```
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
pruned = _sparse_semi_structured_tile(dense)
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
bitmask = _compute_compressed_swizzled_bitmask(pruned)
SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
```
"""
(
packed,
meta,
packed_t,
meta_t,
compressed_swizzled_bitmask,
) = torch._sparse_semi_structured_tile(
original_tensor, algorithm=algorithm, use_cutlass=False
)
return cls(
original_tensor.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=compressed_swizzled_bitmask,
requires_grad=False,
)
def _mm(
self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
)
if B.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
"This operation is only supported when A and B have the same data type."
)
if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type."
)
if self.packed is None:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported"
)
else:
res = torch._cslt_sparse_mm(
self.packed,
B,
bias=bias,
transpose_result=self.fuse_transpose_cusparselt,
alg_id=self.alg_id_cusparselt,
)
return res.t() if self.fuse_transpose_cusparselt else res