I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,39 @@
from . import parametrizations, rnn, stateless
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
from .convert_parameters import parameters_to_vector, vector_to_parameters
from .fusion import (
fuse_conv_bn_eval,
fuse_conv_bn_weights,
fuse_linear_bn_eval,
fuse_linear_bn_weights,
)
from .init import skip_init
from .memory_format import (
convert_conv2d_weight_memory_format,
convert_conv3d_weight_memory_format,
)
from .spectral_norm import remove_spectral_norm, spectral_norm
from .weight_norm import remove_weight_norm, weight_norm
__all__ = [
"clip_grad_norm",
"clip_grad_norm_",
"clip_grad_value_",
"convert_conv2d_weight_memory_format",
"convert_conv3d_weight_memory_format",
"fuse_conv_bn_eval",
"fuse_conv_bn_weights",
"fuse_linear_bn_eval",
"fuse_linear_bn_weights",
"parameters_to_vector",
"parametrizations",
"remove_spectral_norm",
"remove_weight_norm",
"rnn",
"skip_init",
"spectral_norm",
"stateless",
"vector_to_parameters",
"weight_norm",
]

View File

@ -0,0 +1,54 @@
# mypy: allow-untyped-defs
import importlib
import warnings
from typing import Callable, List
_MESSAGE_TEMPLATE = (
r"Usage of '{old_location}' is deprecated; please use '{new_location}' instead."
)
def lazy_deprecated_import(
all: List[str],
old_module: str,
new_module: str,
) -> Callable:
r"""Import utility to lazily import deprecated packages / modules / functional.
The old_module and new_module are also used in the deprecation warning defined
by the `_MESSAGE_TEMPLATE`.
Args:
all: The list of the functions that are imported. Generally, the module's
__all__ list of the module.
old_module: Old module location
new_module: New module location / Migrated location
Returns:
Callable to assign to the `__getattr__`
Usage:
# In the `torch/nn/quantized/functional.py`
from torch.nn.utils._deprecation_utils import lazy_deprecated_import
_MIGRATED_TO = "torch.ao.nn.quantized.functional"
__getattr__ = lazy_deprecated_import(
all=__all__,
old_module=__name__,
new_module=_MIGRATED_TO)
"""
warning_message = _MESSAGE_TEMPLATE.format(
old_location=old_module, new_location=new_module
)
def getattr_dunder(name):
if name in all:
# We are using the "RuntimeWarning" to make sure it is not
# ignored by default.
warnings.warn(warning_message, RuntimeWarning)
package = importlib.import_module(new_module)
return getattr(package, name)
raise AttributeError(f"Module {new_module!r} has no attribute {name!r}.")
return getattr_dunder

View File

@ -0,0 +1,10 @@
from .conv_expanded_weights import ConvPerSampleGrad
from .embedding_expanded_weights import EmbeddingPerSampleGrad
from .expanded_weights_impl import ExpandedWeight
from .group_norm_expanded_weights import GroupNormPerSampleGrad
from .instance_norm_expanded_weights import InstanceNormPerSampleGrad
from .layer_norm_expanded_weights import LayerNormPerSampleGrad
from .linear_expanded_weights import LinearPerSampleGrad
__all__ = ["ExpandedWeight"]

View File

@ -0,0 +1,68 @@
# mypy: allow-untyped-defs
import torch
import torch.nn.functional as F
from .conv_utils import (
conv_args_and_kwargs,
conv_backward,
conv_input_for_string_padding,
conv_picker,
)
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
from .expanded_weights_utils import forward_helper
@implements_per_sample_grads(F.conv1d)
@implements_per_sample_grads(F.conv2d)
@implements_per_sample_grads(F.conv3d)
class ConvPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, conv_fn, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = conv_args_and_kwargs(
kwarg_names, expanded_args_and_kwargs
)
orig_input = expanded_args[0]
was_same_padding = expanded_kwargs["padding"] == "same"
if isinstance(expanded_kwargs["padding"], str):
# if padding is a string, we'll do the necessary padding (slowly) using F.pad
kernel_size = expanded_args[1].shape[2:]
padding, dilation = expanded_kwargs["padding"], expanded_kwargs["dilation"]
input = conv_input_for_string_padding(
conv_fn, padding, expanded_args[0], dilation, kernel_size
)
expanded_args = (input, expanded_args[1])
# since we've already done the padding, don't need any more
expanded_kwargs["padding"] = 0
output = forward_helper(conv_fn, expanded_args, expanded_kwargs)
input, weight = expanded_args
batched_dim_size = conv_picker(conv_fn, 3, 4, 5)
if input.dim() != batched_dim_size:
raise RuntimeError(
f"Expanded Weights only support convolution with batched input, got {conv_fn} with an"
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
)
ctx.conv_fn = conv_fn
ctx.batch_size = orig_input.shape[0]
ctx.input_required_grad = orig_input.requires_grad
ctx.orig_input_shape = orig_input.shape
ctx.was_same_padding = was_same_padding
ctx.stride, ctx.padding = expanded_kwargs["stride"], expanded_kwargs["padding"]
ctx.dilation, ctx.groups = (
expanded_kwargs["dilation"],
expanded_kwargs["groups"],
)
if isinstance(weight, ExpandedWeight):
ctx.input = input
ctx.weight = weight
ctx.bias = expanded_kwargs["bias"]
return output
@staticmethod
def backward(ctx, grad_output):
return conv_backward(ctx.conv_fn, ctx, grad_output)

View File

@ -0,0 +1,353 @@
# mypy: allow-untyped-defs
from typing import List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from .expanded_weights_utils import (
set_grad_sample_if_exists,
unpack_expanded_weight_or_tensor,
)
THRESHOLD = 32
def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt):
if func == F.conv1d:
return conv1dOpt
if func == F.conv2d:
return conv2dOpt
else:
assert func == F.conv3d
return conv3dOpt
def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
args = expanded_args_and_kwargs[: len(expanded_args_and_kwargs) - len(kwarg_names)]
kwargs = expanded_args_and_kwargs[
len(expanded_args_and_kwargs) - len(kwarg_names) :
]
kwargs = dict(zip(kwarg_names, kwargs))
return conv_normalizer(*args, **kwargs)
def conv_normalizer(
input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
):
return (input, weight), {
"bias": bias,
"stride": stride,
"padding": padding,
"dilation": dilation,
"groups": groups,
}
def conv_input_for_string_padding(func, padding_style, input, dilation, kernel_size):
if padding_style == "valid":
return input
else:
padding = int_padding_for_string_padding(
func, padding_style, dilation, kernel_size
)
return F.pad(input, padding)
def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
def get_dilation(i):
return dilation[i] if isinstance(dilation, tuple) else dilation
if padding_style == "same":
padding: List[int] = []
# F.pad needs the padding in reverse order from what conv expects
for i in range(conv_picker(func, 0, 1, 2), -1, -1):
padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
return padding
elif padding_style == "valid":
return conv_picker(func, 2, 4, 6) * (0,)
else:
raise RuntimeError(
f"got padding type of {padding_style}, only accept 'same' or 'valid'"
)
def conv_padding_for_same(dilation, kernel_size):
total_pad = dilation * (kernel_size - 1)
left_pad = total_pad // 2
right_pad = total_pad - left_pad
return left_pad, right_pad
def conv_backward(func, ctx, grad_output):
def weight_grad_sample(weight):
if batch_size < THRESHOLD and groups == 1:
return conv_group_weight_grad_sample(
ctx.input,
grad_output,
weight_shape,
stride,
padding,
dilation,
batch_size,
func,
)
else:
return conv_unfold_weight_grad_sample(
ctx.input,
grad_output,
weight_shape,
kernel_size,
stride,
padding,
dilation,
groups,
func,
)
def expand(param):
if isinstance(param, int):
return conv_picker(func, (param,), (param, param), (param, param, param))
else:
return param
def calc_total_padding(func, was_same, padding, dilation, kernel_size):
if was_same:
all_padding = int_padding_for_string_padding(
func, "same", dilation, kernel_size
)
# F.pad needs the padding in reverse order from what conv expects
total_padding = tuple(
all_padding[i] + all_padding[i - 1]
for i in range(len(all_padding) - 1, -1, -2)
)
return total_padding
else:
return tuple(2 * pad for pad in padding)
weight_shape = ctx.weight.shape
stride, padding, dilation, groups = (
expand(ctx.stride),
expand(ctx.padding),
expand(ctx.dilation),
ctx.groups,
)
kernel_size = []
for i in range(2, conv_picker(func, 3, 4, 5)):
kernel_size.append(weight_shape[i])
batch_size = ctx.batch_size
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
# "same" padding may give uneven padding on either side so we need to separate the "padding" attr and total padding
total_padding = calc_total_padding(
func, ctx.was_same_padding, padding, dilation, kernel_size
)
if ctx.input_required_grad:
output_padding = []
input_dims = conv_picker(func, 1, 2, 3)
for i in range(input_dims):
input_dim = ctx.orig_input_shape[2 + i]
output_padding.append(
(
total_padding[i]
+ input_dim
- (kernel_size[i] * dilation[i] - dilation[i] + 1)
)
% stride[i]
)
weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
transpose_func = conv_picker(
func, F.conv_transpose1d, F.conv_transpose2d, F.conv_transpose3d
)
out = transpose_func(
grad_output,
weight_,
None,
stride,
padding,
tuple(output_padding),
groups,
dilation,
)
if ctx.was_same_padding:
for i in range(len(total_padding)):
out = torch.narrow(
out, 2 + i, total_padding[i] // 2, ctx.orig_input_shape[2 + i]
)
results.append(out)
else:
results.append(None)
# weight and bias don't compute batched gradients; no other arguments are differentiable
results = results + [None] * 6
# set grad_sample field for weight and bias with per sample gradients
set_grad_sample_if_exists(ctx.weight, weight_grad_sample)
set_grad_sample_if_exists(
ctx.bias, lambda _: grad_output.reshape(*grad_output.shape[:2], -1).sum(dim=2)
)
return tuple(results)
def conv_unfold_weight_grad_sample(
input,
grad_output,
weight_shape,
kernel_size,
stride,
padding,
dilation,
groups,
func,
):
n = input.shape[0]
in_channels = input.shape[1]
unfold_func = conv_picker(
func,
lambda: F.unfold(
input.unsqueeze(-2),
kernel_size=(1, kernel_size[0]),
dilation=(1, dilation[0]),
padding=(0, padding[0]),
stride=(1, stride[0]),
),
lambda: F.unfold(
input, kernel_size, dilation=dilation, padding=padding, stride=stride
),
lambda: unfold3d(input, kernel_size, padding, stride, dilation),
)
input = unfold_func()
grad_output = grad_output.reshape(n, -1, input.shape[-1])
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
# rearrange the above tensor and extract diagonals.
weight_grad_sample = weight_grad_sample.view(
n,
groups,
-1,
groups,
int(in_channels / groups),
np.prod(kernel_size),
)
weight_grad_sample = torch.einsum(
"ngrg...->ngr...", weight_grad_sample
).contiguous()
shape = [n] + list(weight_shape)
weight_grad_sample = weight_grad_sample.view(shape)
return weight_grad_sample
def conv_group_weight_grad_sample(
input,
grad_output,
weight_shape,
stride,
padding,
dilation,
batch_size,
func,
):
I = input.shape[1]
O = grad_output.shape[1]
input_ = input.transpose(0, 1)
grad_output_ = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:]
)
weight_grad_sample = func(
input_,
grad_output_,
None,
stride=dilation,
padding=padding,
dilation=stride,
groups=batch_size,
)
input_dims = conv_picker(func, 3, 4, 5)
for i in range(2, input_dims):
weight_grad_sample = weight_grad_sample.narrow(i, 0, weight_shape[i])
weight_grad_sample = weight_grad_sample.view(
I, batch_size, O, *weight_grad_sample.shape[2:]
)
weight_grad_sample = weight_grad_sample.movedim(0, 2)
return weight_grad_sample
def unfold3d(
tensor,
kernel_size,
padding,
stride,
dilation,
):
r"""
Extract sliding local blocks from an batched input tensor.
:class:`torch.nn.Unfold` only supports 4D inputs (batched image-like tensors).
This method implements the same action for 5D inputs
Args:
tensor: An input tensor of shape ``(B, C, D, H, W)``.
kernel_size: the size of the sliding blocks
padding: implicit zero padding to be added on both sides of input
stride: the stride of the sliding blocks in the input spatial dimensions
dilation: the spacing between the kernel points.
Returns:
A tensor of shape ``(B, C * np.prod(kernel_size), L)``, where L - output spatial dimensions.
See :class:`torch.nn.Unfold` for more details
Example:
>>> # xdoctest: +SKIP
>>> B, C, D, H, W = 3, 4, 5, 6, 7
>>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
torch.Size([3, 32, 120])
"""
if len(tensor.shape) != 5:
raise ValueError(
f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}"
)
if dilation != (1, 1, 1):
raise NotImplementedError(f"dilation={dilation} not supported.")
batch_size, channels, _, _, _ = tensor.shape
# Input shape: (B, C, D, H, W)
tensor = F.pad(
tensor, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0])
)
# Output shape: (B, C, D+2*padding[2], H+2*padding[1], W+2*padding[0])
tensor = tensor.unfold(dimension=2, size=kernel_size[0], step=stride[0])
tensor = tensor.unfold(dimension=3, size=kernel_size[1], step=stride[1])
tensor = tensor.unfold(dimension=4, size=kernel_size[2], step=stride[2])
# Output shape: (B, C, D_out, H_out, W_out, kernel_size[0], kernel_size[1], kernel_size[2])
# For D_out, H_out, W_out definitions see :class:`torch.nn.Unfold`
tensor = tensor.permute(0, 2, 3, 4, 1, 5, 6, 7)
# Output shape: (B, D_out, H_out, W_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
tensor = tensor.reshape(batch_size, -1, channels * np.prod(kernel_size)).transpose(
1, 2
)
# Output shape: (B, D_out * H_out * W_out, C * kernel_size[0] * kernel_size[1] * kernel_size[2]
return tensor

View File

@ -0,0 +1,83 @@
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
set_grad_sample_if_exists,
standard_kwargs,
)
@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
if len(expanded_args[0].shape) == 1:
raise RuntimeError(
f"Expanded Weights needs an input with a batch size, got a 1D tensor, {expanded_args[0]}"
)
output = forward_helper(F.embedding, expanded_args, expanded_kwargs)
ctx.input, ctx.weight = expanded_args
ctx.padding_idx, ctx.scale_grad_by_freq = (
expanded_kwargs["padding_idx"],
expanded_kwargs["scale_grad_by_freq"],
)
ctx.sparse = expanded_kwargs["sparse"]
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.input, ctx.weight
padding_idx, scale_grad_by_freq, sparse = (
ctx.padding_idx,
ctx.scale_grad_by_freq,
ctx.sparse,
)
def weight_per_sample_grad(weight):
batch_size = input.shape[0]
embedding_dim = weight.shape[1]
index = (
input.unsqueeze(-1)
.expand(*input.shape, embedding_dim)
.reshape(batch_size, -1, embedding_dim)
)
grad_sample = torch.zeros(
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
)
return grad_sample.scatter_add_(
1, index, grad_output.reshape(batch_size, -1, embedding_dim)
)
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:
bw_fn = torch.ops.aten.embedding_backward
results.append(
bw_fn(
grad_output,
input,
weight.shape[0],
padding_idx,
scale_grad_by_freq,
sparse,
)
)
else:
results.append(None)
# weight doesn't compute batched gradients; no other arguments are differentiable (2 not saved from forward)
results = results + [None] * 6
# set grad_sample field for weight with per sample gradients
set_grad_sample_if_exists(weight, weight_per_sample_grad)
return tuple(results)

