I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,351 @@
# mypy: allow-untyped-defs
import abc
import copy
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import (
FakeSparsity,
get_arg_info_from_tensor_fqn,
module_contains_param,
module_to_fqn,
swap_module,
)
__all__ = ["BaseSparsifier"]
SUPPORTED_MODULES = {nn.Linear}
KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]
# TODO update desc with new config args
class BaseSparsifier(abc.ABC):
r"""Base class for all sparsifiers.
Abstract methods that need to be implemented:
- update_mask: Function to compute a new mask for all keys in the
`groups`.
Args:
- model [nn.Module]: model to configure. The model itself is not saved
but used for the state_dict saving / loading.
- config [list]: configuration elements should be a dict map that includes
`tensor_fqn` of tensors to sparsify
- defaults [dict]: default configurations will be attached to the
configuration. Only the keys that don't exist in the `config` will
be updated.
Example::
>>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
>>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
>>> defaults = {'sparsity_level': 0.7}
>>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
>>> sparsifier = BaseSparsifier(config, defaults)
"""
def __init__(self, defaults: Optional[Dict[str, Any]] = None):
super().__init__()
self.defaults: Dict[str, Any] = defaults or {}
self.state: Dict[str, Dict] = defaultdict(dict)
self.groups: List[Dict[str, Any]] = []
self.enable_mask_update = True
def __getstate__(self) -> Dict[str, Any]:
return {
"defaults": self.defaults,
"state": self.state,
"groups": self.groups,
}
def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
self.__dict__.update(state)
def __repr__(self):
format_string = self.__class__.__name__ + " ("
for i, sparse_args in enumerate(self.groups):
module = sparse_args["module"]
format_string += "\n"
format_string += f"\tGroup {i}\n"
format_string += f"\t module: {module}\n"
for key in sorted(sparse_args.keys()):
if key == "module":
continue
format_string += f"\t {key}: {sparse_args[key]}\n"
format_string += ")"
return format_string
def state_dict(self) -> Dict[str, Any]:
r"""Returns the state of the optimizer as a :class:`dict`.
It contains:
* state - current state of the sparsification.
* groups - a list containing all sparsity configuration groups
with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model
TODO: Need a clean way of loading the state of the "prepared" module
"""
groups: List[Dict[str, Any]] = [
dict(
filter(
lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT,
mg.items(),
)
)
for mg in self.groups
]
return {
"state": self.state,
"groups": groups,
}
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
groups = copy.deepcopy(state_dict["groups"])
states = state_dict["state"]
for tensor_fqn, s in states.items():
arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
module = arg_info["module"]
tensor_name = arg_info["tensor_name"]
if strict and module is None:
raise RuntimeError(f"Error loading {tensor_fqn} into the model")
found = False
for p in module.parametrizations[tensor_name]:
if isinstance(p, FakeSparsity):
found = True
break
if not found:
p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
parametrize.register_parametrization(module, tensor_name, p)
if s.get("mask", None) is not None:
mask = s.pop("mask")
p.mask = mask
for mg in groups:
if mg["tensor_fqn"] == tensor_fqn:
mg.update(arg_info)
self.__setstate__({"state": states, "groups": groups})
def make_config_from_model(
self,
model: nn.Module,
SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES,
) -> None:
self.config = []
stack = [model]
while stack:
module = stack.pop()
for name, child in module.named_children():
if type(child) in SUPPORTED_MODULES:
module_fqn = module_to_fqn(model, child)
assert isinstance(module_fqn, str) # for mypy
self.config.append({"tensor_fqn": module_fqn + ".weight"})
else:
stack.append(child)
def prepare(self, model, config):
r"""Prepares a model, by adding the parametrizations.
Note::
The model is modified inplace. If you need to preserve the original
model, use copy.deepcopy.
"""
self.model = model # TODO: Need to figure out how to load without this.
self.config = config
# If no config -- try getting all the supported layers
if self.config is None:
self.make_config_from_model(model)
# TODO: Remove the configuration by reference ('module')
for module_config in self.config:
assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:"
"[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
)
assert isinstance(self.defaults, Dict) # for mypy
local_args = copy.deepcopy(self.defaults)
local_args.update(module_config)
tensor_fqn = local_args.get("tensor_fqn", None)
assert tensor_fqn is not None, (
"tensor_fqn is a required argument in the sparsity config which"
"replaces previous `module` and [module]`fqn` arguments"
)
# populate all information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
# check that whatever was put into local_args agrees with what was obtained
# from tensor_fqn
for key in info_from_tensor_fqn.keys():
if key in local_args:
assert (
info_from_tensor_fqn[key] == local_args[key]
or (
key == "tensor_fqn"
and "." + info_from_tensor_fqn[key] == local_args[key]
)
# info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
), f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
local_args.update(info_from_tensor_fqn)
self.groups.append(local_args)
self._prepare()
def _prepare(self, *args, **kwargs):
r"""Adds mask parametrization to the layer weight"""
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
parametrization = config.get("parametrization", FakeSparsity)
mask = config.get("mask", torch.ones_like(getattr(module, tensor_name)))
self.state[config["tensor_fqn"]]["mask"] = mask
parametrize.register_parametrization(
module, tensor_name, parametrization(mask)
)
def squash_mask(
self,
params_to_keep: Optional[Tuple[str, ...]] = None,
params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
*args,
**kwargs,
):
r"""Squashes the sparse masks into the appropriate tensors.
If either the `params_to_keep` or `params_to_keep_per_layer` is set,
the module will have a `sparse_params` dict attached to it.
Args:
params_to_keep: List of keys to save in the module or a dict
representing the modules and keys that will have
sparsity parameters saved
params_to_keep_per_layer: Dict to specify the params that should be
saved for specific layers. The keys in the dict
should be the module fqn, while the values should
be a list of strings with the names of the variables
to save in the `sparse_params`
Examples:
>>> # xdoctest: +SKIP("locals are undefined")
>>> # Don't save any sparse params
>>> sparsifier.squash_mask()
>>> hasattr(model.submodule1, 'sparse_params')
False
>>> # Keep sparse params per layer
>>> sparsifier.squash_mask(
... params_to_keep_per_layer={
... 'submodule1.linear1': ('foo', 'bar'),
... 'submodule2.linear42': ('baz',)
... })
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'baz': 0.1}
>>> # Keep sparse params for all layers
>>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'foo': 42, 'bar': 24}
>>> # Keep some sparse params for all layers, and specific ones for
>>> # some other layers
>>> sparsifier.squash_mask(
... params_to_keep=('foo', 'bar'),
... params_to_keep_per_layer={
... 'submodule2.linear42': ('baz',)
... })
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'foo': 42, 'bar': 24, 'baz': 0.1}
"""
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
parametrize.remove_parametrizations(
module, tensor_name, leave_parametrized=True
)
sparse_params = {}
if params_to_keep is not None:
global_params = {k: config[k] for k in params_to_keep}
sparse_params.update(global_params)
if params_to_keep_per_layer is not None:
params = params_to_keep_per_layer.get(config["module_fqn"], None)
if params is not None:
per_layer_params = {k: config[k] for k in params}
sparse_params.update(per_layer_params)
if sparse_params:
# TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
module.sparse_params = sparse_params
def convert(
self,
module: nn.Module,
mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None,
inplace: bool = False,
parameterization: Type[nn.Module] = FakeSparsity,
):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_dense` method on the target module class
Args:
module: input module
mapping: a dictionary that maps from source module type to target
module type, can be overwritten to allow swapping user defined
Modules
inplace: carry out model transformations in-place, the original module
is mutated
"""
if mapping is None:
raise NotImplementedError("Need to auto generate mapping ")
if not inplace:
module = copy.deepcopy(module)
reassign = {}
for name, mod in module.named_children():
# leaf node
if (
module_contains_param(mod, parameterization)
and type_before_parametrizations(mod) in mapping
):
reassign[name] = swap_module(mod, mapping)
else:
# recurse
reassign[name] = self.convert(
mod,
mapping=mapping,
inplace=True,
parameterization=parameterization,
)
for key, value in reassign.items():
module._modules[key] = value
return module
def step(self, use_path: bool = True) -> None:
if not self.enable_mask_update:
return
with torch.no_grad():
for config in self.groups:
self.update_mask(**config)
@abc.abstractmethod
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
pass

View File

@ -0,0 +1,58 @@
# mypy: allow-untyped-defs
import torch
from . import base_sparsifier
class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier):
r"""Nearly Diagonal Sparsifier
This sparsifier creates a nearly diagonal mask to be applied to the weight matrix.
Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero.
An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively.
1 1 0 0 1 1 1 0
1 1 1 0 1 1 1 1
0 1 1 1 1 1 1 1
0 0 1 1 0 1 1 1
Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated
This sparsifier is controlled by one variable:
1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal.
Currently - supports only odd number
Note:
This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix
feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy
Args:
nearliness: The degree of nearliness (default = 1)
"""
def __init__(self, nearliness: int = 1):
defaults = {"nearliness": nearliness}
super().__init__(defaults=defaults)
def update_mask(self, module, tensor_name, nearliness, **kwargs):
mask = getattr(module.parametrizations, tensor_name)[0].mask
mask.data = torch.zeros_like(mask)
if nearliness <= 0:
return
tensor = getattr(module, tensor_name)
height, width = tensor.shape
if nearliness % 2 == 0:
raise ValueError("nearliness can only be an odd number")
dist_to_diagonal = nearliness // 2
# check
if dist_to_diagonal >= min(height, width):
raise ValueError(
"nearliness cannot be larger than the dimensions of tensor."
)
for row in range(0, height):
# Bounds of entries that needs to be set to 1
low = max(0, row - dist_to_diagonal)
high = min(width, row + dist_to_diagonal + 1)
mask[row, low:high].fill_(1)

View File

@ -0,0 +1,138 @@
# mypy: allow-untyped-defs
from itertools import chain
from typing import Any, Dict, Optional, Type
from torch import nn
from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations
__all__ = [
"module_contains_param",
"swap_module",
"module_to_fqn",
"fqn_to_module",
"get_arg_info_from_tensor_fqn",
"FakeSparsity",
]
def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool:
if is_parametrized(module):
# see if any of the module tensors have a parametriztion attached that matches the one passed in
return any(
any(isinstance(param, parametrization) for param in param_list)
for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator]
)
return False
def swap_module(
mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]]
) -> nn.Module:
r"""Swaps the module using from_dense according to the mapping passed in.
Args:
mod: input module
mapping: a dictionary that maps from nn module to sparse nn module
Return:
The corresponding sparse module of `mod` according to mapping, created using from_dense
"""
if type_before_parametrizations(mod) in mapping:
sparse_mod = mapping[type_before_parametrizations(mod)]
# TODO Fix this typing, as Type[Module] has no attribute "from_dense"
new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined]
# Preserve module's pre forward hooks. They'll be called on quantized input
for pre_hook_fn in mod._forward_pre_hooks.values():
new_mod.register_forward_pre_hook(pre_hook_fn)
# Preserve module's post forward hooks except _observer_forward_hook
# After convert they'll work with quantized output
for hook_fn in mod._forward_hooks.values():
new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert (
len(devices) <= 1
), f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
device = next(iter(devices)) if len(devices) > 0 else None
if device:
new_mod.to(device)
return new_mod
else:
return mod
def module_to_fqn(
model: nn.Module, module: nn.Module, prefix: str = ""
) -> Optional[str]:
"""
Returns the fqn for a module or None if module not a descendent of model.
"""
if module is model:
return ""
for name, child in model.named_children():
fqn = module_to_fqn(child, module, ".")
if isinstance(fqn, str):
return prefix + name + fqn
return None
def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
"""
Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
"""
if path != "":
for name in path.split("."):
model = getattr(model, name, None)
return model
def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
"""
Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
"""
# string manip to split tensor_fqn into module_fqn and tensor_name
# if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
# if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
tensor_name = tensor_fqn.split(".")[-1]
module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
module = fqn_to_module(model, module_fqn)
return {
"module_fqn": module_fqn,
"module": module,
"tensor_name": tensor_name,
"tensor_fqn": tensor_fqn,
}
# Parametrizations
class FakeSparsity(nn.Module):
r"""Parametrization for the weights. Should be attached to the 'weight' or
any other parameter that requires a mask applied to it.
Note::
Once the mask is passed, the variable should not change the id. The
contents of the mask can change, but the mask reference itself should
not.
"""
def __init__(self, mask):
super().__init__()
self.register_buffer("mask", mask)
def forward(self, x):
assert self.mask.shape == x.shape
return self.mask * x
def state_dict(self, *args, **kwargs):
# We don't want to let the parametrizations to save the mask.
# That way we make sure that the linear module doesn't store the masks
# alongside their parametrizations.
return {}

View File

@ -0,0 +1,248 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from .base_sparsifier import BaseSparsifier
__all__ = ["WeightNormSparsifier"]
def _flat_idx_to_2d(idx, shape):
rows = idx // shape[1]
cols = idx % shape[1]
return rows, cols
class WeightNormSparsifier(BaseSparsifier):
r"""Weight-Norm Sparsifier
This sparsifier computes the norm of every sparse block and "zeroes-out" the
ones with the lowest norm. The level of sparsity defines how many of the
blocks is removed.
This sparsifier is controlled by three variables:
1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
the sparse blocks originate at the zero-index of the tensor.
3. `zeros_per_block` is the number of zeros that we are expecting in each
sparse block. By default we assume that all elements within a block are
zeroed-out. However, setting this variable sets the target number of
zeros per block. The zeros within each block are chosen as the *smallest
absolute values*.
Args:
sparsity_level: The target level of sparsity
sparse_block_shape: The shape of a sparse block (see note below)
zeros_per_block: Number of zeros in a sparse block
norm: Norm to use. Could be either `int` or a callable.
If `int`, only L1 and L2 are implemented.
Note::
The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
irrespective of what the rows / cols mean in the data tensor. That means,
if you were to sparsify a weight tensor in the nn.Linear, which has a
weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
channels, while the `block_COLS` would refer to the input channels.
Note::
All arguments to the WeightNormSparsifier constructor are "default"
arguments and could be overriden by the configuration provided in the
`prepare` step.
"""
def __init__(
self,
sparsity_level: float = 0.5,
sparse_block_shape: Tuple[int, int] = (1, 4),
zeros_per_block: Optional[int] = None,
norm: Optional[Union[Callable, int]] = None,
):
if zeros_per_block is None:
zeros_per_block = reduce(operator.mul, sparse_block_shape)
defaults = {
"sparsity_level": sparsity_level,
"sparse_block_shape": sparse_block_shape,
"zeros_per_block": zeros_per_block,
}
if norm is None:
norm = 2
if callable(norm):
self.norm_fn = norm
elif norm == 1:
self.norm_fn = lambda T: T.abs()
elif norm == 2:
self.norm_fn = lambda T: T * T
else:
raise NotImplementedError(f"L-{norm} is not yet implemented.")
super().__init__(defaults=defaults)
def _scatter_fold_block_mask(
self,
output_shape,
dim,
indices,
block_shape,
mask=None,
input_shape=None,
device=None,
):
r"""Creates patches of size `block_shape` after scattering the indices."""
if mask is None:
assert input_shape is not None
mask = torch.ones(input_shape, device=device)
mask.scatter_(dim=dim, index=indices, value=0)
mask.data = F.fold(
mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape
)
return mask
def _make_tensor_mask(
self, data, input_shape, sparsity_level, sparse_block_shape, mask=None
):
r"""Creates a tensor-level mask.
Tensor-level mask is described as a mask, where the granularity of sparsification of the
smallest patch is the sparse_block_shape. That means, that for a given mask and a
sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.
In this context, `sparsity_level` describes the fraction of sparse patches.
"""
h, w = data.shape[-2:]
block_h, block_w = sparse_block_shape
dh = (block_h - h % block_h) % block_h
dw = (block_w - w % block_w) % block_w
if mask is None:
mask = torch.ones(h + dh, w + dw, device=data.device)
if sparsity_level >= 1.0:
mask.data = torch.zeros_like(mask)
return mask
elif sparsity_level <= 0.0:
mask.data = torch.ones_like(mask)
return mask
values_per_block = reduce(operator.mul, sparse_block_shape)
if values_per_block > 1:
# Reduce the data
data = F.avg_pool2d(
data[None, None, :],
kernel_size=sparse_block_shape,
stride=sparse_block_shape,
ceil_mode=True,
)
data = data.flatten()
num_blocks = len(data)
data = data.repeat(1, values_per_block, 1)
threshold_idx = int(round(sparsity_level * num_blocks))
threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check
_, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)
# Temp reshape for mask
mask_reshape = mask.reshape(data.shape) # data might be reshaped
self._scatter_fold_block_mask(
dim=2,
output_shape=(h + dh, w + dw),
indices=sorted_idx,
block_shape=sparse_block_shape,
mask=mask_reshape,
)
mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
return mask
def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
r"""Creates a block-level mask.
Block-level mask is described as a mask, where the granularity of sparsification of the
largest patch is the sparse_block_shape. That means that for a given mask and a
sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.
In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
"""
h, w = data.shape[-2:]
block_h, block_w = sparse_block_shape
dh = (block_h - h % block_h) % block_h
dw = (block_w - w % block_w) % block_w
values_per_block = reduce(operator.mul, sparse_block_shape)
if mask is None:
mask = torch.ones((h + dh, w + dw), device=data.device)
if values_per_block == zeros_per_block:
# Everything should be sparsified
mask.data = torch.zeros_like(mask)
return mask
# create a new padded tensor like data (to match the block_shape)
padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
padded_data.fill_(torch.nan)
padded_data[:h, :w] = data
unfolded_data = F.unfold(
padded_data[None, None, :],
kernel_size=sparse_block_shape,
stride=sparse_block_shape,
)
# Temp reshape for mask
mask_reshape = mask.reshape(unfolded_data.shape)
_, sorted_idx = torch.topk(
unfolded_data, k=zeros_per_block, dim=1, largest=False
)
self._scatter_fold_block_mask(
dim=1,
indices=sorted_idx,
output_shape=padded_data.shape,
block_shape=sparse_block_shape,
mask=mask_reshape,
)
mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous()
return mask
def update_mask(
self,
module,
tensor_name,
sparsity_level,
sparse_block_shape,
zeros_per_block,
**kwargs,
):
values_per_block = reduce(operator.mul, sparse_block_shape)
if zeros_per_block > values_per_block:
raise ValueError(
"Number of zeros per block cannot be more than the total number of elements in that block."
)
if zeros_per_block < 0:
raise ValueError("Number of zeros per block should be positive.")
mask = getattr(module.parametrizations, tensor_name)[0].mask
if sparsity_level <= 0 or zeros_per_block == 0:
mask.data = torch.ones_like(mask)
elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
mask.data = torch.zeros_like(mask)
else:
ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask(
data=ww,
input_shape=ww.shape,
sparsity_level=sparsity_level,
sparse_block_shape=sparse_block_shape,
)
if values_per_block != zeros_per_block:
block_mask = self._make_block_mask(
data=ww,
sparse_block_shape=sparse_block_shape,
zeros_per_block=zeros_per_block,
)
tensor_mask = torch.logical_or(tensor_mask, block_mask)
mask.data = tensor_mask