View File

@ -0,0 +1,182 @@
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from typing import Callable, Dict
import torch
from torch._decomp import decomposition_table
from torch.utils._pytree import tree_map_only
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
aten = torch._ops.ops.aten
# __torch_function__ runs before the pydispatcher so we need to manually use the same
# decompositions indexed by their torch equivalent
expanded_weights_rnn_decomps = {
# func: (input_decomp, data_decomp)
torch.rnn_relu: (
decomposition_table[aten.rnn_relu.input],
decomposition_table[aten.rnn_relu.data],
),
torch.rnn_tanh: (
decomposition_table[aten.rnn_tanh.input],
decomposition_table[aten.rnn_tanh.data],
),
torch.lstm: (
decomposition_table[aten.lstm.input],
decomposition_table[aten.lstm.data],
),
torch.gru: (
decomposition_table[aten.gru.input],
decomposition_table[aten.gru.data],
),
}
# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set
@contextmanager
def batch_second(args, kwargs):
def set_batch_second(ew):
ew.set_batch_first(False)
def reset_batch_first(ew):
ew.set_batch_first(True)
tree_map_only(ExpandedWeight, set_batch_second, args)
tree_map_only(ExpandedWeight, set_batch_second, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, reset_batch_first, args)
tree_map_only(ExpandedWeight, reset_batch_first, kwargs)
# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch
@contextmanager
def allow_smaller_batches(args, kwargs):
def allow(ew):
ew.set_allow_smaller_batches(True)
def reset(ew):
ew.set_allow_smaller_batches(False)
tree_map_only(ExpandedWeight, allow, args)
tree_map_only(ExpandedWeight, allow, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, reset, args)
tree_map_only(ExpandedWeight, reset, kwargs)
@contextmanager
def setup_rnn(use_input_variant, args, kwargs):
with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(
args, kwargs
):
yield
def implements_per_sample_grads(torch_function):
@functools.wraps(torch_function)
def decorator(autograd_func):
HANDLED_FUNCTIONS[torch_function] = autograd_func
return autograd_func
return decorator
# ExpandedWeight represents a weight (parameter) Tensor that has an expanded
# batch dimension. Operations on the ExpandedWeight Tensor act exactly like
# those without an expanded batch dimension but a call to .backward() populates
# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field
#
# ExpandedWeight has a fallback that always fails since we cannot know what the batch
# dimension of the input tensor is and therefore cannot know if this is a valid call
#
# This is a __torch_function__ object but it could have also been a Tensor Extension
# with a dispatch key.
#
# Needs to be a tensor subclass to allow reparamaterization
class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.batch_first = True
self.allow_smaller_batches = False
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
handled_functions = HANDLED_FUNCTIONS
def __new__(cls, orig_weight, batch_size, loss_reduction):
if not isinstance(orig_weight, torch.Tensor):
raise RuntimeError(
f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}"
)
if not orig_weight.requires_grad:
raise RuntimeError(
"Can only build ExpandedWeights objects of tensors that require_grad"
)
ret = torch.Tensor._make_subclass(cls, orig_weight, True)
return ret
@classmethod
def __torch_function__(cls, func, _, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in expanded_weights_rnn_decomps:
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance(
args[2], list
) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
if decomp is not None:
with setup_rnn(use_input_variant, args, kwargs):
return decomp(*args, **kwargs)
if func == torch._cudnn_rnn_flatten_weight:
# since we aren't using the fused cuda kernels for RNNs, don't do this
return
if func in cls.handled_functions:
return cls.handled_functions[func].apply(
tuple(kwargs.keys()), func, *(args + tuple(kwargs.values()))
)
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
# i.e. torch.add(torch.Tensor, ExpandedWeight)
raise RuntimeError(
f"Expanded Weights encountered but cannot handle function {func.__name__}"
)
@property
def dtype(self):
return self.orig_weight.dtype
@property
def data(self):
return self.orig_weight.data
@property
def shape(self):
return self.orig_weight.shape
@property
def device(self):
return self.orig_weight.device
@property
def is_cuda(self):
return self.orig_weight.is_cuda
def data_ptr(self):
return self.orig_weight.data_ptr()
def get_device(self):
return self.orig_weight.get_device()
def set_allow_smaller_batches(self, is_allow_smaller_batches):
self.allow_smaller_batches = is_allow_smaller_batches
def set_batch_first(self, is_batch_first=True):
self.batch_first = is_batch_first

View File

@ -0,0 +1,188 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
from .expanded_weights_impl import ExpandedWeight
def is_batch_first(expanded_args_and_kwargs):
batch_first = None
for arg in expanded_args_and_kwargs:
if not isinstance(arg, ExpandedWeight):
continue
if not batch_first:
batch_first = arg.batch_first
elif arg.batch_first != batch_first:
raise RuntimeError(
"Got conflicting batch_first arguments in the same layer"
)
return batch_first
def standard_kwargs(kwarg_names, expanded_args):
r"""Separate args and kwargs from `__torch_function__`s that standardize kwargs.
Most `__torch_function__`s standardize the kwargs that they give, so this will separate
the args and kwargs they pass. Functions that don't are linear and convND.
"""
kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names) :]
expanded_args_without_kwargs = expanded_args[
: len(expanded_args) - len(kwarg_names)
]
expanded_kwargs = dict(zip(kwarg_names, kwarg_values))
return expanded_args_without_kwargs, expanded_kwargs
def forward_helper(func, expanded_args, expanded_kwargs):
r"""Compute the forward pass for a function that has expanded weight(s) passed to it.
It will run the forward pass where all ExpandedWeights are their original
weight. It runs checks on the given arguments and detaches the outputs.
.. note:: First argument in :attr:`expanded_args` must be the input with the batch
dimension as the first element of the shape
.. note:: :attr:`func` must return a Tensor or tuple of Tensors
Args:
func: The function to be called
expanded_args: Arguments to be passed to :attr:`func`. Will include arguments
that need to be unpacked because they are ExpandedWeights
expanded_kwargs: Keyword arguments to be passed to :attr:`func`.
Similar to :attr:`expanded_args`.
"""
unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args(
func, expanded_args, expanded_kwargs
)
return func(*unexpanded_args, **unexpanded_kwargs)
def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
# input must be the first argument passed
input = expanded_args[0]
if isinstance(input, ExpandedWeight):
raise RuntimeError(
"Expanded Weights do not support inputs that are also ExpandedWeights. "
f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}"
)
if not isinstance(input, torch.Tensor):
raise RuntimeError(
"Expanded Weights requires a Tensor as the first input to get the batch dimension, "
f"got {type(input).__name__} in function {func.__name__}"
)
if len(input.shape) == 0:
raise RuntimeError(
f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}"
)
if input.shape[0] == 0:
raise RuntimeError(
"0 is not a valid batch size for Expanded Weights but got input tensor of "
f"{input} in function {func.__name__}"
)
for arg in expanded_args + tuple(expanded_kwargs.values()):
if not isinstance(arg, ExpandedWeight):
continue
batch_size = input.shape[0] if arg.batch_first else input.shape[1]
if (arg.allow_smaller_batches and batch_size > arg.batch_size) or (
not arg.allow_smaller_batches and arg.batch_size != batch_size
):
raise RuntimeError(
"Expected ExpandedWeights to have batch size matching input but got "
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}"
)
loss_reduction: Optional[str] = None
for arg in expanded_args + tuple(expanded_kwargs.values()):
if isinstance(arg, ExpandedWeight):
if loss_reduction is None:
loss_reduction = arg.loss_reduction
elif loss_reduction != arg.loss_reduction:
raise RuntimeError(
"Expected ExpandedWeights to all have the same loss_reduction argument but got one"
f"with {loss_reduction} and one with {arg.loss_reduction}"
)
unexpanded_args = tuple(
arg.orig_weight if isinstance(arg, ExpandedWeight) else arg
for arg in expanded_args
)
unexpanded_kwargs = {
name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg
for (name, arg) in expanded_kwargs.items()
}
return unexpanded_args, unexpanded_kwargs
def maybe_scale_by_batch_size(grad_sample, expanded_weight):
if expanded_weight.loss_reduction == "mean":
return grad_sample * expanded_weight.batch_size
else:
return grad_sample
def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn):
unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight)
if isinstance(maybe_expanded_weight, ExpandedWeight):
grad_sample_contribution = maybe_scale_by_batch_size(
per_sample_grad_fn(unpacked), maybe_expanded_weight
)
if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]:
# this only passes the other checks if the arg allows smaller batch sizes
intermediate = torch.zeros(
maybe_expanded_weight.batch_size,
*grad_sample_contribution.shape[1:],
dtype=grad_sample_contribution.dtype,
device=grad_sample_contribution.device,
)
intermediate[: grad_sample_contribution.shape[0]] = grad_sample_contribution
grad_sample_contribution = intermediate
if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None:
unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution
else:
unpacked.grad_sample = grad_sample_contribution
def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x):
if isinstance(maybe_expanded_weight, ExpandedWeight):
orig_weight = maybe_expanded_weight.orig_weight
return func(orig_weight)
elif (
isinstance(maybe_expanded_weight, torch.Tensor)
and not maybe_expanded_weight.requires_grad
):
return func(maybe_expanded_weight)
elif isinstance(maybe_expanded_weight, torch.Tensor):
raise RuntimeError(
"ExpandedWeights currently does not support a mixture of ExpandedWeight parameters "
"and normal Parameters. Please file and issue with pytorch/pytorch"
)
def sum_over_all_but_batch_and_last_n(
tensor: torch.Tensor,
n_dims: int,
) -> torch.Tensor:
r"""
Calculate the sum over all dimensions, except the first (batch dimension), and excluding the last n_dims.
This function will ignore the first dimension and it will
not aggregate over the last n_dims dimensions.
Args:
tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``.
n_dims: Number of dimensions to keep.
Example:
>>> tensor = torch.ones(1, 2, 3, 4, 5)
>>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape
torch.Size([1, 4, 5])
Returns:
A tensor of shape ``(B, ..., X[n_dims-1])``
"""
if tensor.dim() == n_dims + 1:
return tensor
else:
dims = list(range(1, tensor.dim() - n_dims))
return tensor.sum(dim=dims)

View File

@ -0,0 +1,104 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
set_grad_sample_if_exists,
standard_kwargs,
unpack_expanded_weight_or_tensor,
)
@implements_per_sample_grads(F.group_norm)
class GroupNormPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
input, num_groups = expanded_args
N = input.shape[0]
C = input.shape[1]
HxW = reduce(operator.mul, input.shape[2:], 1)
weight, bias, eps = (
expanded_kwargs["weight"],
expanded_kwargs["bias"],
expanded_kwargs["eps"],
)
output, mean, rstd = forward_helper(
torch.native_group_norm,
(input, weight, bias, N, C, HxW, num_groups, eps),
{},
)
ctx.input, ctx.num_groups = input, num_groups
ctx.weight, ctx.eps = weight, eps
ctx.mean, ctx.rstd = mean, rstd
if isinstance(bias, ExpandedWeight):
ctx.bias = bias
if input.requires_grad and isinstance(weight, ExpandedWeight):
ctx.weight = weight
return output
@staticmethod
def backward(ctx, grad_output):
input, num_groups = ctx.input, ctx.num_groups
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
mean, rstd = ctx.mean, ctx.rstd
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:
weight_c = unpack_expanded_weight_or_tensor(
weight, lambda t: t.contiguous()
)
input_c = input.contiguous()
grad_output_c = (
grad_output.contiguous() if grad_output is not None else None
)
N = input.shape[0]
C = input.shape[1]
HxW = 1
for s in input.shape[2:]:
HxW *= s
bw_fn = torch.ops.aten.native_group_norm_backward
results.append(
bw_fn(
grad_output_c,
input_c,
mean,
rstd,
weight_c,
N,
C,
HxW,
num_groups,
(True, False, False),
)[0]
)
else:
results.append(None)
# weight and bias don't compute batched gradients; no other arguments are differentiable
results = results + [None] * 4
# set grad_sample field for weight and bias with per sample gradients
if hasattr(ctx, "weight"):
set_grad_sample_if_exists(
weight,
lambda _: torch.einsum(
"ni...->ni", F.group_norm(input, num_groups, eps=eps) * grad_output
),
)
if hasattr(ctx, "bias"):
set_grad_sample_if_exists(
bias, lambda _: torch.einsum("ni...->ni", grad_output)
)
return tuple(results)

View File

@ -0,0 +1,100 @@
# mypy: allow-untyped-defs
from functools import partial
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
set_grad_sample_if_exists,
standard_kwargs,
unpack_expanded_weight_or_tensor,
)
@implements_per_sample_grads(F.instance_norm)
class InstanceNormPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
output = forward_helper(instance_norm, expanded_args, expanded_kwargs)
ctx.input = expanded_args[0]
ctx.running_mean, ctx.running_var = (
expanded_kwargs["running_mean"],
expanded_kwargs["running_var"],
)
ctx.weight, ctx.bias, ctx.eps = (
expanded_kwargs["weight"],
expanded_kwargs["bias"],
expanded_kwargs["eps"],
)
return output
@staticmethod
def backward(ctx, grad_output):
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:
b = input.shape[0]
c = input.shape[1]
new_shape = (1, b * c, *input.shape[2:])
weight_ = unpack_expanded_weight_or_tensor(
weight, lambda orig_weight: orig_weight.repeat(b)
)
running_mean_ = running_mean.repeat(b) if running_mean is not None else None
running_var_ = running_var.repeat(b) if running_var is not None else None
input_reshaped = input.contiguous().view(new_shape)
grad_output_reshaped = grad_output.contiguous().view(new_shape)
mean = torch.mean(
input_reshaped, (0,) + tuple(range(2, input.dim())), False
)
var = torch.var(
input_reshaped,
(0,) + tuple(range(2, input.dim())),
keepdim=False,
unbiased=False,
)
rstd = 1 / torch.sqrt(var + eps)
# must use native batch norm since it supports all inputs. This may have used cuda or openmi during the forward but
# it didn't save the metadata, so we don't know during the backward
res = torch.ops.aten.native_batch_norm_backward(
grad_output_reshaped,
input_reshaped,
weight_,
running_mean_,
running_var_,
mean,
rstd,
True,
eps,
(True, False, False),
)
results.append(res[0].reshape(input.shape))
else:
results.append(None)
# weight and bias don't compute batched gradients; no other arguments are differentiable (2 are not saved from the forward)
results = results + [None] * 7
# set grad_sample field for weight and bias with per sample gradients
set_grad_sample_if_exists(
weight,
lambda _: torch.einsum(
"ni...->ni", F.instance_norm(input, eps=eps) * grad_output
),
)
set_grad_sample_if_exists(
bias, lambda _: torch.einsum("ni...->ni", grad_output)
)
return tuple(results)

View File

@ -0,0 +1,87 @@
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import ExpandedWeight, implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
set_grad_sample_if_exists,
standard_kwargs,
sum_over_all_but_batch_and_last_n,
unpack_expanded_weight_or_tensor,
)
@implements_per_sample_grads(F.layer_norm)
class LayerNormPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
input = expanded_args[0]
normalized_shape = expanded_args[1]
if len(input.shape) <= len(normalized_shape):
raise RuntimeError(
"Expanded Weights: Layer norm should not normalize over batch dimension for per sample gradient"
f"computations but got that normalized shape, {normalized_shape}, matched input shape."
)
output, mean, rstd = forward_helper(
torch.native_layer_norm, expanded_args, expanded_kwargs
)
ctx.args = expanded_args
if input.requires_grad or isinstance(expanded_kwargs["weight"], ExpandedWeight):
ctx.weight = expanded_kwargs["weight"]
if input.requires_grad or isinstance(expanded_kwargs["bias"], ExpandedWeight):
ctx.bias = expanded_kwargs["bias"]
ctx.eps = expanded_kwargs["eps"]
ctx.mean, ctx.rstd = mean, rstd
return output
@staticmethod
def backward(ctx, grad_output):
def weight_per_sample_grad(weight):
return sum_over_all_but_batch_and_last_n(
F.layer_norm(input, normalized_shape, eps=ctx.eps) * grad_output,
weight.dim(),
)
input, normalized_shape = ctx.args
mean, rstd = ctx.mean, ctx.rstd
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:
weight_ = unpack_expanded_weight_or_tensor(ctx.weight)
bias_ = unpack_expanded_weight_or_tensor(ctx.bias)
results.append(
torch.ops.aten.native_layer_norm_backward(
grad_output,
input,
normalized_shape,
mean,
rstd,
weight_,
bias_,
(True, False, False),
)[0]
)
else:
results.append(None)
# weight and bias don't compute batched gradients; no other arguments are differentiable
results = results + [None] * 4
# set grad_sample field for weight and bias with per sample gradients
if hasattr(ctx, "weight"):
set_grad_sample_if_exists(ctx.weight, weight_per_sample_grad)
if hasattr(ctx, "bias"):
set_grad_sample_if_exists(
ctx.bias,
lambda bias: sum_over_all_but_batch_and_last_n(grad_output, bias.dim()),
)
return tuple(results)

View File

@ -0,0 +1,62 @@
# mypy: allow-untyped-defs
from typing import List, Optional
import torch
import torch.nn.functional as F
from .expanded_weights_impl import implements_per_sample_grads
from .expanded_weights_utils import (
forward_helper,
is_batch_first,
set_grad_sample_if_exists,
unpack_expanded_weight_or_tensor,
)
@implements_per_sample_grads(F.linear)
class LinearPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, _, __, *expanded_args_and_kwargs):
if len(expanded_args_and_kwargs[0].shape) <= 1:
raise RuntimeError(
"Input does not have a batch dimension. Expanded Weights expected input "
f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}"
)
expanded_kwargs = {
"bias": expanded_args_and_kwargs[2]
if len(expanded_args_and_kwargs) == 3
else None
}
expanded_args = expanded_args_and_kwargs[:2]
ctx.batch_first = is_batch_first(expanded_args_and_kwargs)
output = forward_helper(F.linear, expanded_args, expanded_kwargs)
ctx.args = expanded_args
ctx.kwargs = expanded_kwargs
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.args
bias = ctx.kwargs["bias"]
results: List[Optional[torch.Tensor]] = []
results.append(None) # for kwarg_names
results.append(None) # for op reference
if input.requires_grad:
results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight)))
else:
results.append(None)
results.extend([None] * 2) # weight and bias don't compute batched gradients
if not ctx.batch_first:
grad_output = grad_output.transpose(0, 1)
input = input.transpose(0, 1)
# weight and bias get their grad_sample fields set directly if they exist
set_grad_sample_if_exists(
weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input)
)
set_grad_sample_if_exists(
bias, lambda _: torch.einsum("n...k->nk", grad_output)
)
return tuple(results)

View File

@ -0,0 +1,372 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Iterable, List, Tuple
import torch
_MISSING: torch.Tensor = object() # type: ignore[assignment]
def set_tensor(module: "torch.nn.Module", name: str, tensor: torch.Tensor) -> None:
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if not isinstance(tensor, torch.Tensor) and tensor is not None:
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
if "." in name:
raise KeyError('tensor name can\'t contain "."')
if name == "":
raise KeyError('tensor name can\'t be empty string ""')
if name in module._parameters:
module._parameters[name] = tensor # type: ignore[assignment]
elif name in module._buffers:
module._buffers[name] = tensor
else:
setattr(module, name, tensor)
def swap_tensor(
module: "torch.nn.Module",
name: str,
tensor: torch.Tensor,
allow_missing: bool = False,
) -> torch.Tensor:
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if (
tensor is not _MISSING
and not isinstance(tensor, torch.Tensor)
and tensor is not None
):
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
if "." in name:
raise KeyError('tensor name can\'t contain "."')
if name == "":
raise KeyError('tensor name can\'t be empty string ""')
orig_tensor: torch.Tensor
if name in module._parameters:
orig_tensor = module._parameters[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._parameters[name] = tensor # type: ignore[assignment]
else:
del module._parameters[name]
elif name in module._buffers:
orig_tensor = module._buffers[name] # type: ignore[assignment]
if tensor is not _MISSING:
module._buffers[name] = tensor
else:
del module._buffers[name]
else:
if hasattr(module, name):
orig_tensor = getattr(module, name)
else:
if not allow_missing:
raise AttributeError(f"{module._get_name()} has no attribute `{name}`")
orig_tensor = _MISSING
if (
orig_tensor is not _MISSING
and not isinstance(orig_tensor, torch.Tensor)
and orig_tensor is not None
):
raise TypeError(
f"attribute `{name}`: {orig_tensor} is not an instance of torch.Tensor"
)
if tensor is not _MISSING:
setattr(module, name, tensor)
elif hasattr(module, name):
delattr(module, name)
return orig_tensor
def swap_submodule(
module: "torch.nn.Module",
name: str,
submodule: "torch.nn.Module",
) -> "torch.nn.Module":
if not isinstance(module, torch.nn.Module):
raise TypeError(f"{module} is not an instance of torch.nn.Module")
if not isinstance(submodule, torch.nn.Module):
raise TypeError(f"{submodule} is not an instance of torch.nn.Module")
if "." in name:
raise KeyError('submodule name can\'t contain "."')
if name == "":
raise KeyError('submodule name can\'t be empty string ""')
if name not in module._modules:
raise KeyError(f"submodule {name} does not exist")
orig_submodule = module._modules[name]
if not isinstance(orig_submodule, torch.nn.Module):
raise TypeError(f"{name} attribute is not an instance of torch.nn.Module")
module._modules[name] = submodule
return orig_submodule
class NamedMemberAccessor:
"""
A class that provides a way to access the submodules and parameters/buffers of a module.
It provides caching mechanism to speed up submodule lookups.
This is useful for functional programming to manipulate the module state.
"""
def __init__(self, module: "torch.nn.Module") -> None:
self.module = module
self.memo: Dict[str, torch.nn.Module] = {}
# Nested attribute access
def get_submodule(self, name: str) -> "torch.nn.Module":
"""
Return the submodule specified by the given path.
For example, to get the submodule mod.layer1.conv1,
use accessor.get_submodule("layer1.conv1")
Compare to mod.get_submodule("layer1.conv1"), this method will cache the
intermediate submodule access to speed up future lookups.
"""
if not name:
return self.module
if name in self.memo:
return self.memo[name]
else:
prefix, dot, attr = name.rpartition(".")
if dot:
module = self.get_submodule(prefix)
else:
module = self.module
try:
submodule = getattr(module, attr)
except AttributeError as ex:
raise AttributeError(
f"{module._get_name()} has no attribute `{attr}`"
) from ex
if not isinstance(submodule, torch.nn.Module):
raise TypeError( # noqa: B904
f"submodule `{name}`: {submodule} is not an instance of torch.nn.Module"
)
self.memo[name] = submodule
return submodule
def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module":
"""
Swap the submodule specified by the given ``path`` to ``value``.
For example, to swap the attribute mod.layer1.conv1 use
``accessor.swap_submodule("layer1.conv1", conv2)``.
"""
prefix, _, attr = path.rpartition(".")
return swap_submodule(self.get_submodule(prefix), attr, value)
def get_tensor(self, name: str) -> torch.Tensor:
"""
Get the tensor specified by the given path to value.
For example, to get the attribute mod.layer1.conv1.weight,
use accessor.get_tensor('layer1.conv1.weight')
Compare to mod.get_parameter("layer1.conv1.weight"), this method will
cache the intermediate submodule access to speed up future lookups.
"""
prefix, _, attr = name.rpartition(".")
submodule = self.get_submodule(prefix)
try:
tensor = getattr(submodule, attr)
except AttributeError as ex:
raise AttributeError(
f"{submodule._get_name()} has no attribute `{name}`"
) from ex
if not isinstance(tensor, torch.Tensor) and tensor is not None:
raise TypeError(f"{tensor} is not an instance of torch.Tensor")
return tensor # type: ignore[return-value]
def set_tensor(self, name: str, value: torch.Tensor) -> None:
"""
Set the attribute specified by the given path to value.
For example, to set the attribute mod.layer1.conv1.weight,
use accessor.set_tensor("layer1.conv1.weight", value)
"""
prefix, _, attr = name.rpartition(".")
set_tensor(self.get_submodule(prefix), attr, value)
def del_tensor(self, name: str) -> None:
"""
Delete the attribute specified by the given path.
For example, to delete the attribute mod.layer1.conv1.weight,
use accessor.del_tensor("layer1.conv1.weight")
"""
prefix, _, attr = name.rpartition(".")
submodule = self.get_submodule(prefix)
try:
delattr(submodule, attr)
except AttributeError as ex:
raise AttributeError(
f"{submodule._get_name()} has no attribute `{name}`"
) from ex
def swap_tensor(
self, name: str, value: torch.Tensor, allow_missing: bool = False
) -> torch.Tensor:
"""
Swap the attribute specified by the given path to value.
For example, to swap the attribute mod.layer1.conv1.weight,
use accessor.swap_tensor("layer1.conv1.weight", value)
"""
prefix, _, attr = name.rpartition(".")
return swap_tensor(
self.get_submodule(prefix), attr, value, allow_missing=allow_missing
)
# Batched operations
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
"""
Get the tensors specified by the given paths.
For example, to get the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.get_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"])
"""
return [self.get_tensor(name) for name in names]
def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> None:
"""
Set the attributes specified by the given paths to values.
For example, to set the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.set_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"], [weight, bias])
"""
if not isinstance(names, (list, tuple)):
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
for name, value in zip(names, values):
self.set_tensor(name, value)
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
"""
Set the attributes specified by the given paths to values.
For example, to set the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.set_tensors_dict({
"layer1.conv1.weight": weight,
"layer1.conv1.bias": bias,
})
"""
for name, value in named_tensors.items():
self.set_tensor(name, value)
def del_tensors(self, names: Iterable[str]) -> None:
"""
Delete the attributes specified by the given paths.
For example, to delete the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.del_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"])
"""
for name in names:
self.del_tensor(name)
def swap_tensors(
self,
names: Iterable[str],
values: Iterable[torch.Tensor],
allow_missing: bool = False,
) -> List[torch.Tensor]:
"""
Swap the attributes specified by the given paths to values.
For example, to swap the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.swap_tensors(["layer1.conv1.weight",
"layer1.conv1.bias"], [weight, bias])
"""
if not isinstance(names, (list, tuple)):
names = list(names)
if not isinstance(values, (list, tuple)):
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
return [
self.swap_tensor(name, value, allow_missing=allow_missing)
for name, value in zip(names, values)
]
def swap_tensors_dict(
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
"""
Swap the attributes specified by the given paths to values.
For example, to swap the attributes mod.layer1.conv1.weight and
mod.layer1.conv1.bias, use accessor.swap_tensors_dict({
"layer1.conv1.weight": weight,
"layer1.conv1.bias": bias,
})
"""
orig_named_tensors = {}
missing_keys = []
try:
for name, tensor in named_tensors.items():
orig_tensor = self.swap_tensor(name, tensor, allow_missing=True)
if orig_tensor is _MISSING:
missing_keys.append(name)
orig_named_tensors[name] = orig_tensor
except Exception:
# Swap back if any exception occurs
for name, orig_tensor in orig_named_tensors.items():
self.swap_tensor(name, orig_tensor, allow_missing=True)
raise
if missing_keys and not allow_missing:
# Swap back if any key is missing when allow_missing is False
for name, orig_tensor in orig_named_tensors.items():
self.swap_tensor(name, orig_tensor, allow_missing=True)
raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
return orig_named_tensors, missing_keys
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
"""Check that the given keys are valid."""
keys = set(keys)
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
missing_keys = valid_keys - keys
unexpected_keys = keys - valid_keys
return sorted(missing_keys), sorted(unexpected_keys)
# Shortcut methods
def named_parameters(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""Iterate over all the parameters in the module."""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
def named_buffers(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""Iterate over all the buffers in the module."""
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
def named_tensors(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
"""Iterate over all the tensors in the module."""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
def named_modules(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
"""Iterate over all the modules in the module."""
yield from self.module.named_modules(remove_duplicate=remove_duplicate)

View File

@ -0,0 +1,124 @@
# mypy: allow-untyped-defs
import functools
import torch
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
from torch.utils import _pytree as pytree
# dependency on `functional_call` means that this can't be exposed in utils
# without creating circular dependency
def call_for_per_sample_grads(
module,
*,
batch_size=None,
loss_reduction="sum",
batch_first=True,
):
r"""
Return a forward function for a module, populating grad_sample with per sample gradients on backward invocation.
Args:
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
parameters will compute per sample gradients, located in a ``grad_sample``
field when ``backward`` is invoked
batch_size: The batch size of the input. If None is passed, all tensor arguments in args and kwargs must have
the same batch size, which is the size of the first dimension. Otherwise, it must be passed manually.
Default: None
loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. If
"mean", per sample gradients will be scaled by the batch size to offset the crossbatch interaction from
running mean across a batch. Must be "mean" or "sum". Default: "sum"
batch_first: Indicates if the batch dimension is the first dimension. If True, the batch dimension is the first
dimension. If False, it's the second dimension. Default: True.
Examples::
>>> # xdoctest: +SKIP
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model)(batched_input).sum()
>>> res.backward()
>>> assert model.weight.shape == (3, 4)
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
>>> assert model.weight.grad is None
>>> assert model.bias.shape == (3,)
>>> assert model.bias.grad_sample.shape == (5, 3)
>>> assert model.bias.grad is None
An example using "mean" loss reduction. The grad_sample fields will be scaled by batch_size from what they would be
if we ran the same code with loss_reduction="sum". This is because the mean at the end will scale all
grad_outputs by 1 / batch_size from cross batch interaction.
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
>>> res.backward()
Note::
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
rewrites that wrap an `nn.Linear` module. See Opacus for an example
"""
def maybe_build_expanded_weight(og_tensor, batch_size):
if og_tensor.requires_grad:
return ExpandedWeight(og_tensor, batch_size, loss_reduction)
else:
return og_tensor
def compute_batch_size(*args, **kwargs):
args_and_kwargs = pytree.arg_tree_leaves(*args, **kwargs)
batch_size = None
for arg in args_and_kwargs:
if not isinstance(arg, torch.Tensor):
continue
arg_batch_size = arg.shape[0] if batch_first else arg.shape[1]
if batch_size is not None and batch_size != arg_batch_size:
raise RuntimeError(
"When computing batch size, found at least one input with batch size "
f"{batch_size} and one with batch size {arg_batch_size}. Please specify it "
"explicitly using the batch size kwarg in call_for_per_sample_grads"
)
batch_size = arg_batch_size
if batch_size is None:
raise RuntimeError(
"Unable to find a tensor in the passed args and kwargs. They may not be pytree-able "
"and so ExpandedWeights cannot compute the batch size from the inputs. Please specify "
"it explicitly"
)
return batch_size
if loss_reduction not in ["sum", "mean"]:
raise RuntimeError(
f"Expected loss_reduction argument to be sum or mean, got {loss_reduction}"
)
if not isinstance(module, torch.nn.Module):
raise RuntimeError(
f"Module passed must be nn.Module, got {type(module).__name__}"
)
if not (batch_size is None or isinstance(batch_size, int)):
raise RuntimeError(
f"Batch size passed must be None or an integer, got {type(batch_size).__name__}"
)
if batch_size is not None and batch_size < 1:
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
for weight in module.parameters():
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
raise RuntimeError(
"Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
"post an issue to pytorch/pytorch to prioritize correct behavior"
)
@functools.wraps(module.forward)
def wrapper(*args, **kwargs):
wrapper_batch_size = batch_size
if wrapper_batch_size is None:
wrapper_batch_size = compute_batch_size(*args, **kwargs)
params = {
name: maybe_build_expanded_weight(value, wrapper_batch_size)
for (name, value) in module.named_parameters()
}
return torch.func.functional_call(module, params, args, kwargs)
return wrapper

View File

@ -0,0 +1,189 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
from typing_extensions import deprecated
import torch
from torch import Tensor
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def _no_grad(func):
"""
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
clip_grad_norm_ and clip_grad_value_ themselves.
"""
def _no_grad_wrapper(*args, **kwargs):
with torch.no_grad():
return func(*args, **kwargs)
functools.update_wrapper(_no_grad_wrapper, func)
return _no_grad_wrapper
@_no_grad
def clip_grad_norm_(
parameters: _tensor_or_tensors,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Clip the gradient norm of an iterable of parameters.
The norm is computed over the norms of the individual gradients of all parameters,
as if the norms of the individual gradients were concatenated into a single vector.
Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm
@deprecated(
"`torch.nn.utils.clip_grad_norm` is now deprecated "
"in favor of `torch.nn.utils.clip_grad_norm_`.",
category=FutureWarning,
)
def clip_grad_norm(
parameters: _tensor_or_tensors,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Clip the gradient norm of an iterable of parameters.
.. warning::
This method is now deprecated in favor of
:func:`torch.nn.utils.clip_grad_norm_`.
"""
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
@_no_grad
def clip_grad_value_(
parameters: _tensor_or_tensors,
clip_value: float,
foreach: Optional[bool] = None,
) -> None:
r"""Clip the gradients of an iterable of parameters at specified value.
Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
clip_value (float): maximum allowed value of the gradients.
The gradients are clipped in the range
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
foreach (bool): use the faster foreach-based implementation
If ``None``, use the foreach implementation for CUDA and CPU native tensors and
silently fall back to the slow implementation for other device types.
Default: ``None``
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
clip_value = float(clip_value)
grads = [p.grad for p in parameters if p.grad is not None]
grouped_grads = _group_tensors_by_device_and_dtype([grads])
for (device, _), ([grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (
foreach is None
and _has_foreach_support(cast(List[Tensor], grads), device=device)
) or (foreach and _device_has_foreach_support(device)):
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value)
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value)
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
for grad in grads:
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)

View File

@ -0,0 +1,90 @@
from typing import Iterable, Optional
import torch
def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
r"""Flatten an iterable of parameters into a single vector.
Args:
parameters (Iterable[Tensor]): an iterable of Tensors that are the
parameters of a model.
Returns:
The parameters represented by a single vector
"""
# Flag for the device where the parameter is located
param_device = None
vec = []
for param in parameters:
# Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device)
vec.append(param.view(-1))
return torch.cat(vec)
def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
r"""Copy slices of a vector into an iterable of parameters.
Args:
vec (Tensor): a single vector representing the parameters of a model.
parameters (Iterable[Tensor]): an iterable of Tensors that are the
parameters of a model.
"""
# Ensure vec of type Tensor
if not isinstance(vec, torch.Tensor):
raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}")
# Flag for the device where the parameter is located
param_device = None
# Pointer for slicing the vector for each parameter
pointer = 0
for param in parameters:
# Ensure the parameters are located in the same device
param_device = _check_param_device(param, param_device)
# The length of the parameter
num_param = param.numel()
# Slice the vector, reshape it, and replace the old data of the parameter
param.data = vec[pointer : pointer + num_param].view_as(param).data
# Increment the pointer
pointer += num_param
def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
r"""Check if the parameters are located on the same device.
Currently, the conversion between model parameters and single vector form is not supported
for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1.
Args:
param ([Tensor]): a Tensor of a parameter of a model
old_param_device (int): the device where the first parameter of a
model is allocated.
Returns:
old_param_device (int): report device for the first time
"""
# Meet the first parameter
support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()]
if old_param_device is None:
old_param_device = (
param.get_device() if param.device.type in support_device_types else -1
)
else:
warn = False
if (
param.device.type in support_device_types
): # Check if in same GPU/PrivateUse1
warn = param.get_device() != old_param_device
else: # Check if in CPU
warn = old_param_device != -1
if warn:
raise TypeError(
"Found two parameters on different devices, "
"this is currently not supported."
)
return old_param_device

View File

@ -0,0 +1,190 @@
from __future__ import annotations
import copy
from typing import Optional, Tuple, TypeVar
import torch
__all__ = [
"fuse_conv_bn_eval",
"fuse_conv_bn_weights",
"fuse_linear_bn_eval",
"fuse_linear_bn_weights",
]
ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
def fuse_conv_bn_eval(
conv: ConvT,
bn: torch.nn.modules.batchnorm._BatchNorm,
transpose: bool = False,
) -> ConvT:
r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
Args:
conv (torch.nn.modules.conv._ConvNd): A convolutional module.
bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False.
Returns:
torch.nn.modules.conv._ConvNd: The fused convolutional module.
.. note::
Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (conv.training or bn.training), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
assert bn.running_mean is not None and bn.running_var is not None
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight,
fused_conv.bias,
bn.running_mean,
bn.running_var,
bn.eps,
bn.weight,
bn.bias,
transpose,
)
return fused_conv
def fuse_conv_bn_weights(
conv_w: torch.Tensor,
conv_b: Optional[torch.Tensor],
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: Optional[torch.Tensor],
bn_b: Optional[torch.Tensor],
transpose: bool = False,
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
Args:
conv_w (torch.Tensor): Convolutional weight.
conv_b (Optional[torch.Tensor]): Convolutional bias.
bn_rm (torch.Tensor): BatchNorm running mean.
bn_rv (torch.Tensor): BatchNorm running variance.
bn_eps (float): BatchNorm epsilon.
bn_w (Optional[torch.Tensor]): BatchNorm weight.
bn_b (Optional[torch.Tensor]): BatchNorm bias.
transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
"""
conv_weight_dtype = conv_w.dtype
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
if transpose:
shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
else:
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(
dtype=conv_weight_dtype
)
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(
dtype=conv_bias_dtype
)
return (
torch.nn.Parameter(fused_conv_w, conv_w.requires_grad),
torch.nn.Parameter(fused_conv_b, conv_b.requires_grad),
)
def fuse_linear_bn_eval(
linear: LinearT,
bn: torch.nn.modules.batchnorm._BatchNorm,
) -> LinearT:
r"""Fuse a linear module and a BatchNorm module into a single, new linear module.
Args:
linear (torch.nn.Linear): A Linear module.
bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
Returns:
torch.nn.Linear: The fused linear module.
.. note::
Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (linear.training or bn.training), "Fusion only for eval!"
fused_linear = copy.deepcopy(linear)
"""
Linear-BN needs to be fused while preserving the shapes of linear weight/bias.
To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear,
because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in).
To be broadcastable, the number of features in bn and
the number of output features from linear must satisfy the following condition:
1. they are equal, or
2. the number of features in bn is 1
Otherwise, skip the folding path
"""
assert (
linear.out_features == bn.num_features or bn.num_features == 1
), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
assert bn.running_mean is not None and bn.running_var is not None
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
fused_linear.weight,
fused_linear.bias,
bn.running_mean,
bn.running_var,
bn.eps,
bn.weight,
bn.bias,
)
return fused_linear
def fuse_linear_bn_weights(
linear_w: torch.Tensor,
linear_b: Optional[torch.Tensor],
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: torch.Tensor,
bn_b: torch.Tensor,
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
Args:
linear_w (torch.Tensor): Linear weight.
linear_b (Optional[torch.Tensor]): Linear bias.
bn_rm (torch.Tensor): BatchNorm running mean.
bn_rv (torch.Tensor): BatchNorm running variance.
bn_eps (float): BatchNorm epsilon.
bn_w (torch.Tensor): BatchNorm weight.
bn_b (torch.Tensor): BatchNorm bias.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
"""
linear_weight_dtype = linear_w.dtype
linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
if linear_b is None:
linear_b = torch.zeros_like(bn_rm)
bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
fused_b, linear_b.requires_grad
)

View File

@ -0,0 +1,55 @@
# mypy: allow-untyped-defs
import inspect
import torch
def skip_init(module_cls, *args, **kwargs):
r"""
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
This can be useful if initialization is slow or if custom initialization will
be performed, making the default initialization unnecessary. There are some caveats to this, due to
the way this function is implemented:
1. The module must accept a `device` arg in its constructor that is passed to any parameters
or buffers created during construction.
2. The module must not perform any computation on parameters in its constructor except
initialization (i.e. functions from :mod:`torch.nn.init`).
If these conditions are satisfied, the module can be instantiated with parameter / buffer values
uninitialized, as if having been created using :func:`torch.empty`.
Args:
module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
args: args to pass to the module's constructor
kwargs: kwargs to pass to the module's constructor
Returns:
Instantiated module with uninitialized parameters / buffers
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
Parameter containing:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
Parameter containing:
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
4.5915e-41]], requires_grad=True)
"""
if not issubclass(module_cls, torch.nn.Module):
raise RuntimeError(f"Expected a Module; got {module_cls}")
if "device" not in inspect.signature(module_cls).parameters:
raise RuntimeError("Module must support a 'device' arg to skip initialization")
final_device = kwargs.pop("device", "cpu")
kwargs["device"] = "meta"
return module_cls(*args, **kwargs).to_empty(device=final_device)

View File

@ -0,0 +1,152 @@
# mypy: allow-untyped-defs
import torch
def convert_conv2d_weight_memory_format(module, memory_format):
r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``.
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
than the utility function ``convert_conv2d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NHWC(channels_last) conversion for
convolution in cuDNN, as it is beneficial to run convolution in NHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format).
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv2d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
>>> model = nn.Sequential(
>>> nn.Conv2d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
>>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
weight_data = (
module.weight.detach().clone().contiguous(memory_format=memory_format)
)
module.weight.data = weight_data.resize_(
weight_data.size(), memory_format=memory_format
)
for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format)
return module
def convert_conv3d_weight_memory_format(module, memory_format):
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive
than the utility function ``convert_conv3d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NDHWC(channels_last_3d) conversion for
convolution in cuDNN, as it is beneficial to run convolution in NDHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last_3d. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format).
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv3d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
>>> model = nn.Sequential(
>>> nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
weight_data = (
module.weight.detach().clone().contiguous(memory_format=memory_format)
)
module.weight.data = weight_data.resize_(
weight_data.size(), memory_format=memory_format
)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
return module

View File

@ -0,0 +1,628 @@
# mypy: allow-untyped-defs
from enum import auto, Enum
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.utils import parametrize
__all__ = ["orthogonal", "spectral_norm", "weight_norm"]
def _is_orthogonal(Q, eps=None):
n, k = Q.size(-2), Q.size(-1)
Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
# A reasonable eps, but not too large
eps = 10.0 * n * torch.finfo(Q.dtype).eps
return torch.allclose(Q.mH @ Q, Id, atol=eps)
def _make_orthogonal(A):
"""Assume that A is a tall matrix.
Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative.
"""
X, tau = torch.geqrf(A)
Q = torch.linalg.householder_product(X, tau)
# The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
return Q
class _OrthMaps(Enum):
matrix_exp = auto()
cayley = auto()
householder = auto()
class _Orthogonal(Module):
base: Tensor
def __init__(
self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True
) -> None:
super().__init__()
# Note [Householder complex]
# For complex tensors, it is not possible to compute the tensor `tau` necessary for
# linalg.householder_product from the reflectors.
# To see this, note that the reflectors have a shape like:
# 0 0 0
# * 0 0
# * * 0
# which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
# to parametrize the unitary matrices. Saving tau on its own does not work either, because
# not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
# them as independent tensors we would not maintain the constraint
# An equivalent reasoning holds for rectangular matrices
if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
raise ValueError(
"The householder parametrization does not support complex tensors."
)
self.shape = weight.shape
self.orthogonal_map = orthogonal_map
if use_trivialization:
self.register_buffer("base", None)
def forward(self, X: torch.Tensor) -> torch.Tensor:
n, k = X.size(-2), X.size(-1)
transposed = n < k
if transposed:
X = X.mT
n, k = k, n
# Here n > k and X is a tall matrix
if (
self.orthogonal_map == _OrthMaps.matrix_exp
or self.orthogonal_map == _OrthMaps.cayley
):
# We just need n x k - k(k-1)/2 parameters
X = X.tril()
if n != k:
# Embed into a square matrix
X = torch.cat(
[X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1
)
A = X - X.mH
# A is skew-symmetric (or skew-hermitian)
if self.orthogonal_map == _OrthMaps.matrix_exp:
Q = torch.matrix_exp(A)
elif self.orthogonal_map == _OrthMaps.cayley:
# Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
Id = torch.eye(n, dtype=A.dtype, device=A.device)
Q = torch.linalg.solve(
torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)
)
# Q is now orthogonal (or unitary) of size (..., n, n)
if n != k:
Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed)
else:
# X is real here, as we do not support householder with complex numbers
A = X.tril(diagonal=-1)
tau = 2.0 / (1.0 + (A * A).sum(dim=-2))
Q = torch.linalg.householder_product(A, tau)
# The diagonal of X is 1's and -1's
# We do not want to differentiate through this or update the diagonal of X hence the casting
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
Q = self.base @ Q
if transposed:
Q = Q.mT
return Q # type: ignore[possibly-undefined]
@torch.autograd.no_grad()
def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
if Q.shape != self.shape:
raise ValueError(
f"Expected a matrix or batch of matrices of shape {self.shape}. "
f"Got a tensor of shape {Q.shape}."
)
Q_init = Q
n, k = Q.size(-2), Q.size(-1)
transpose = n < k
if transpose:
Q = Q.mT
n, k = k, n
# We always make sure to always copy Q in every path
if not hasattr(self, "base"):
# Note [right_inverse expm cayley]
# If we do not have use_trivialization=True, we just implement the inverse of the forward
# map for the Householder. To see why, think that for the Cayley map,
# we would need to find the matrix X \in R^{n x k} such that:
# Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
# A = Y - Y.mH
# cayley(A)[:, :k]
# gives the original tensor. It is not clear how to do this.
# Perhaps via some algebraic manipulation involving the QR like that of
# Corollary 2.2 in Edelman, Arias and Smith?
if (
self.orthogonal_map == _OrthMaps.cayley
or self.orthogonal_map == _OrthMaps.matrix_exp
):
raise NotImplementedError(
"It is not possible to assign to the matrix exponential "
"or the Cayley parametrizations when use_trivialization=False."
)
# If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
# Here Q is always real because we do not support householder and complex matrices.
# See note [Householder complex]
A, tau = torch.geqrf(Q)
# We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
# decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
# The diagonal of Q is the diagonal of R from the qr decomposition
A.diagonal(dim1=-2, dim2=-1).sign_()
# Equality with zero is ok because LAPACK returns exactly zero when it does not want
# to use a particular reflection
A.diagonal(dim1=-2, dim2=-1)[tau == 0.0] *= -1
return A.mT if transpose else A
else:
if n == k:
# We check whether Q is orthogonal
if not _is_orthogonal(Q):
Q = _make_orthogonal(Q)
else: # Is orthogonal
Q = Q.clone()
else:
# Complete Q into a full n x n orthogonal matrix
N = torch.randn(
*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device
)
Q = torch.cat([Q, N], dim=-1)
Q = _make_orthogonal(Q)
self.base = Q
# It is necessary to return the -Id, as we use the diagonal for the
# Householder parametrization. Using -Id makes:
# householder(torch.zeros(m,n)) == torch.eye(m,n)
# Poor man's version of eye_like
neg_Id = torch.zeros_like(Q_init)
neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.0)
return neg_Id
def orthogonal(
module: Module,
name: str = "weight",
orthogonal_map: Optional[str] = None,
*,
use_trivialization: bool = True,
) -> Module:
r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
.. math::
\begin{align*}
Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
\end{align*}
where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
and the transpose when :math:`Q` is real-valued, and
:math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
and orthonormal rows otherwise.
If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
- ``"matrix_exp"``/``"cayley"``:
the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
:math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
:math:`A` to give an orthogonal matrix.
- ``"householder"``: computes a product of Householder reflectors
(:func:`~torch.linalg.householder_product`).
``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
``"householder"``, but they are slower to compute for very thin or very wide matrices.
If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
``module.parametrizations.weight[0].base``. This helps the
convergence of the parametrized layer at the expense of some extra memory use.
See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
Initial value of :math:`Q`:
If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
Otherwise, the initial value is the result of the composition of all the registered
parametrizations applied to the original tensor.
.. note::
This function is implemented using the parametrization functionality
in :func:`~torch.nn.utils.parametrize.register_parametrization`.
.. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
.. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
Args:
module (nn.Module): module on which to register the parametrization.
name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
Default: ``True``.
Returns:
The original module with an orthogonal parametrization registered to the specified
weight
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> orth_linear = orthogonal(nn.Linear(20, 40))
>>> orth_linear
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _Orthogonal()
)
)
)
>>> # xdoctest: +IGNORE_WANT
>>> Q = orth_linear.weight
>>> torch.dist(Q.T @ Q, torch.eye(20))
tensor(4.9332e-07)
"""
weight = getattr(module, name, None)
if not isinstance(weight, Tensor):
raise ValueError(
f"Module '{module}' has no parameter or buffer with name '{name}'"
)
# We could implement this for 1-dim tensors as the maps on the sphere
# but I believe it'd bite more people than it'd help
if weight.ndim < 2:
raise ValueError(
"Expected a matrix or batch of matrices. "
f"Got a tensor of {weight.ndim} dimensions."
)
if orthogonal_map is None:
orthogonal_map = (
"matrix_exp"
if weight.size(-2) == weight.size(-1) or weight.is_complex()
else "householder"
)
orth_enum = getattr(_OrthMaps, orthogonal_map, None)
if orth_enum is None:
raise ValueError(
'orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
f"Got: {orthogonal_map}"
)
orth = _Orthogonal(weight, orth_enum, use_trivialization=use_trivialization)
parametrize.register_parametrization(module, name, orth, unsafe=True)
return module
class _WeightNorm(Module):
def __init__(
self,
dim: Optional[int] = 0,
) -> None:
super().__init__()
if dim is None:
dim = -1
self.dim = dim
def forward(self, weight_g, weight_v):
return torch._weight_norm(weight_v, weight_g, self.dim)
def right_inverse(self, weight):
weight_g = torch.norm_except_dim(weight, 2, self.dim)
weight_v = weight
return weight_g, weight_v
def weight_norm(module: Module, name: str = "weight", dim: int = 0):
r"""Apply weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` with two parameters: one specifying the magnitude
and one specifying the direction.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _WeightNorm()
)
)
)
>>> m.parametrizations.weight.original0.size()
torch.Size([40, 1])
>>> m.parametrizations.weight.original1.size()
torch.Size([40, 20])
"""
_weight_norm = _WeightNorm(dim)
parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
def _weight_norm_compat_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
g_key = f"{prefix}{name}_g"
v_key = f"{prefix}{name}_v"
if g_key in state_dict and v_key in state_dict:
original0 = state_dict.pop(g_key)
original1 = state_dict.pop(v_key)
state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
return module
class _SpectralNorm(Module):
def __init__(
self,
weight: torch.Tensor,
n_power_iterations: int = 1,
dim: int = 0,
eps: float = 1e-12,
) -> None:
super().__init__()
ndim = weight.ndim
if dim >= ndim or dim < -ndim:
raise IndexError(
"Dimension out of range (expected to be in range of "
f"[-{ndim}, {ndim - 1}] but got {dim})"
)
if n_power_iterations <= 0:
raise ValueError(
"Expected n_power_iterations to be positive, but "
f"got n_power_iterations={n_power_iterations}"
)
self.dim = dim if dim >= 0 else dim + ndim
self.eps = eps
if ndim > 1:
# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
self.n_power_iterations = n_power_iterations
weight_mat = self._reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
u = weight_mat.new_empty(h).normal_(0, 1)
v = weight_mat.new_empty(w).normal_(0, 1)
self.register_buffer("_u", F.normalize(u, dim=0, eps=self.eps))
self.register_buffer("_v", F.normalize(v, dim=0, eps=self.eps))
# Start with u, v initialized to some reasonable values by performing a number
# of iterations of the power method
self._power_method(weight_mat, 15)
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
# Precondition
assert weight.ndim > 1
if self.dim != 0:
# permute dim to front
weight = weight.permute(
self.dim, *(d for d in range(weight.dim()) if d != self.dim)
)
return weight.flatten(1)
@torch.autograd.no_grad()
def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
# See original note at torch/nn/utils/spectral_norm.py
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
# Therefore, to make the change propagate back, we rely on two
# important behaviors (also enforced via tests):
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
# is already on correct device; and it makes sure that the
# parallelized module is already on `device[0]`.
# 2. If the out tensor in `out=` kwarg has correct shape, it will
# just fill in the values.
# Therefore, since the same power iteration is performed on all
# devices, simply updating the tensors in-place will make sure that
# the module replica on `device[0]` will update the _u vector on the
# parallelized module (by shared storage).
#
# However, after we update `u` and `v` in-place, we need to **clone**
# them before using them to normalize the weight. This is to support
# backproping through two forward passes, e.g., the common pattern in
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
# complain that variables needed to do backward for the first forward
# (i.e., the `u` and `v` vectors) are changed in the second forward.
# Precondition
assert weight_mat.ndim > 1
for _ in range(n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
self._u = F.normalize(
torch.mv(weight_mat, self._v), # type: ignore[has-type]
dim=0,
eps=self.eps,
out=self._u, # type: ignore[has-type]
)
self._v = F.normalize(
torch.mv(weight_mat.H, self._u), # type: ignore[has-type]
dim=0,
eps=self.eps,
out=self._v, # type: ignore[has-type]
)
def forward(self, weight: torch.Tensor) -> torch.Tensor:
if weight.ndim == 1:
# Faster and more exact path, no need to approximate anything
return F.normalize(weight, dim=0, eps=self.eps)
else:
weight_mat = self._reshape_weight_to_matrix(weight)
if self.training:
self._power_method(weight_mat, self.n_power_iterations)
# See above on why we need to clone
u = self._u.clone(memory_format=torch.contiguous_format)
v = self._v.clone(memory_format=torch.contiguous_format)
# The proper way of computing this should be through F.bilinear, but
# it seems to have some efficiency issues:
# https://github.com/pytorch/pytorch/issues/58093
sigma = torch.vdot(u, torch.mv(weight_mat, v))
return weight / sigma
def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
# we may want to assert here that the passed value already
# satisfies constraints
return value
def spectral_norm(
module: Module,
name: str = "weight",
n_power_iterations: int = 1,
eps: float = 1e-12,
dim: Optional[int] = None,
) -> Module:
r"""Apply spectral normalization to a parameter in the given module.
.. math::
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
When applied on a vector, it simplifies to
.. math::
\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
of the model. :math:`\sigma` is approximated performing one iteration of the
`power method`_ every time the weight is accessed. If the dimension of the
weight tensor is greater than 2, it is reshaped to 2D in power iteration
method to get spectral norm.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
.. note::
This function is implemented using the parametrization functionality
in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
reimplementation of :func:`torch.nn.utils.spectral_norm`.
.. note::
When this constraint is registered, the singular vectors associated to the largest
singular value are estimated rather than sampled at random. These are then updated
performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
is accessed with the module on `training` mode.
.. note::
If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
is in training mode on removal, it will perform another power iteration.
If you'd like to avoid this iteration, set the module to eval mode
before its removal.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter. Default: ``"weight"``.
n_power_iterations (int, optional): number of power iterations to
calculate spectral norm. Default: ``1``.
eps (float, optional): epsilon for numerical stability in
calculating norms. Default: ``1e-12``.
dim (int, optional): dimension corresponding to number of outputs.
Default: ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
The original module with a new parametrization registered to the specified
weight
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
ParametrizedLinear(
in_features=20, out_features=40, bias=True
(parametrizations): ModuleDict(
(weight): ParametrizationList(
(0): _SpectralNorm()
)
)
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
tensor(1.0081, grad_fn=<AmaxBackward0>)
"""
weight = getattr(module, name, None)
if not isinstance(weight, Tensor):
raise ValueError(
f"Module '{module}' has no parameter or buffer with name '{name}'"
)
if dim is None:
if isinstance(
module,
(
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
),
):
dim = 1
else:
dim = 0
parametrize.register_parametrization(
module, name, _SpectralNorm(weight, n_power_iterations, dim, eps)
)
return module

View File

@ -0,0 +1,819 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import collections
import copyreg
from contextlib import contextmanager
from copy import deepcopy
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torch.__future__ import get_swap_module_params_on_conversion
from torch.nn.modules.container import Module, ModuleDict, ModuleList
from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
__all__ = [
"cached",
"ParametrizationList",
"register_parametrization",
"is_parametrized",
"remove_parametrizations",
"type_before_parametrizations",
"transfer_parametrizations_and_params",
]
_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
@contextmanager
def cached():
r"""Context manager that enables the caching system within parametrizations registered with :func:`register_parametrization`.
The value of the parametrized objects is computed and cached the first time
they are required when this context manager is active. The cached values are
discarded when leaving the context manager.
This is useful when using a parametrized parameter more than once in the forward pass.
An example of this is when parametrizing the recurrent kernel of an RNN or when
sharing weights.
The simplest way to activate the cache is by wrapping the forward pass of the neural network
.. code-block:: python
import torch.nn.utils.parametrize as P
...
with P.cached():
output = model(inputs)
in training and evaluation. One may also wrap the parts of the modules that use
several times the parametrized tensors. For example, the loop of an RNN with a
parametrized recurrent kernel:
.. code-block:: python
with P.cached():
for x in xs:
out_rnn = self.rnn_cell(x, out_rnn)
"""
global _cache
global _cache_enabled
_cache_enabled += 1
try:
yield
finally:
_cache_enabled -= 1
if not _cache_enabled:
_cache = {}
def _register_parameter_or_buffer(module, name, X):
if isinstance(X, Parameter):
module.register_parameter(name, X)
else:
module.register_buffer(name, X)
def _maybe_set(dest: Tensor, src: Tensor) -> None:
should_swap = (
get_swap_module_params_on_conversion() or is_traceable_wrapper_subclass(dest)
)
if should_swap:
if isinstance(dest, Parameter) and not isinstance(src, Parameter):
src = Parameter(src, requires_grad=dest.requires_grad)
torch.utils.swap_tensors(dest, src)
else:
dest.set_(src) # type: ignore[call-overload]
class ParametrizationList(ModuleList):
r"""A sequential container that holds and manages the original parameters or buffers of a parametrized :class:`torch.nn.Module`.
It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]``
has been parametrized with :func:`register_parametrization`.
If the first registered parametrization has a ``right_inverse`` that returns one tensor or
does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity),
it will hold the tensor under the name ``original``.
If it has a ``right_inverse`` that returns more than one tensor, these will be registered as
``original0``, ``original1``, ...
.. warning::
This class is used internally by :func:`register_parametrization`. It is documented
here for completeness. It shall not be instantiated by the user.
Args:
modules (sequence): sequence of modules representing the parametrizations
original (Parameter or Tensor): parameter or buffer that is parametrized
unsafe (bool): a boolean flag that denotes whether the parametrization
may change the dtype and shape of the tensor. Default: `False`
Warning: the parametrization is not checked for consistency upon registration.
Enable this flag at your own risk.
"""
original: Tensor
unsafe: bool
def __init__(
self,
modules: Sequence[Module],
original: Union[Tensor, Parameter],
unsafe: bool = False,
) -> None:
# We require this because we need to treat differently the first parametrization
# This should never throw, unless this class is used from the outside
if len(modules) == 0:
raise ValueError("ParametrizationList requires one or more modules.")
super().__init__(modules)
self.unsafe = unsafe
# In plain words:
# module.weight must keep its dtype and shape.
# Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
# this should be of the same dtype as the original tensor
#
# We check that the following invariants hold:
# X = module.weight
# Y = param.right_inverse(X)
# assert isinstance(Y, Tensor) or
# (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
# Z = param(Y) if isinstance(Y, Tensor) else param(*Y)
# # Consistency checks
# assert X.dtype == Z.dtype and X.shape == Z.shape
# # If it has one input, this allows to be able to use set_ to be able to
# # move data to/from the original tensor without changing its id (which is what the
# # optimizer uses to track parameters)
# if isinstance(Y, Tensor)
# assert X.dtype == Y.dtype
# Below we use original = X, new = Y
original_shape = original.shape
original_dtype = original.dtype
# Compute new
with torch.no_grad():
new = original
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
try:
new = module.right_inverse(new)
except NotImplementedError:
pass
# else, or if it throws, we assume that right_inverse is the identity
if not isinstance(new, Tensor) and not isinstance(
new, collections.abc.Sequence
):
raise ValueError(
"'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
f"Got {type(new).__name__}"
)
# Set the number of original tensors
self.is_tensor = isinstance(new, Tensor)
self.ntensors = 1 if self.is_tensor else len(new)
# Register the tensor(s)
if self.is_tensor:
if original.dtype != new.dtype:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
f"original.dtype: {original.dtype}\n"
f"right_inverse(original).dtype: {new.dtype}"
)
# Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser
with torch.no_grad():
_maybe_set(original, new)
_register_parameter_or_buffer(self, "original", original)
else:
for i, originali in enumerate(new):
if not isinstance(originali, Tensor):
raise ValueError(
"'right_inverse' must return a Tensor or a Sequence of tensors "
"(list, tuple...). "
f"Got element {i} of the sequence with type {type(originali).__name__}."
)
# If the original tensor was a Parameter that required grad, we expect the user to
# add the new parameters to the optimizer after registering the parametrization
# (this is documented)
if isinstance(original, Parameter):
originali = Parameter(originali, original.requires_grad)
originali.requires_grad_(original.requires_grad)
_register_parameter_or_buffer(self, f"original{i}", originali)
if not self.unsafe:
# Consistency checks:
# Since f : A -> B, right_inverse : B -> A, Z and original should live in B
# Z = forward(right_inverse(original))
Z = self()
if not isinstance(Z, Tensor):
raise ValueError(
f"A parametrization must return a tensor. Got {type(Z).__name__}."
)
if Z.dtype != original_dtype:
raise ValueError(
"Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
f"unparametrized dtype: {original_dtype}\n"
f"parametrized dtype: {Z.dtype}"
)
if Z.shape != original_shape:
raise ValueError(
"Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
f"unparametrized shape: {original_shape}\n"
f"parametrized shape: {Z.shape}"
)
def right_inverse(self, value: Tensor) -> None:
r"""Call the ``right_inverse`` methods of the parametrizations in the inverse registration order.
Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor
or in ``self.original0``, ``self.original1``, ... if it outputs several.
Args:
value (Tensor): Value to which initialize the module
"""
# All the exceptions in this function should almost never throw.
# They could throw if, for example, right_inverse function returns a different
# dtype when given a different input, which should most likely be caused by a
# bug in the user's code
with torch.no_grad():
# See https://github.com/pytorch/pytorch/issues/53103
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
value = module.right_inverse(value)
else:
raise RuntimeError(
f"parametrization {type(module).__name__} does not implement "
"right_inverse."
)
if self.is_tensor:
# These exceptions should only throw when a right_inverse function does not
# return the same dtype for every input, which should most likely be caused by a bug
if not isinstance(value, Tensor):
raise ValueError(
f"`right_inverse` should return a tensor. Got {type(value).__name__}"
)
if value.dtype != self.original.dtype:
raise ValueError(
f"The tensor returned by `right_inverse` has dtype {value.dtype} "
f"while `original` has dtype {self.original.dtype}"
)
# We know that the result is going to have the same dtype
_maybe_set(self.original, value)
else:
if not isinstance(value, collections.abc.Sequence):
raise ValueError(
"'right_inverse' must return a sequence of tensors. "
f"Got {type(value).__name__}."
)
if len(value) != self.ntensors:
raise ValueError(
"'right_inverse' must return a sequence of tensors of length "
f"{self.ntensors}. Got a sequence of length {len(value)}."
)
for i, tensor in enumerate(value):
original_i = getattr(self, f"original{i}")
if not isinstance(tensor, Tensor):
raise ValueError(
f"`right_inverse` must return a sequence of tensors. "
f"Got element {i} of type {type(tensor).__name__}"
)
if original_i.dtype != tensor.dtype:
raise ValueError(
f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
f"while `original{i}` has dtype {original_i.dtype}"
)
_maybe_set(original_i, tensor)
def forward(self) -> Tensor:
if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.")
# Unpack the originals for the first parametrization
if self.is_tensor:
x = self[0](self.original)
else:
originals = (getattr(self, f"original{i}") for i in range(self.ntensors))
x = self[0](*originals)
# It's not possible to call self[1:] here, so we have to be a bit more cryptic
# Also we want to skip all non-integer keys
curr_idx = 1
while hasattr(self, str(curr_idx)):
x = self[curr_idx](x)
curr_idx += 1
return x
def _inject_new_class(module: Module) -> None:
r"""Set up a module to be parametrized.
This works by substituting the class of the module by a class
that extends it to be able to inject a property
Args:
module (nn.Module): module into which to inject the property
"""
cls = module.__class__
def default_deepcopy(self, memo):
# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
obj = memo.get(id(self), None)
if obj is not None:
return obj
replica = self.__new__(self.__class__)
memo[id(self)] = replica
replica.__dict__ = deepcopy(self.__dict__, memo)
# Also save all slots if they exist.
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
for slot in slots_to_save:
if hasattr(self, slot):
setattr(replica, slot, deepcopy(getattr(self, slot), memo))
return replica
def getstate(self):
raise RuntimeError(
"Serialization of parametrized modules is only "
"supported through state_dict(). See:\n"
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
dct = {"__getstate__": getstate}
# We don't allow serialization of parametrized modules but should still allow deepcopying.
# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
if not hasattr(cls, "__deepcopy__"):
dct["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
dct,
)
module.__class__ = param_cls
def _inject_property(module: Module, tensor_name: str) -> None:
r"""Injects a property into module[tensor_name].
It assumes that the class in the module has already been modified from its
original one using _inject_new_class and that the tensor under :attr:`tensor_name`
has already been moved out
Args:
module (nn.Module): module into which to inject the property
tensor_name (str): name of the name of the property to create
"""
# We check the precondition.
# This should never fire if register_parametrization is correctly implemented
assert not hasattr(module, tensor_name)
@torch.jit.unused
def get_cached_parametrization(parametrization) -> Tensor:
global _cache
key = (id(module), tensor_name)
tensor = _cache.get(key)
if tensor is None:
tensor = parametrization()
_cache[key] = tensor
return tensor
def get_parametrized(self) -> Tensor:
if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.")
parametrization = self.parametrizations[tensor_name]
if _cache_enabled:
if torch.jit.is_scripting():
# Scripting
raise RuntimeError(
"Caching is not implemented for scripting. "
"Either disable caching or avoid scripting."
)
elif torch._C._get_tracing_state() is not None:
# Tracing
raise RuntimeError(
"Cannot trace a model while caching parametrizations."
)
else:
return get_cached_parametrization(parametrization)
else:
# If caching is not active, this function just evaluates the parametrization
return parametrization()
def set_original(self, value: Tensor) -> None:
if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.")
self.parametrizations[tensor_name].right_inverse(value)
setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
def register_parametrization(
module: Module,
tensor_name: str,
parametrization: Module,
*,
unsafe: bool = False,
) -> Module:
r"""Register a parametrization to a tensor in a module.
Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
the module will return the parametrized version ``parametrization(module.weight)``.
If the original tensor requires a gradient, the backward pass will differentiate
through :attr:`parametrization`, and the optimizer will update the tensor accordingly.
The first time that a module registers a parametrization, this function will add an attribute
``parametrizations`` to the module of type :class:`~ParametrizationList`.
The list of parametrizations on the tensor ``weight`` will be accessible under
``module.parametrizations.weight``.
The original tensor will be accessible under
``module.parametrizations.weight.original``.
Parametrizations may be concatenated by registering several parametrizations
on the same attribute.
The training mode of a registered parametrization is updated on registration
to match the training mode of the host module
Parametrized parameters and buffers have an inbuilt caching system that can be activated
using the context manager :func:`cached`.
A :attr:`parametrization` may optionally implement a method with signature
.. code-block:: python
def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
This method is called on the unparametrized tensor when the first parametrization
is registered to compute the initial value of the original tensor.
If this method is not implemented, the original tensor will be just the unparametrized tensor.
If all the parametrizations registered on a tensor implement `right_inverse` it is possible
to initialize a parametrized tensor by assigning to it, as shown in the example below.
It is possible for the first parametrization to depend on several inputs.
This may be implemented returning a tuple of tensors from ``right_inverse``
(see the example implementation of a ``RankOne`` parametrization below).
In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
with names ``original0``, ``original1``,...
.. note::
If unsafe=False (default) both the forward and right_inverse methods will be called
once to perform a number of consistency checks.
If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
and nothing will be called otherwise.
.. note::
In most situations, ``right_inverse`` will be a function such that
``forward(right_inverse(X)) == X`` (see
`right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
Sometimes, when the parametrization is not surjective, it may be reasonable
to relax this.
.. warning::
If a parametrization depends on several inputs, :func:`~register_parametrization`
will register a number of new parameters. If such parametrization is registered
after the optimizer is created, these new parameters will need to be added manually
to the optimizer. See :meth:`torch.Optimizer.add_param_group`.
Args:
module (nn.Module): module on which to register the parametrization
tensor_name (str): name of the parameter or buffer on which to register
the parametrization
parametrization (nn.Module): the parametrization to register
Keyword args:
unsafe (bool): a boolean flag that denotes whether the parametrization
may change the dtype and shape of the tensor. Default: `False`
Warning: the parametrization is not checked for consistency upon registration.
Enable this flag at your own risk.
Raises:
ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>> def forward(self, X):
>>> return X.triu() + X.triu(1).T # Return a symmetric matrix
>>>
>>> def right_inverse(self, A):
>>> return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T # A is now symmetric
>>> m.weight = A # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>> def forward(self, x, y):
>>> # Form a rank 1 matrix multiplying two vectors
>>> return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>> def right_inverse(self, Z):
>>> # Project Z onto the rank 1 matrices
>>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>> # Return rescaled singular vectors
>>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1
"""
parametrization.train(module.training)
if is_parametrized(module, tensor_name):
# Correctness checks.
# If A is the space of tensors with shape and dtype equal to module.weight
# we check that parametrization.forward and parametrization.right_inverse are
# functions from A to A
if not unsafe:
Y = getattr(module, tensor_name)
X = parametrization(Y)
if not isinstance(X, Tensor):
raise ValueError(
f"A parametrization must return a tensor. Got {type(X).__name__}."
)
if X.dtype != Y.dtype:
raise ValueError(
"Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.dtype: {Y.dtype}\n"
f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
)
if X.shape != Y.shape:
raise ValueError(
"Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.shape: {Y.shape}\n"
f"parametrization(module.{tensor_name}).shape: {X.shape}"
)
if hasattr(parametrization, "right_inverse"):
try:
Z = parametrization.right_inverse(X) # type: ignore[operator]
except NotImplementedError:
pass
else:
if not isinstance(Z, Tensor):
raise ValueError(
f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
)
if Z.dtype != Y.dtype:
raise ValueError(
"The tensor returned by parametrization.right_inverse must have the same dtype "
f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.dtype: {Y.dtype}\n"
f"returned dtype: {Z.dtype}"
)
if Z.shape != Y.shape:
raise ValueError(
"The tensor returned by parametrization.right_inverse must have the same shape "
f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
f"module.{tensor_name}.shape: {Y.shape}\n"
f"returned shape: {Z.shape}"
)
# else right_inverse is assumed to be the identity
# add the new parametrization to the parametrization list
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
module.parametrizations[tensor_name].append(parametrization)
# If unsafe was True in previous parametrization, keep it enabled
module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr]
elif tensor_name in module._buffers or tensor_name in module._parameters:
# Set the parametrization mechanism
# Fetch the original buffer or parameter
original = getattr(module, tensor_name)
# We create this early to check for possible errors
parametrizations = ParametrizationList(
[parametrization], original, unsafe=unsafe
)
# Delete the previous parameter or buffer
delattr(module, tensor_name)
# If this is the first parametrization registered on the module,
# we prepare the module to inject the property
if not is_parametrized(module):
# Change the class
_inject_new_class(module)
# Inject a ``ModuleDict`` into the instance under module.parametrizations
module.parametrizations = ModuleDict()
# Add a property into the class
_inject_property(module, tensor_name)
# Add a ParametrizationList
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
module.parametrizations[tensor_name] = parametrizations
else:
raise ValueError(
f"Module '{module}' does not have a parameter, a buffer, or a "
f"parametrized element with name '{tensor_name}'"
)
return module
def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
r"""Determine if a module has a parametrization.
Args:
module (nn.Module): module to query
tensor_name (str, optional): name of the parameter in the module
Default: ``None``
Returns:
``True`` if :attr:`module` has a parametrization for the parameter named :attr:`tensor_name`,
or if it has any parametrization when :attr:`tensor_name` is ``None``;
otherwise ``False``
"""
parametrizations = getattr(module, "parametrizations", None)
if parametrizations is None or not isinstance(parametrizations, ModuleDict):
return False
if tensor_name is None:
# Check that there is at least one parametrized buffer or Parameter
return len(parametrizations) > 0
else:
return tensor_name in parametrizations
def remove_parametrizations(
module: Module,
tensor_name: str,
leave_parametrized: bool = True,
) -> Module:
r"""Remove the parametrizations on a tensor in a module.
- If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
its current output. In this case, the parametrization shall not change the ``dtype``
of the tensor.
- If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
This is only possible when the parametrization depends on just one tensor.
Args:
module (nn.Module): module from which remove the parametrization
tensor_name (str): name of the parametrization to be removed
leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
Default: ``True``
Returns:
Module: module
Raises:
ValueError: if ``module[tensor_name]`` is not parametrized
ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
"""
if not is_parametrized(module, tensor_name):
raise ValueError(
f"Module {module} does not have a parametrization on {tensor_name}"
)
# Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
parametrizations = module.parametrizations[tensor_name]
if parametrizations.is_tensor:
original = parametrizations.original
if leave_parametrized:
with torch.no_grad():
t = getattr(module, tensor_name)
# We know they have the same dtype because we have checked this when registering the
# parametrizations. As such, we can use set_
# We do this so that the parameter does not to change the id()
# This way the user does not need to update the optimizer
with torch.no_grad():
if type(original) is torch.Tensor:
_maybe_set(original, t)
else:
try:
_maybe_set(original, t)
except RuntimeError as e:
# TODO: Fix this for tensor subclasses that are parameters:
# RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
raise RuntimeError(
"Calling remove_parametrizations() with leave_parametrized=True "
"for a parameter that is an instance of a tensor subclass requires "
"set_() to be implemented correctly for the tensor subclass."
"Alternatively, one can opt into the swap_tensors path"
"Either set leave_parametrized=False or provide a working implementation"
"for set_() in the tensor subclass or set "
"torch.__future__.set_swap_module_params_on_conversion(True)."
) from e
else:
if leave_parametrized:
# We cannot use no_grad because we need to know whether one or more
# original tensors required grad
t = getattr(module, tensor_name)
# We'll have to trust the user to add it to the optimizer
original = Parameter(t) if t.requires_grad else t
else:
raise ValueError(
"Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
"that is parametrized in terms of a sequence of tensors."
)
# Delete the property that manages the parametrization
delattr(module.__class__, tensor_name)
# Delete the ParametrizationList
del module.parametrizations[tensor_name]
# Restore the parameter / buffer into the main class
_register_parameter_or_buffer(module, tensor_name, original)
# Roll back the parametrized class if no other buffer or parameter
# is currently parametrized in this class
if not is_parametrized(module):
delattr(module, "parametrizations")
# Restore class
orig_cls = module.__class__.__bases__[0]
module.__class__ = orig_cls
return module
def type_before_parametrizations(module: Module) -> type:
r"""Return the module type before parametrizations were applied and if not, then it returns the module type.
Args:
module (nn.Module): module to get type of
"""
if is_parametrized(module):
return module.__class__.__bases__[0]
else:
return type(module)
def transfer_parametrizations_and_params(
from_module: Module,
to_module: Module,
tensor_name: Optional[str] = None,
) -> Module:
r"""Transfer parametrizations and the parameters they parametrize from :attr:`from_module` to :attr:`to_module`.
If :attr:`tensor_name` is specified, only transfers the specified parameter, otherwise
transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them.
Does nothing if from_module is not parametrized.
Args:
from_module (nn.Module): module to transfer from
to_module (nn.Module): module to transfer to
tensor_name (str, optional): parameter to transfer
Returns:
Module: to_module
"""
if is_parametrized(from_module):
assert isinstance(from_module.parametrizations, ModuleDict) # for mypy
# get list of all params or the single param to transfer
parameters_to_transfer: Union[list, ModuleDict] = (
from_module.parametrizations if tensor_name is None else [tensor_name]
)
assert hasattr(parameters_to_transfer, "__iter__") # for mypy
for parameter_name in parameters_to_transfer:
# initialize the to-be-transferred param in to_module if it doesn't exist already
if not hasattr(to_module, parameter_name):
setattr(
to_module,
parameter_name,
Parameter(getattr(from_module, parameter_name)),
)
# apply the params's parametrizations to to_module
for param_func in from_module.parametrizations[parameter_name]:
register_parametrization(to_module, parameter_name, param_func)
assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
# make values match, original values can be stored in either original or
# original0, original1..., need to check both cases
if hasattr(from_module.parametrizations[parameter_name], "original"):
to_module.parametrizations[
parameter_name
].original = from_module.parametrizations[parameter_name].original
else:
num = 0
orig_num = "original" + str(num)
# loop through each original# until all values have been set
while hasattr(from_module.parametrizations[parameter_name], orig_num):
setattr(
to_module.parametrizations[parameter_name],
orig_num,
getattr(from_module.parametrizations[parameter_name], orig_num),
)
num = num + 1
orig_num = "original" + str(num)
return to_module

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,599 @@
import warnings
from collections.abc import Iterable
from typing import (
Any,
Callable,
List,
NamedTuple,
Optional,
overload,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Self
import torch
from torch import _VF, Tensor
__all__ = [
"PackedSequence",
"invert_permutation",
"pack_padded_sequence",
"pad_packed_sequence",
"pad_sequence",
"unpad_sequence",
"pack_sequence",
"unpack_sequence",
]
_T = TypeVar("_T")
_R = TypeVar("_R")
class PackedSequence_(NamedTuple):
data: torch.Tensor
batch_sizes: torch.Tensor
sorted_indices: Optional[torch.Tensor]
unsorted_indices: Optional[torch.Tensor]
def bind(optional: Optional[_T], fn: Callable[[_T], _R]) -> Optional[_R]:
if optional is None:
return None
return fn(optional)
class PackedSequence(PackedSequence_):
r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
All RNN modules accept packed sequences as inputs.
Note:
Instances of this class should never be created manually. They are meant
to be instantiated by functions like :func:`pack_padded_sequence`.
Batch sizes represent the number elements at each sequence step in
the batch, not the varying sequence lengths passed to
:func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x``
the :class:`PackedSequence` would contain data ``axbc`` with
``batch_sizes=[2,1,1]``.
Attributes:
data (Tensor): Tensor containing packed sequence
batch_sizes (Tensor): Tensor of integers holding
information about the batch size at each sequence step
sorted_indices (Tensor, optional): Tensor of integers holding how this
:class:`PackedSequence` is constructed from sequences.
unsorted_indices (Tensor, optional): Tensor of integers holding how this
to recover the original sequences with correct order.
.. note::
:attr:`data` can be on arbitrary device and of arbitrary dtype.
:attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
tensors on the same device as :attr:`data`.
However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
This invariant is maintained throughout :class:`PackedSequence` class,
and all functions that construct a :class:`PackedSequence` in PyTorch
(i.e., they only pass in tensors conforming to this constraint).
"""
def __new__(
cls,
data: Tensor,
batch_sizes: Optional[Tensor] = None,
sorted_indices: Optional[Tensor] = None,
unsorted_indices: Optional[Tensor] = None,
) -> Self:
return super().__new__(
cls,
*_packed_sequence_init_args(
data, batch_sizes, sorted_indices, unsorted_indices
),
)
# NOTE [ device and dtype of a PackedSequence ]
#
# See the note above in doc string (starting with ":attr:`data` can be on
# arbitrary device...").
def pin_memory(self) -> Self:
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
return type(self)(
self.data.pin_memory(),
self.batch_sizes,
bind(self.sorted_indices, lambda t: t.pin_memory()),
bind(self.unsorted_indices, lambda t: t.pin_memory()),
)
@overload
def to(
self,
dtype: torch.dtype,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
@overload
def to(
self,
device: Optional[Union[str, torch.device, int]] = ...,
dtype: Optional[torch.dtype] = ...,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
@overload
def to(
self,
other: Tensor,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
def to(self, *args: Any, **kwargs: Any) -> Self:
r"""Perform dtype and/or device conversion on `self.data`.
It has similar signature as :meth:`torch.Tensor.to`, except optional
arguments like `non_blocking` and `copy` should be passed as kwargs,
not args, or they will not apply to the index tensors.
.. note::
If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
and :class:`torch.device`, then ``self`` is returned.
Otherwise, returns a copy with the desired configuration.
"""
# Why not convert `batch_sizes`?
# See NOTE [ device and dtype of a PackedSequence ]
data = self.data.to(*args, **kwargs)
if data is self.data:
return self
else:
# Does not forward device or dtype arg/kwargs, device is set from data.device
kwargs = dict(
filter(lambda t: t[0] != "device" and t[0] != "dtype", kwargs.items())
)
sorted_indices = bind(
self.sorted_indices, lambda t: t.to(data.device, **kwargs)
)
unsorted_indices = bind(
self.unsorted_indices, lambda t: t.to(data.device, **kwargs)
)
return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
def cuda(self, *args: Any, **kwargs: Any) -> Self:
# Tests to see if 'cuda' should be added to kwargs
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
*args, **kwargs
)
if ex.is_cuda:
return self.to(*args, **kwargs)
kwargs["device"] = "cuda"
return self.to(*args, **kwargs)
def cpu(self, *args: Any, **kwargs: Any) -> Self:
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
*args, **kwargs
)
if ex.device.type == "cpu":
return self.to(*args, **kwargs)
kwargs["device"] = "cpu"
return self.to(*args, **kwargs)
def double(self) -> Self:
return self.to(dtype=torch.double)
def float(self) -> Self:
return self.to(dtype=torch.float)
def half(self) -> Self:
return self.to(dtype=torch.half)
def long(self) -> Self:
return self.to(dtype=torch.long)
def int(self) -> Self:
return self.to(dtype=torch.int)
def short(self) -> Self:
return self.to(dtype=torch.short)
def char(self) -> Self:
return self.to(dtype=torch.int8)
def byte(self) -> Self:
return self.to(dtype=torch.uint8)
@property
def is_cuda(self) -> bool:
r"""Return true if `self.data` stored on a gpu."""
return self.data.is_cuda
def is_pinned(self) -> bool:
r"""Return true if `self.data` stored on in pinned memory."""
return self.data.is_pinned()
# TorchScript doesn't support constructors on named tuples, so we use this helper
# method to construct PackedSequence
def _packed_sequence_init_args(
data: Tensor,
batch_sizes: Optional[Tensor] = None,
sorted_indices: Optional[Tensor] = None,
unsorted_indices: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
# NB: if unsorted_indices is provided, it should be the inverse permutation
# to sorted_indices. Don't assert it here because the PackedSequence ctor
# should only be used internally.
if unsorted_indices is None:
unsorted_indices = invert_permutation(sorted_indices)
# support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
if batch_sizes is not None:
# TODO: Re-enable this check (.type isn't supported in TorchScript)
if batch_sizes.device.type != "cpu":
raise ValueError(
"batch_sizes should always be on CPU. "
"Instances of PackedSequence should never be created manually. "
"They should be instantiated by functions like pack_sequence "
"and pack_padded_sequences in nn.utils.rnn. "
"https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence"
)
return data, batch_sizes, sorted_indices, unsorted_indices
# support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
else:
assert isinstance(data, (list, tuple)) and len(data) == 2
return data[0], data[1], sorted_indices, unsorted_indices
def _packed_sequence_init(
data: Tensor,
batch_sizes: Optional[Tensor] = None,
sorted_indices: Optional[Tensor] = None,
unsorted_indices: Optional[Tensor] = None,
) -> PackedSequence:
data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
data, batch_sizes, sorted_indices, unsorted_indices
)
return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
if permutation is None:
return None
output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
output.scatter_(
0, permutation, torch.arange(0, permutation.numel(), device=permutation.device)
)
return output
def pack_padded_sequence(
input: Tensor,
lengths: Union[Tensor, List[int]],
batch_first: bool = False,
enforce_sorted: bool = True,
) -> PackedSequence:
r"""Packs a Tensor containing padded sequences of variable length.
:attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length
of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions
(including 0).
For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
``True``, the sequences should be sorted by length in a decreasing order, i.e.
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
one. `enforce_sorted = True` is only necessary for ONNX export.
Note:
This function accepts any input that has at least two dimensions. You
can apply it to pack the labels, and use the output of the RNN with
them to compute the loss directly. A Tensor can be retrieved from
a :class:`PackedSequence` object by accessing its ``.data`` attribute.
Args:
input (Tensor): padded batch of variable length sequences.
lengths (Tensor or list(int)): list of sequence lengths of each batch
element (must be on the CPU if provided as a tensor).
batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
format, ``T x B x *`` otherwise.
enforce_sorted (bool, optional): if ``True``, the input is expected to
contain sequences sorted by length in a decreasing order. If
``False``, the input will get sorted unconditionally. Default: ``True``.
Returns:
a :class:`PackedSequence` object
"""
if not isinstance(lengths, torch.Tensor):
if torch._C._get_tracing_state():
warnings.warn(
"pack_padded_sequence has been called with a Python list of "
"sequence lengths. The tracer cannot track the data flow of Python "
"values, and it will treat them as constants, likely rendering "
"the trace incorrect for any other combination of lengths.",
stacklevel=2,
)
lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu")
else:
lengths = lengths.to(dtype=torch.int64)
if enforce_sorted:
sorted_indices = None
else:
lengths, sorted_indices = torch.sort(lengths, descending=True)
sorted_indices = sorted_indices.to(input.device)
batch_dim = 0 if batch_first else 1
input = input.index_select(batch_dim, sorted_indices)
data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first)
return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
def pad_packed_sequence(
sequence: PackedSequence,
batch_first: bool = False,
padding_value: float = 0.0,
total_length: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
r"""Pad a packed batch of variable length sequences.
It is an inverse operation to :func:`pack_padded_sequence`.
The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest
sequence and ``B`` is the batch size.
Example:
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
[3, 0, 0],
[4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])
.. note::
:attr:`total_length` is useful to implement the
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
details.
Args:
sequence (PackedSequence): batch to pad
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
format, ``T x B x *`` otherwise.
padding_value (float, optional): values for padded elements.
total_length (int, optional): if not ``None``, the output will be padded to
have length :attr:`total_length`. This method will throw :class:`ValueError`
if :attr:`total_length` is less than the max sequence length in
:attr:`sequence`.
Returns:
Tuple of Tensor containing the padded sequence, and a Tensor
containing the list of lengths of each sequence in the batch.
Batch elements will be re-ordered as they were ordered originally when
the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
"""
max_seq_length = sequence.batch_sizes.size(0)
if total_length is not None:
if total_length < max_seq_length:
raise ValueError(
"Expected total_length to be at least the length "
"of the longest sequence in input, but got "
f"total_length={total_length} and max sequence length being {max_seq_length}"
)
max_seq_length = total_length
padded_output, lengths = _VF._pad_packed_sequence(
sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length
)
unsorted_indices = sequence.unsorted_indices
if unsorted_indices is not None:
batch_dim = 0 if batch_first else 1
return (
padded_output.index_select(batch_dim, unsorted_indices),
lengths[unsorted_indices.cpu()],
)
return padded_output, lengths
# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable.
def pad_sequence(
sequences: Union[Tensor, List[Tensor]],
batch_first: bool = False,
padding_value: float = 0.0,
padding_side: str = "right",
) -> Tensor:
r"""Pad a list of variable length Tensors with :attr:`padding_value`.
``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them
to equal length. :attr:`sequences` can be list of sequences with size ``L x *``,
where `L` is length of the sequence and ``*`` is any number of dimensions
(including 0). If :attr:`batch_first` is ``False``, the output is of size
``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size
(the number of elements in :attr:`sequences`), ``T`` is the length of the longest
sequence.
Example:
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
Note:
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
where `T` is the length of the longest sequence. This function assumes
trailing dimensions and type of all the Tensors in sequences are same.
Args:
sequences (list[Tensor]): list of variable length sequences.
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
format, ``T x B x *`` otherwise.
padding_value (float, optional): value for padded elements. Default: 0.
padding_side (str, optional): the side to pad the sequences on.
Default: "right".
Returns:
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
Tensor of size ``B x T x *`` otherwise
"""
if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
# JIT doesn't support `Iterable`
if not isinstance(sequences, Iterable):
msg = (
"pad_sequence: Expected iterable for input sequences, but got arg of type: "
f"{type(sequences)}"
)
raise RuntimeError(msg)
# In JIT context this leads to,
# RuntimeError: cannot statically infer the expected size of a list in this context
sequences = tuple(sequences) # type: ignore[assignment]
else:
# For JIT, we only support Union[Tensor, Tuple[Tensor]]
if isinstance(sequences, torch.Tensor):
sequences = sequences.unbind(0) # type: ignore[assignment]
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
return torch._C._nn.pad_sequence(
sequences, batch_first, padding_value, padding_side # type: ignore[arg-type]
)
def unpad_sequence(
padded_sequences: Tensor,
lengths: Tensor,
batch_first: bool = False,
) -> List[Tensor]:
r"""Unpad padded Tensor into a list of variable length Tensors.
``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
Example:
>>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> sequences = [a, b, c]
>>> padded_sequences = pad_sequence(sequences)
>>> lengths = torch.as_tensor([v.size(0) for v in sequences])
>>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
>>> torch.allclose(sequences[0], unpadded_sequences[0])
True
>>> torch.allclose(sequences[1], unpadded_sequences[1])
True
>>> torch.allclose(sequences[2], unpadded_sequences[2])
True
Args:
padded_sequences (Tensor): padded sequences.
lengths (Tensor): length of original (unpadded) sequences.
batch_first (bool, optional): whether batch dimension first or not. Default: False.
Returns:
a list of :class:`Tensor` objects
"""
unpadded_sequences = []
if not batch_first:
padded_sequences.transpose_(0, 1)
max_length = padded_sequences.shape[1]
idx = torch.arange(max_length, device=lengths.device)
for seq, length in zip(padded_sequences, lengths):
mask = idx < length
unpacked_seq = seq[mask]
unpadded_sequences.append(unpacked_seq)
return unpadded_sequences
def pack_sequence(
sequences: List[Tensor],
enforce_sorted: bool = True,
) -> PackedSequence:
r"""Packs a list of variable length Tensors.
Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
the length of a sequence and `*` is any number of trailing dimensions,
including zero.
For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
is ``True``, the sequences should be sorted in the order of decreasing length.
``enforce_sorted = True`` is only necessary for ONNX export.
Example:
>>> from torch.nn.utils.rnn import pack_sequence
>>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4, 5])
>>> c = torch.tensor([6])
>>> pack_sequence([a, b, c])
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
Args:
sequences (list[Tensor]): A list of sequences of decreasing length.
enforce_sorted (bool, optional): if ``True``, checks that the input
contains sequences sorted by length in a decreasing order. If
``False``, this condition is not checked. Default: ``True``.
Returns:
a :class:`PackedSequence` object
"""
lengths = torch.as_tensor([v.size(0) for v in sequences])
return pack_padded_sequence(
pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted
)
def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
r"""Unpack PackedSequence into a list of variable length Tensors.
``packed_sequences`` should be a PackedSequence object.
Example:
>>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
>>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4, 5])
>>> c = torch.tensor([6])
>>> sequences = [a, b, c]
>>> print(sequences)
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
>>> packed_sequences = pack_sequence(sequences)
>>> print(packed_sequences)
PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
>>> unpacked_sequences = unpack_sequence(packed_sequences)
>>> print(unpacked_sequences)
[tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
Args:
packed_sequences (PackedSequence): A PackedSequence object.
Returns:
a list of :class:`Tensor` objects
"""
padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
return unpacked_sequences

View File

@ -0,0 +1,366 @@
# mypy: allow-untyped-defs
"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""
from typing import Any, Optional, TypeVar
import torch
import torch.nn.functional as F
from torch.nn.modules import Module
__all__ = [
"SpectralNorm",
"SpectralNormLoadStateDictPreHook",
"SpectralNormStateDictHook",
"spectral_norm",
"remove_spectral_norm",
]
class SpectralNorm:
# Invariant before and after each forward call:
# u = F.normalize(W @ v)
# NB: At initialization, this invariant is not enforced
_version: int = 1
# At version 1:
# made `W` not a buffer,
# added `v` as a buffer, and
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
name: str
dim: int
n_power_iterations: int
eps: float
def __init__(
self,
name: str = "weight",
n_power_iterations: int = 1,
dim: int = 0,
eps: float = 1e-12,
) -> None:
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError(
"Expected n_power_iterations to be positive, but "
f"got n_power_iterations={n_power_iterations}"
)
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
weight_mat = weight
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(
self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]
)
height = weight_mat.size(0)
return weight_mat.reshape(height, -1)
def compute_weight(self, module: Module, do_power_iteration: bool) -> torch.Tensor:
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
# Therefore, to make the change propagate back, we rely on two
# important behaviors (also enforced via tests):
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
# is already on correct device; and it makes sure that the
# parallelized module is already on `device[0]`.
# 2. If the out tensor in `out=` kwarg has correct shape, it will
# just fill in the values.
# Therefore, since the same power iteration is performed on all
# devices, simply updating the tensors in-place will make sure that
# the module replica on `device[0]` will update the _u vector on the
# parallelized module (by shared storage).
#
# However, after we update `u` and `v` in-place, we need to **clone**
# them before using them to normalize the weight. This is to support
# backproping through two forward passes, e.g., the common pattern in
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
# complain that variables needed to do backward for the first forward
# (i.e., the `u` and `v` vectors) are changed in the second forward.
weight = getattr(module, self.name + "_orig")
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = F.normalize(
torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v
)
u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
if self.n_power_iterations > 0:
# See above on why we need to clone
u = u.clone(memory_format=torch.contiguous_format)
v = v.clone(memory_format=torch.contiguous_format)
sigma = torch.dot(u, torch.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, module: Module) -> None:
with torch.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + "_u")
delattr(module, self.name + "_v")
delattr(module, self.name + "_orig")
module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
def __call__(self, module: Module, inputs: Any) -> None:
setattr(
module,
self.name,
self.compute_weight(module, do_power_iteration=module.training),
)
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
# Tries to returns a vector `v` s.t. `u = F.normalize(W @ v)`
# (the invariant at top of this class) and `u @ W @ v = sigma`.
# This uses pinverse in case W^T W is not invertible.
v = torch.linalg.multi_dot(
[weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]
).squeeze(1)
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
@staticmethod
def apply(
module: Module, name: str, n_power_iterations: int, dim: int, eps: float
) -> "SpectralNorm":
for hook in module._forward_pre_hooks.values():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError(
f"Cannot register two spectral_norm hooks on the same parameter {name}"
)
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
if weight is None:
raise ValueError(
f"`SpectralNorm` cannot be applied as parameter `{name}` is None"
)
if isinstance(weight, torch.nn.parameter.UninitializedParameter):
raise ValueError(
"The module passed to `SpectralNorm` can't have uninitialized parameters. "
"Make sure to run the dummy forward before applying spectral normalization"
)
with torch.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
# randomly initialize `u` and `v`
u = F.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
v = F.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a plain
# attribute.
setattr(module, fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
return fn
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook:
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn) -> None:
self.fn = fn
# For state_dict with version None, (assuming that it has gone through at
# least one training forward), we have
#
# u = F.normalize(W_orig @ v)
# W = W_orig / sigma, where sigma = u @ W_orig @ v
#
# To compute `v`, we solve `W_orig @ x = u`, and let
# v = x / (u @ W_orig @ x) * (W / W_orig).
def __call__(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
) -> None:
fn = self.fn
version = local_metadata.get("spectral_norm", {}).get(
fn.name + ".version", None
)
if version is None or version < 1:
weight_key = prefix + fn.name
if (
version is None
and all(weight_key + s in state_dict for s in ("_orig", "_u", "_v"))
and weight_key not in state_dict
):
# Detect if it is the updated state dict and just missing metadata.
# This could happen if the users are crafting a state dict themselves,
# so we just pretend that this is the newest.
return
has_missing_keys = False
for suffix in ("_orig", "", "_u"):
key = weight_key + suffix
if key not in state_dict:
has_missing_keys = True
if strict:
missing_keys.append(key)
if has_missing_keys:
return
with torch.no_grad():
weight_orig = state_dict[weight_key + "_orig"]
weight = state_dict.pop(weight_key)
sigma = (weight_orig / weight).mean()
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
u = state_dict[weight_key + "_u"]
v = fn._solve_v_and_rescale(weight_mat, u, sigma)
state_dict[weight_key + "_v"] = v
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormStateDictHook:
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn) -> None:
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata) -> None:
if "spectral_norm" not in local_metadata:
local_metadata["spectral_norm"] = {}
key = self.fn.name + ".version"
if key in local_metadata["spectral_norm"]:
raise RuntimeError(f"Unexpected key in metadata['spectral_norm']: {key}")
local_metadata["spectral_norm"][key] = self.fn._version
T_module = TypeVar("T_module", bound=Module)
def spectral_norm(
module: T_module,
name: str = "weight",
n_power_iterations: int = 1,
eps: float = 1e-12,
dim: Optional[int] = None,
) -> T_module:
r"""Apply spectral normalization to a parameter in the given module.
.. math::
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectral norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
The original module with the spectral norm hook
.. note::
This function has been reimplemented as
:func:`torch.nn.utils.parametrizations.spectral_norm` using the new
parametrization functionality in
:func:`torch.nn.utils.parametrize.register_parametrization`. Please use
the newer version. This function will be deprecated in a future version
of PyTorch.
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_u.size()
torch.Size([40])
"""
if dim is None:
if isinstance(
module,
(
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
),
):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
def remove_spectral_norm(module: T_module, name: str = "weight") -> T_module:
r"""Remove the spectral normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
break
else:
raise ValueError(f"spectral_norm of '{name}' not found in {module}")
for k, hook in module._state_dict_hooks.items():
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
del module._state_dict_hooks[k]
break
for k, hook in module._load_state_dict_pre_hooks.items():
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
del module._load_state_dict_pre_hooks[k]
break
return module

View File

@ -0,0 +1,298 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Optional, Set, Tuple, Union
from typing_extensions import deprecated
import torch
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
__all__ = ["functional_call"]
def _untie_named_tensors_map(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
) -> Dict[str, Tensor]:
"""
Unties all tied tensors in the module to parameters_and_buffers.
This function returns a new untied_parameters_and_buffers dictionary and leave the original
untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors
in the module to untied_parameters_and_buffers. The value of the new key is the user-given value
in the original parameters_and_buffers dictionary.
If there are more than one user-given values for the same tied tensor, it will raise an error.
For example, if the module has two tied weights self.foo and self.tied_foo and the user passes
{'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the
user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the
user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error.
Args:
module (torch.nn.Module): the module to determine which tensors are tied.
parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module.
Returns:
A new untied version of the parameters_and_buffers dictionary.
Raises:
ValueError: if there are more than one user-given values for the same tied tensor.
"""
# A map of {name: tensor} for all tensors (including tied ones) in the module.
all_named_tensors: Dict[str, Tensor] = {}
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
# A map of {tensor: set(all_tied_names)} for all tensor names in the module.
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {}
for name, tensor in all_named_tensors.items():
if tensor not in tensor_to_tied_names_map:
tensor_to_tied_names_map[tensor] = set()
tensor_to_tied_names_map[tensor].add(name)
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
# If a name is not tied, it will not be in this map.
tied_names_map: Dict[str, Set[str]] = {}
for tied_names in tensor_to_tied_names_map.values():
if len(tied_names) > 1:
for tied_name in tied_names:
tied_names_map[tied_name] = tied_names
# Make sure the user didn't pass multiple values for the same tied tensor.
given_names = set(parameters_and_buffers.keys())
# same as given_names.intersection(tied_names_map.keys()) but dynamo can't
# handle that
given_names_for_tied_tensors: set[str] = set()
for name in given_names:
if name in tied_names_map:
given_names_for_tied_tensors.add(name)
for given_name in given_names_for_tied_tensors:
tied_names = tied_names_map[given_name]
if (
# Detect if there are multiple keys present for the same tied tensor.
len(tied_names.intersection(given_names_for_tied_tensors)) > 1
# Only raise an error if the user passed multiple values for the same tied tensor.
# If all given values are the same, don't raise.
and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
!= 1
):
raise ValueError(
f"functional_call got multiple values for keys {sorted(tied_names)}, "
f"which are tied. Consider using tie_weights=False"
)
# Untie the given named tensor map
# Make a copy for not modifying the original dict
untied_parameters_and_buffers = parameters_and_buffers.copy()
for given_name in given_names_for_tied_tensors:
for tied_name in tied_names_map[given_name]:
untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
given_name
]
return untied_parameters_and_buffers
class _ReparametrizeModule:
def __init__(
self,
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
):
self.parameters_and_buffers = parameters_and_buffers
self.stack_weights = stack_weights
if tie_weights:
self.untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
self.untied_parameters_and_buffers = parameters_and_buffers
self.accessor = NamedMemberAccessor(module)
if strict:
missing_keys, unexpected_keys = self.accessor.check_keys(
self.untied_parameters_and_buffers
)
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append(
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}."
)
if len(missing_keys) > 0:
error_msgs.append(
f"Missing key(s): {', '.join(map(repr, missing_keys))}."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in reparametrizing for {}:\n\t{}".format(
module._get_name(), "\n\t".join(error_msgs)
)
)
def __enter__(self):
self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.untied_parameters_and_buffers, allow_missing=True
)
def __exit__(self, exception_type, exception_value, traceback):
if self.stack_weights:
# When stacking is enabled, we will restore the weights in LIFO order.
self.orig_parameters_and_buffers = dict(
reversed(self.orig_parameters_and_buffers.items())
)
new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict(
self.orig_parameters_and_buffers, allow_missing=True
)
# Sometimes the module is not completely stateless and has some in-place modifications on
# the _parameters and _buffers dictionaries.
# Write the changed parameters and buffers back to the original dict.
self.parameters_and_buffers.update(
{
k: new_parameters_and_buffers[k]
for k in self.parameters_and_buffers
if k in new_parameters_and_buffers
}
)
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
*,
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
) -> _ReparametrizeModule:
return _ReparametrizeModule(
module,
parameters_and_buffers,
tie_weights=tie_weights,
strict=strict,
stack_weights=stack_weights,
)
@deprecated(
"`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 "
"and will be removed in a future version of PyTorch. "
"Please use `torch.func.functional_call` instead which is a drop-in replacement.",
category=FutureWarning,
)
def functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones.
.. warning::
This API is deprecated as of PyTorch 2.0 and will be removed in a future
version of PyTorch. Please use :func:`torch.func.functional_call` instead,
which is a drop-in replacement for this API.
.. note:: If the module has active parametrizations, passing a value in the
:attr:`parameters_and_buffers` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
in the `parameters_and_buffers` input.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # does self.foo = self.foo + 1
>>> print(mod.foo) # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
error. Default: False.
Returns:
Any: the result of calling ``module``.
"""
return _functional_call(
module,
parameters_and_buffers,
args,
kwargs,
tie_weights=tie_weights,
strict=strict,
)
def _functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Optional[Dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
):
# TODO allow kwargs such as unsafe and others for parametrization
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(
module,
(
torch.jit.RecursiveScriptModule,
torch.jit.ScriptModule,
torch.jit.ScriptFunction,
),
)
):
raise RuntimeError("The stateless API can't be used with Jitted modules")
if isinstance(module, torch.nn.DataParallel):
raise RuntimeError(
"The stateless API can't be used with nn.DataParallel module"
)
if kwargs is None:
kwargs = {}
if not isinstance(args, tuple):
args = (args,)
with _reparametrize_module(
module, parameters_and_buffers, tie_weights=tie_weights, strict=strict
):
return module(*args, **kwargs)

View File

@ -0,0 +1,164 @@
# mypy: allow-untyped-defs
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
from typing import Any, TypeVar
from typing_extensions import deprecated
from torch import _weight_norm, norm_except_dim
from torch.nn.modules import Module
from torch.nn.parameter import Parameter, UninitializedParameter
__all__ = ["WeightNorm", "weight_norm", "remove_weight_norm"]
class WeightNorm:
name: str
dim: int
def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim
# TODO Make return type more specific
def compute_weight(self, module: Module) -> Any:
g = getattr(module, self.name + "_g")
v = getattr(module, self.name + "_v")
return _weight_norm(v, g, self.dim)
@staticmethod
@deprecated(
"`torch.nn.utils.weight_norm` is deprecated "
"in favor of `torch.nn.utils.parametrizations.weight_norm`.",
category=FutureWarning,
)
def apply(module, name: str, dim: int) -> "WeightNorm":
for hook in module._forward_pre_hooks.values():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(
f"Cannot register two weight_norm hooks on the same parameter {name}"
)
if dim is None:
dim = -1
fn = WeightNorm(name, dim)
weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
"The module passed to `WeightNorm` can't have uninitialized parameters. "
"Make sure to run the dummy forward before applying weight normalization"
)
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(
name + "_g", Parameter(norm_except_dim(weight, 2, dim).data)
)
module.register_parameter(name + "_v", Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module: Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + "_g"]
del module._parameters[self.name + "_v"]
setattr(module, self.name, Parameter(weight.data))
def __call__(self, module: Module, inputs: Any) -> None:
setattr(module, self.name, self.compute_weight(module))
T_module = TypeVar("T_module", bound=Module)
def weight_norm(module: T_module, name: str = "weight", dim: int = 0) -> T_module:
r"""Apply weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
.. warning::
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
which uses the modern parametrization API. The new ``weight_norm`` is compatible
with ``state_dict`` generated from old ``weight_norm``.
Migration guide:
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
respectively. If this is bothering you, please comment on
https://github.com/pytorch/pytorch/issues/102999
* To remove the weight normalization reparametrization, use
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
* The weight is no longer recomputed once at module forward; instead, it will
be recomputed on every access. To restore the old behavior, use
:func:`torch.nn.utils.parametrize.cached` before invoking the module
in question.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
WeightNorm.apply(module, name, dim)
return module
def remove_weight_norm(module: T_module, name: str = "weight") -> T_module:
r"""Remove the weight normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError(f"weight_norm of '{name}' not found in {module}")