I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,62 @@
# mypy: allow-untyped-defs
from torch.nn.parameter import ( # usort: skip
Buffer as Buffer,
Parameter as Parameter,
UninitializedBuffer as UninitializedBuffer,
UninitializedParameter as UninitializedParameter,
)
from torch.nn.modules import * # usort: skip # noqa: F403
from torch.nn import (
attention as attention,
functional as functional,
init as init,
modules as modules,
parallel as parallel,
parameter as parameter,
utils as utils,
)
from torch.nn.parallel import DataParallel as DataParallel
def factory_kwargs(kwargs):
r"""Return a canonicalized dict of factory kwargs.
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
This function makes it simple to write code like this::
class MyModule(nn.Module):
def __init__(self, **kwargs):
factory_kwargs = torch.nn.factory_kwargs(kwargs)
self.weight = Parameter(torch.empty(10, **factory_kwargs))
Why should you use this function instead of just passing `kwargs` along directly?
1. This function does error validation, so if there are unexpected kwargs we will
immediately report an error, instead of deferring it to the factory call
2. This function supports a special `factory_kwargs` argument, which can be used to
explicitly specify a kwarg to be used for factory functions, in the event one of the
factory kwargs conflicts with an already existing argument in the signature (e.g.
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
functions, as distinct from the dtype argument, by saying
``f(dtype1, factory_kwargs={"dtype": dtype2})``)
"""
if kwargs is None:
return {}
simple_keys = {"device", "dtype", "memory_format"}
expected_keys = simple_keys | {"factory_kwargs"}
if not kwargs.keys() <= expected_keys:
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
# guarantee no input kwargs is untouched
r = dict(kwargs.get("factory_kwargs", {}))
for k in simple_keys:
if k in kwargs:
if k in r:
raise TypeError(
f"{k} specified twice, in **kwargs and in factory_kwargs"
)
r[k] = kwargs[k]
return r

View File

@ -0,0 +1,60 @@
import warnings
from typing import Optional
# NB: Keep this file in sync with enums in aten/src/ATen/core/Reduction.h
def get_enum(reduction: str) -> int:
if reduction == "none":
ret = 0
elif reduction == "mean":
ret = 1
elif reduction == "elementwise_mean":
warnings.warn(
"reduction='elementwise_mean' is deprecated. "
"Please use reduction='mean' instead."
)
ret = 1
elif reduction == "sum":
ret = 2
else:
ret = -1 # TODO: remove once JIT exceptions support control flow
raise ValueError(f"{reduction} is not a valid value for reduction")
return ret
# In order to support previous versions, accept boolean size_average and reduce
# and convert them into the new constants for now
# We use these functions in torch/legacy as well, in which case we'll silence the warning
def legacy_get_string(
size_average: Optional[bool],
reduce: Optional[bool],
emit_warning: bool = True,
) -> str:
warning = "size_average and reduce args will be deprecated, please use reduction='{}' instead."
if size_average is None:
size_average = True
if reduce is None:
reduce = True
if size_average and reduce:
ret = "mean"
elif reduce:
ret = "sum"
else:
ret = "none"
if emit_warning:
warnings.warn(warning.format(ret))
return ret
def legacy_get_enum(
size_average: Optional[bool],
reduce: Optional[bool],
emit_warning: bool = True,
) -> int:
return get_enum(legacy_get_string(size_average, reduce, emit_warning))

View File

@ -0,0 +1,129 @@
# mypy: allow-untyped-defs
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from typing import List, Union
from warnings import warn
from torch._C import _SDPBackend as SDPBackend
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
cudnn_sdp_enabled,
enable_cudnn_sdp,
enable_flash_sdp,
enable_math_sdp,
enable_mem_efficient_sdp,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
# This only effects users of bias subclasses
# If this is set to True, we will warn the user if they are not using the fused kernels
# As well, it will raise warnings for all the reasons why the fused kernels can't be run.
# To set this to True, run
# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
WARN_FOR_UNFUSED_KERNELS = False
# Hacks for Sphinx documentation:
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
SDPBackend = SDPBackend
r"""An enum-like class that contains the different backends for scaled dot product attention.
This backend class is designed to be used with the sdpa_kernel context manager.
The following Enums are available:
- ERROR: An error occurred when trying to determine the backend.
- MATH: The math backend for scaled dot product attention.
- FLASH_ATTENTION: The flash attention backend for scaled dot product attention.
- EFFICIENT_ATTENTION: The efficient attention backend for scaled dot product attention.
- CUDNN_ATTENTION: The cuDNN backend for scaled dot product attention.
See :func:`torch.nn.attention.sdpa_kernel` for more details.
.. warning:: This class is in beta and subject to change.
"""
SDPBackend.__module__ = __name__
SDPBackend.__name__ = "SDPBackend"
def _raise_kernel_warnings(params: SDPAParams) -> None:
"""
If WARN_FOR_UNFUSED_KERNELS is set to True, this will raise warnings
for all the reasons why the fused kernels can't be run. If using subclasses
"""
if WARN_FOR_UNFUSED_KERNELS:
if not can_use_efficient_attention(params):
warn("Efficient attention can't be used because:")
can_use_efficient_attention(params, True)
if not can_use_flash_attention(params):
warn("Flash attention can't be used because:")
can_use_flash_attention(params, True)
@contextlib.contextmanager
def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
r"""
Context manager to select which backend to use for scaled dot product attention.
.. warning:: This function is beta and subject to change.
Args:
backend (Union[List[SDPBackend], SDPBackend]): A backend or list of backends for scaled dot product attention.
Example:
.. code-block:: python
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
# Only enable flash attention backend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
scaled_dot_product_attention(...)
# Enable the Math or Efficient attention backends
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
scaled_dot_product_attention(...)
This context manager can be used to select which backend to use for scaled dot product attention.
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
"""
assert isinstance(
backends, (list, SDPBackend)
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
if isinstance(backends, SDPBackend):
backends = [backends]
backends = set(backends)
previous_cudnn: bool = cudnn_sdp_enabled()
previous_flash: bool = flash_sdp_enabled()
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
previous_math: bool = math_sdp_enabled()
try:
enable_cudnn = SDPBackend.CUDNN_ATTENTION in backends
enable_flash = SDPBackend.FLASH_ATTENTION in backends
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
enable_math = SDPBackend.MATH in backends
enable_cudnn_sdp(enable_cudnn)
enable_flash_sdp(enable_flash)
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
yield {}
finally:
enable_cudnn_sdp(previous_cudnn)
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math)
def _get_flash_version() -> str:
"""This returns the closest matching tag for the flash attention backend"""
return "2.5.7"

View File

@ -0,0 +1,67 @@
# mypy: allow-untyped-defs
"""Defines utilities for interacting with scaled_dot_product_attention"""
import math
from typing import List, Optional, Union
import torch
__all__: List[str] = []
def _input_requires_grad(*tensors: torch.Tensor) -> bool:
"""Returns True if any of the tensors requires grad"""
return any(t.requires_grad for t in tensors)
def _postprocess_flash_output(inpt_tensor: torch.Tensor, og_size: int) -> torch.Tensor:
"""Handles the unpad of the last dimension"""
if inpt_tensor.size(-1) != og_size:
return inpt_tensor[..., :og_size]
return inpt_tensor
def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float:
"""
For FlashAttention we pad the head dimension to be a multiple of 8 so we need to scale the output
by the original head size and not the padded.
"""
if scale is not None:
return scale
return 1.0 / math.sqrt(head_dim_size)
_SUPPORTED_HEAD_DIMS = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
def _supported_head_dim(n: Union[int, torch.SymInt]) -> bool:
"""Returns true if the head dim is supported by FlexAttention"""
return n in _SUPPORTED_HEAD_DIMS
def _validate_sdpa_input(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
f"Expected query, key, and value to have the same dtype, "
f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
f"and value.dtype: {value.dtype} instead."
)
if query.device != key.device or query.device != value.device:
raise ValueError(
f"Expected query, key, and value to have the same device type, "
f"but got query.device: {query.device}, key.device: {key.device}, "
f"and value.device: {value.device} instead."
)
if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
raise ValueError(
f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
)

View File

@ -0,0 +1,362 @@
# mypy: allow-untyped-defs
"""Defines bias subclasses that work with scaled_dot_product_attention"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn
import torch
import torch.nn.functional as F
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
is_flash_attention_available,
SDPAParams,
)
from torch.nn.attention import _raise_kernel_warnings
from torch.nn.attention._utils import (
_calculate_scale,
_input_requires_grad,
_postprocess_flash_output,
_validate_sdpa_input,
)
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
torch._dynamo.allow_in_graph(is_flash_attention_available)
torch._dynamo.allow_in_graph(can_use_flash_attention)
torch._dynamo.allow_in_graph(can_use_efficient_attention)
torch._dynamo.allow_in_graph(SDPAParams)
class CausalVariant(IntEnum):
r"""
Enum for causal variants used in attention mechanisms.
Defines two types of causal biases:
`UPPER_LEFT`: Represents upper-left triangular bias for standard causal attention.
The equivalent pytorch code for constructing this bias is:
.. code-block:: python
torch.tril(torch.ones(size, dtype=torch.bool))
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
.. code-block:: text
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]
`LOWER_RIGHT`: Represents lower-right triangular bias, the include values are aligned to the lower
right corner of the matrix.
The equivalent pytorch code for constructing this bias is:
.. code-block:: python
diagonal_offset = size[1] - size[0]
torch.tril(
torch.ones(size, dtype=torch.bool),
diagonal=diagonal_offset,
)
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
.. code-block:: text
[[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
Note that these variants are equivalent to each other when the sequence lengths of the query and key/value
tensors are equal since the triangular matrix is square.
.. warning:: This enum is a prototype and subject to change.
"""
UPPER_LEFT = auto()
LOWER_RIGHT = auto()
class CausalBias(torch.Tensor):
"""
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
This class is used for defining causal (triangular) attention biases. For construing the bias, there exist
two factory functions: :func:`causal_upper_left` and :func:`causal_lower_right`.
Example:
.. code-block:: python
from torch.nn.attention.bias import causal_lower_right
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
.. warning:: This class is a prototype and subject to change.
"""
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
"""
Initializes the CausalBias instance with a specified variant and sequence lengths.
Args:
variant (CausalVariant): The type of causal bias to use (either UPPER_LEFT or LOWER_RIGHT).
seq_len_q (int): The sequence length of the query tensor.
seq_len_kv (int): The sequence length of the key/value tensor.
Raises a warning if the LOWER_RIGHT variant is used with seq_len_q > seq_len_kv, as it may produce NaNs.
"""
assert isinstance(variant, CausalVariant)
self.variant = variant
self.seq_len_q = seq_len_q
self.seq_len_kv = seq_len_kv
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
warn(
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
)
def _upper_left(self, device: torch.device) -> torch.Tensor:
"""Upper left causal bias"""
return torch.tril(
torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
)
def _lower_right(self, device: torch.device) -> torch.Tensor:
"""Lower right causal bias"""
diagonal_offset = self.seq_len_kv - self.seq_len_q
return torch.tril(
torch.ones(
self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
),
diagonal=diagonal_offset,
)
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
"""
Materializes the causal bias into a tensor form.
Depending on the variant, this method generates either an upper-left or lower-right
triangular matrix to represent the causal bias.
Args:
device (Optional[torch.device]): The device on which to create the tensor. Defaults to CPU.
Returns:
torch.Tensor: The materialized bias tensor.
"""
if device is None:
device = torch.device("cpu")
if self.variant == CausalVariant.UPPER_LEFT:
return self._upper_left(device)
elif self.variant == CausalVariant.LOWER_RIGHT:
return self._lower_right(device)
@staticmethod
def _dispatch(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: "CausalBias",
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
r"""
Handles the logic for computing attention with the specified causal bias.
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (CausalBias): The type of causal attention to apply.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes upper left causal attention masking and errors if both attn_mask and is_causal
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
Raises:
ValueError: If the causal bias variant is not a CausalVariant type.
"""
if is_causal:
raise ValueError("CausalBias should not be used with causal=True")
if (
attn_mask.seq_len_q == attn_mask.seq_len_kv
or attn_mask.variant == CausalVariant.UPPER_LEFT
):
return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=enable_gqa,
)
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(
query, key, value, None, dropout_p, is_causal, enable_gqa
)
if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0
og_head_size = query.size(-1)
og_scale = _calculate_scale(og_head_size, scale)
if needs_padding:
query = torch.nn.functional.pad(query, (0, 8 - query.size(-1) % 8))
key = torch.nn.functional.pad(key, (0, 8 - key.size(-1) % 8))
value = torch.nn.functional.pad(value, (0, 8 - value.size(-1) % 8))
out = torch.ops.aten._scaled_dot_product_flash_attention(
query,
key,
value,
dropout_p,
is_causal=True, # TODO: Flash accepts causal = True and for this particular op it means lower right
return_debug_mask=False,
scale=og_scale,
)[0]
return _postprocess_flash_output(out, og_head_size)
if can_use_efficient_attention(sdpa_params):
compute_log_sumexp = False
if _input_requires_grad(query, key, value):
compute_log_sumexp = True
return torch.ops.aten._efficient_attention_forward(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
bias=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
dropout_p=dropout_p,
custom_mask_type=int(attn_mask.variant),
compute_log_sumexp=compute_log_sumexp,
scale=scale,
seqlen_k=None,
)[0].transpose(1, 2)
else:
_raise_kernel_warnings(sdpa_params)
# We cant use efficient attention the only support for lower right is via materialization
return F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask._materialize(query.device),
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=enable_gqa,
)
else:
raise ValueError(
f"CausalBias.variant must be a CausalVariant type, but found: {attn_mask.variant}"
)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias"""
if kwargs is None:
kwargs = {}
if func != torch.nn.functional.scaled_dot_product_attention:
raise NotImplementedError(
"CausalBias only supports scaled_dot_product_attention"
)
return cls._dispatch(*args, **kwargs)
def __repr__(self):
return self._materialize().__repr__()
def causal_upper_left(*size) -> CausalBias:
"""
Creates an upper-left triangular causal bias.
This function generates a upper-left triangular matrix to represent causal attention bias with a
diagonal offset set so that the inclusive values are aligned to the upper left corner of the matrix.
This equivalent to the `is_causal=True` argument in `scaled_dot_product_attention`.
The equivalent pytorch code for constructing this bias is:
.. code-block:: python
torch.tril(torch.ones(size, dtype=torch.bool))
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
.. code-block:: text
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]
Args:
size: The size of the bias matrix.
Returns:
CausalBias: The UPPER_LEFT triangular causal bias variant.
"""
assert len(size) == 2, "causal_upper_left only supports 2D tensors"
seq_len_q, seq_len_kv = size
return CausalBias(CausalVariant.UPPER_LEFT, seq_len_q, seq_len_kv)
def causal_lower_right(*size) -> CausalBias:
"""
Creates a lower-right triangular causal bias.
This function generates a lower-right triangular matrix to represent causal attention bias with a
diagonal offset set so that the inclusive values are aligned to the lower right corner of the matrix.
The equivalent pytorch code for constructing this bias is:
.. code-block:: python
diagonal_offset = size[1] - size[0]
torch.tril(
torch.ones(size, dtype=torch.bool),
diagonal=diagonal_offset,
)
For instance, with `shape=(3,4)`, the materialized bias tensor will be:
.. code-block:: text
[[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
Args:
size: The size of the bias matrix.
Returns:
CausalBias: The LOWER_RIGHT triangular causal bias variant.
"""
assert len(size) == 2, "causal_lower_right only supports 2D tensors"
seq_len_q, seq_len_kv = size
return CausalBias(CausalVariant.LOWER_RIGHT, seq_len_q, seq_len_kv)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
# mypy: allow-untyped-defs
# this is for historical pickle deserialization, it is not used otherwise
def _get_thnn_function_backend():
pass

View File

@ -0,0 +1,44 @@
from typing import Optional, Tuple, TypeVar, Union
from torch import Tensor
# Create some useful type aliases
# Template for arguments which can be supplied as a tuple, or which can be a scalar which PyTorch will internally
# broadcast to a tuple.
# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
T = TypeVar("T")
_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
_scalar_or_tuple_1_t = Union[T, Tuple[T]]
_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
# For arguments which represent size parameters (eg, kernel size, padding)
_size_any_t = _scalar_or_tuple_any_t[int]
_size_1_t = _scalar_or_tuple_1_t[int]
_size_2_t = _scalar_or_tuple_2_t[int]
_size_3_t = _scalar_or_tuple_3_t[int]
_size_4_t = _scalar_or_tuple_4_t[int]
_size_5_t = _scalar_or_tuple_5_t[int]
_size_6_t = _scalar_or_tuple_6_t[int]
# For arguments which represent optional size parameters (eg, adaptive pool parameters)
_size_any_opt_t = _scalar_or_tuple_any_t[Optional[int]]
_size_2_opt_t = _scalar_or_tuple_2_t[Optional[int]]
_size_3_opt_t = _scalar_or_tuple_3_t[Optional[int]]
# For arguments that represent a ratio to adjust each dimension of an input with (eg, upsampling parameters)
_ratio_2_t = _scalar_or_tuple_2_t[float]
_ratio_3_t = _scalar_or_tuple_3_t[float]
_ratio_any_t = _scalar_or_tuple_any_t[float]
_tensor_list_t = _scalar_or_tuple_any_t[Tensor]
# For the return value of max pooling operations that may or may not return indices.
# With the proposed 'Literal' feature to Python typing, it might be possible to
# eventually eliminate this.
_maybe_indices_t = _scalar_or_tuple_2_t[Tensor]

View File

@ -0,0 +1,89 @@
# mypy: allow-untyped-defs
"""Functionality for Python <-> C++ frontend inter-op."""
from torch import nn
class OrderedDictWrapper:
"""A wrapper around a C++ OrderedDict.
It dynamically evaluates the OrderedDict getter on a bound C++ module, such
that new changes on the C++ side are picked up. Otherwise accessing e.g.
``cpp_module._parameters`` just once would get a frozen copy of the parameters
at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
so using properties does not work.
"""
def __init__(self, cpp_module, attr):
self.cpp_module = cpp_module
self.attr = attr
@property
def cpp_dict(self):
return getattr(self.cpp_module, self.attr)
# Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
# must manually override them.
def items(self):
return self.cpp_dict.items()
def keys(self):
return self.cpp_dict.keys()
def values(self):
return self.cpp_dict.values()
def __iter__(self):
return self.cpp_dict.__iter__()
def __len__(self):
return self.cpp_dict.__len__()
def __contains__(self, key):
return self.cpp_dict.__contains__(key)
def __getitem__(self, key):
return self.cpp_dict.__getitem__(key)
class ModuleWrapper(nn.Module):
"""A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
def __init__(self, cpp_module):
# Assign before the super class constructor so ``self.training`` can be
# assigned to in the super class constructor.
self.cpp_module = cpp_module
super().__init__()
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
for attr in dir(cpp_module):
# Skip magic methods and the three attributes above.
if not attr.startswith("_"):
setattr(self, attr, getattr(self.cpp_module, attr))
def _apply(self, fn, recurse=True):
for param in self.parameters():
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param._grad is not None:
param._grad.data = fn(param._grad.data)
for buf in self.buffers():
buf.data = fn(buf.data)
return self
# nn.Module defines training as a boolean
@property # type: ignore[override]
def training(self):
return self.cpp_module.training
@training.setter
def training(self, mode):
self.cpp_module.train(mode)
def __repr__(self):
return self.cpp_module.__repr__()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,691 @@
# @generated by tools/pyi/gen_pyi.py from torch/nn/functional.pyi.in
# mypy: allow-untyped-defs
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
overload,
Sequence,
Tuple,
Union,
)
from torch import Tensor
from torch.types import _dtype, _int, _size
from .common_types import (
_ratio_any_t,
_size_1_t,
_size_2_opt_t,
_size_2_t,
_size_3_opt_t,
_size_3_t,
_size_any_t,
)
# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys.
# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature
# is wide-spread.
# from mypy_extensions import TypedDict
# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int})
# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int})
GRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int]
GRID_SAMPLE_PADDING_MODES = Dict[str, int]
# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning.
#
# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate.
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
#
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
def fractional_max_pool2d_with_indices(
input: Tensor,
kernel_size: _size,
output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ...,
return_indices: bool = ...,
_random_samples: Optional[Tensor] = ...,
) -> Tuple[Tensor, Tensor]: ...
def fractional_max_pool3d_with_indices(
input: Tensor,
kernel_size: _size,
output_size: Optional[_size] = ...,
output_ratio: Optional[_ratio_any_t] = ...,
return_indices: bool = ...,
_random_samples: Optional[Tensor] = ...,
) -> Tuple[Tensor, Tensor]: ...
def max_pool1d_with_indices(
input: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
dilation: _size = ...,
ceil_mode: bool = ...,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def max_pool2d_with_indices(
input: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
dilation: _size = ...,
ceil_mode: bool = ...,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def max_pool3d_with_indices(
input: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
dilation: _size = ...,
ceil_mode: bool = ...,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def max_unpool1d(
input: Tensor,
indices: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
output_size: Optional[_size] = ...,
) -> Tensor: ...
def max_unpool2d(
input: Tensor,
indices: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
output_size: Optional[_size] = ...,
) -> Tensor: ...
def max_unpool3d(
input: Tensor,
indices: Tensor,
kernel_size: _size,
stride: Optional[_size] = ...,
padding: _size = ...,
output_size: Optional[_size] = ...,
) -> Tensor: ...
def lp_pool1d(
input: Tensor,
norm_type: float,
kernel_size: _size_1_t,
stride: Union[Optional[_size], Optional[int]] = ...,
ceil_mode: bool = ...,
) -> Tensor: ...
def lp_pool2d(
input: Tensor,
norm_type: float,
kernel_size: _size_2_t,
stride: Union[Optional[_size], Optional[int]] = ...,
ceil_mode: bool = ...,
) -> Tensor: ...
def lp_pool3d(
input: Tensor,
norm_type: float,
kernel_size: _size_3_t,
stride: Union[Optional[_size], Optional[int]] = ...,
ceil_mode: bool = ...,
) -> Tensor: ...
def adaptive_max_pool1d_with_indices(
input: Tensor,
output_size: _size,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def adaptive_max_pool2d_with_indices(
input: Tensor,
output_size: _size_2_opt_t,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def adaptive_max_pool3d_with_indices(
input: Tensor,
output_size: _size_3_opt_t,
return_indices: bool = ...,
) -> Tuple[Tensor, Tensor]: ...
def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ...
def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ...
def dropout(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def alpha_dropout(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def dropout1d(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def dropout2d(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def dropout3d(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def feature_alpha_dropout(
input: Tensor,
p: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def threshold(
input: Tensor,
threshold: float,
value: float,
inplace: bool = ...,
) -> Tensor: ...
def relu(input: Tensor, inplace: bool = ...) -> Tensor: ...
def glu(input: Tensor, dim: int = ...) -> Tensor: ...
def hardtanh(
input: Tensor,
min_val: float = ...,
max_val: float = ...,
inplace: bool = ...,
) -> Tensor: ...
def relu6(input: Tensor, inplace: bool = ...) -> Tensor: ...
def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
def selu(input: Tensor, inplace: bool = ...) -> Tensor: ...
def celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
def leaky_relu(
input: Tensor,
negative_slope: float = ...,
inplace: bool = ...,
) -> Tensor: ...
def rrelu(
input: Tensor,
lower: float = ...,
upper: float = ...,
training: bool = ...,
inplace: bool = ...,
) -> Tensor: ...
def tanhshrink(input: Any): ...
def softsign(input: Any): ...
def softmin(
input: Tensor,
dim: Optional[int] = ...,
_stacklevel: int = ...,
dtype: Optional[_dtype] = ...,
) -> Tensor: ...
def softmax(
input: Tensor,
dim: Optional[int] = ...,
_stacklevel: int = ...,
dtype: Optional[_dtype] = ...,
) -> Tensor: ...
def gumbel_softmax(
logits: Tensor,
tau: float = ...,
hard: bool = ...,
eps: float = ...,
dim: int = ...,
) -> Tensor: ...
def log_softmax(
input: Tensor,
dim: Optional[int] = ...,
_stacklevel: int = ...,
dtype: Optional[_dtype] = ...,
) -> Tensor: ...
def tanh(input: Any): ...
def sigmoid(input: Any) -> Tensor: ...
def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ...
def silu(input: Tensor, inplace: bool = False) -> Tensor: ...
def mish(input: Tensor, inplace: bool = False) -> Tensor: ...
def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ...
def embedding(
input: Tensor,
weight: Tensor,
padding_idx: Optional[int] = ...,
max_norm: Optional[float] = ...,
norm_type: float = ...,
scale_grad_by_freq: bool = ...,
sparse: bool = ...,
) -> Tensor: ...
def embedding_bag(
input: Tensor,
weight: Tensor,
offsets: Optional[Tensor] = ...,
max_norm: Optional[float] = ...,
norm_type: float = ...,
scale_grad_by_freq: bool = ...,
mode: str = ...,
sparse: bool = ...,
per_sample_weights: Optional[Tensor] = ...,
include_last_offset: bool = ...,
padding_idx: Optional[int] = ...,
) -> Tensor: ...
def batch_norm(
input: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
weight: Optional[Tensor] = ...,
bias: Optional[Tensor] = ...,
training: bool = ...,
momentum: float = ...,
eps: float = ...,
) -> Tensor: ...
def instance_norm(
input: Tensor,
running_mean: Optional[Tensor] = ...,
running_var: Optional[Tensor] = ...,
weight: Optional[Tensor] = ...,
bias: Optional[Tensor] = ...,
use_input_stats: bool = ...,
momentum: float = ...,
eps: float = ...,
) -> Tensor: ...
def layer_norm(
input: Tensor,
normalized_shape: Sequence[int],
weight: Optional[Tensor] = ...,
bias: Optional[Tensor] = ...,
eps: float = ...,
) -> Tensor: ...
def rms_norm(
input: Tensor,
normalized_shape: Sequence[int],
weight: Optional[Tensor] = ...,
eps: Optional[float] = ...,
) -> Tensor: ...
def group_norm(
input: Tensor,
num_groups: int,
weight: Optional[Tensor] = ...,
bias: Optional[Tensor] = ...,
eps: float = ...,
) -> Tensor: ...
def local_response_norm(
input: Tensor,
size: int,
alpha: float = ...,
beta: float = ...,
k: float = ...,
) -> Tensor: ...
def ctc_loss(
log_probs: Tensor,
targets: Tensor,
input_lengths: Tensor,
target_lengths: Tensor,
blank: int = ...,
reduction: str = ...,
zero_infinity: bool = ...,
) -> Tensor: ...
def nll_loss(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
ignore_index: int = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def poisson_nll_loss(
input: Tensor,
target: Tensor,
log_input: bool = ...,
full: bool = ...,
size_average: Optional[bool] = ...,
eps: float = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def gaussian_nll_loss(
input: Tensor,
target: Tensor,
var: Tensor,
full: Optional[bool] = ...,
eps: Optional[float] = ...,
reduction: Optional[str] = ...,
) -> Tensor: ...
def kl_div(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
log_target: bool = ...,
) -> Tensor: ...
def cross_entropy(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
ignore_index: int = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
label_smoothing: float = ...,
) -> Tensor: ...
def binary_cross_entropy(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def binary_cross_entropy_with_logits(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
pos_weight: Optional[Tensor] = ...,
) -> Tensor: ...
def smooth_l1_loss(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
beta: float = ...,
) -> Tensor: ...
def huber_loss(
input: Tensor,
target: Tensor,
reduction: str = ...,
delta: float = ...,
) -> Tensor: ...
def l1_loss(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def mse_loss(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def margin_ranking_loss(
input1: Tensor,
input2: Tensor,
target: Tensor,
margin: float = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def hinge_embedding_loss(
input: Tensor,
target: Tensor,
margin: float = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def multilabel_margin_loss(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def soft_margin_loss(
input: Tensor,
target: Tensor,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def multilabel_soft_margin_loss(
input: Tensor,
target: Tensor,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def cosine_embedding_loss(
input1: Tensor,
input2: Tensor,
target: Tensor,
margin: float = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def multi_margin_loss(
input: Tensor,
target: Tensor,
p: int = ...,
margin: float = ...,
weight: Optional[Tensor] = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def upsample(
input: Any,
size: Optional[Any] = ...,
scale_factor: Optional[Any] = ...,
mode: str = ...,
align_corners: Optional[Any] = ...,
): ...
def interpolate(
input: Any,
size: Optional[Any] = ...,
scale_factor: Optional[Any] = ...,
mode: str = ...,
align_corners: Optional[Any] = ...,
recompute_scale_factor: Optional[Any] = ...,
antialias: bool = ...,
): ...
def upsample_nearest(
input: Any,
size: Optional[Any] = ...,
scale_factor: Optional[Any] = ...,
): ...
def upsample_bilinear(
input: Any,
size: Optional[Any] = ...,
scale_factor: Optional[Any] = ...,
): ...
def grid_sample(
input: Tensor,
grid: Tensor,
mode: str = ...,
padding_mode: str = ...,
align_corners: Optional[Any] = ...,
) -> Tensor: ...
def affine_grid(
theta: Tensor,
size: List[int],
align_corners: Optional[Any] = ...,
) -> Tensor: ...
def triplet_margin_loss(
anchor: Tensor,
positive: Tensor,
negative: Tensor,
margin: float = ...,
p: float = ...,
eps: float = ...,
swap: bool = ...,
size_average: Optional[bool] = ...,
reduce: Optional[bool] = ...,
reduction: str = ...,
) -> Tensor: ...
def triplet_margin_with_distance_loss(
anchor: Tensor,
positive: Tensor,
negative: Tensor,
*,
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ...,
margin: float = ...,
swap: bool = ...,
reduction: str = ...,
) -> Tensor: ...
def normalize(
input: Tensor,
p: float = ...,
dim: int = ...,
eps: float = ...,
out: Optional[Tensor] = ...,
) -> Tensor: ...
def assert_int_or_pair(
arg: Any,
arg_name: Any,
message: Any,
) -> None: ...
def unfold(
input: Tensor,
kernel_size: _size_any_t,
dilation: _size_any_t = ...,
padding: _size_any_t = ...,
stride: _size_any_t = ...,
) -> Tensor: ...
def fold(
input: Tensor,
output_size: _size_any_t,
kernel_size: _size_any_t,
dilation: _size_any_t = ...,
padding: _size_any_t = ...,
stride: _size_any_t = ...,
) -> Tensor: ...
def _canonical_mask(
mask: Optional[Tensor],
mask_name: str,
other_type: Optional[_dtype],
other_name: str,
target_type: _dtype,
check_other: bool = True,
) -> Optional[Tensor]: ...
def _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ...
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]: ...
from torch import conv1d as conv1d
from torch import conv2d as conv2d
from torch import conv3d as conv3d
from torch import conv_transpose1d as conv_transpose1d
from torch import conv_transpose2d as conv_transpose2d
from torch import conv_transpose3d as conv_transpose3d
from torch import conv_tbc as conv_tbc
from torch import avg_pool1d as avg_pool1d
from torch import adaptive_avg_pool1d as adaptive_avg_pool1d
from torch import relu_ as relu_
from torch import selu_ as selu_
from torch import celu_ as celu_
from torch import prelu as prelu
from torch import rrelu_ as rrelu_
from torch import hardshrink as hardshrink
from torch import bilinear as bilinear
from torch import pixel_shuffle as pixel_shuffle
from torch import pixel_unshuffle as pixel_unshuffle
from torch import channel_shuffle as channel_shuffle
from torch import native_channel_shuffle as native_channel_shuffle
from torch import pairwise_distance as pairwise_distance
from torch import pdist as pdist
from torch import cosine_similarity as cosine_similarity
from torch._C._nn import avg_pool2d as avg_pool2d
from torch._C._nn import avg_pool3d as avg_pool3d
from torch._C._nn import hardtanh_ as hardtanh_
from torch._C._nn import elu_ as elu_
from torch._C._nn import leaky_relu_ as leaky_relu_
from torch._C._nn import gelu as gelu
from torch._C._nn import softplus as softplus
from torch._C._nn import softshrink as softshrink
from torch._C._nn import linear as linear
from torch._C._nn import pad as pad
from torch._C._nn import one_hot as one_hot
from torch._C._nn import scaled_dot_product_attention as scaled_dot_product_attention
from torch._C._nn import log_sigmoid
logsigmoid = log_sigmoid
@overload
def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ...
@overload
def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def adaptive_max_pool1d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...
@overload
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ...
@overload
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...
@overload
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[False] = False) -> Tensor: ...
@overload
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size], *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...
@overload
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ...
@overload
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ...
@overload
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ...
@overload
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, return_indices: Literal[False] = False, _random_samples: Optional[Tensor] = None) -> Tensor: ...
@overload
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]], output_ratio: Optional[_ratio_any_t], return_indices: Literal[True], /, _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ...
@overload
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Optional[Union[_int, _size]] = None, output_ratio: Optional[_ratio_any_t] = None, *, return_indices: Literal[True], _random_samples: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ...
@overload
def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool1d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ...
@overload
def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, return_indices: Literal[False] = False) -> Tensor: ...
@overload
def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]], padding: Union[_int, _size], dilation: Union[_int, _size], ceil_mode: bool, return_indices: Literal[True], /) -> Tuple[Tensor, Tensor]: ...
@overload
def max_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, dilation: Union[_int, _size] = 1, ceil_mode: bool = False, *, return_indices: Literal[True]) -> Tuple[Tensor, Tensor]: ...

View File

@ -0,0 +1,298 @@
# mypy: allow-untyped-defs
"""Gradient interface."""
import torch
from torch.nn.modules.utils import _pair, _single, _triple
def conv1d_input(
input_size,
weight,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv1d with respect to the input of the convolution.
This is same as the 1D transposed convolution operator under the hood but requires
the shape of the gradient w.r.t. input to be specified explicitly.
Args:
input_size : Shape of the input gradient tensor
weight: weight tensor (out_channels x in_channels/groups x kW)
grad_output : output gradient tensor (minibatch x out_channels x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(1, 1, 3, requires_grad=True)
>>> weight = torch.randn(1, 1, 1, requires_grad=True)
>>> output = F.conv1d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> grad_input = torch.autograd.grad(output, input, grad_output)
>>> F.grad.conv1d_input(input.shape, weight, grad_output)
"""
input = grad_output.new_empty(1).expand(input_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_single(stride),
_single(padding),
_single(dilation),
False,
[0],
groups,
(True, False, False),
)[0]
def conv1d_weight(
input,
weight_size,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv1d with respect to the weight of the convolution.
Args:
input: input tensor of shape (minibatch x in_channels x iW)
weight_size : Shape of the weight gradient tensor
grad_output : output gradient tensor (minibatch x out_channels x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(1, 1, 3, requires_grad=True)
>>> weight = torch.randn(1, 1, 1, requires_grad=True)
>>> output = F.conv1d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> # xdoctest: +SKIP
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
>>> F.grad.conv1d_weight(input, weight.shape, grad_output)
"""
weight = grad_output.new_empty(1).expand(weight_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_single(stride),
_single(padding),
_single(dilation),
False,
[0],
groups,
(False, True, False),
)[1]
def conv2d_input(
input_size,
weight,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv2d with respect to the input of the convolution.
This is same as the 2D transposed convolution operator under the hood but requires
the shape of the gradient w.r.t. input to be specified explicitly.
Args:
input_size : Shape of the input gradient tensor
weight: weight tensor (out_channels x in_channels/groups x kH x kW)
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
>>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
>>> output = F.conv2d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> grad_input = torch.autograd.grad(output, input, grad_output)
>>> F.grad.conv2d_input(input.shape, weight, grad_output)
"""
input = grad_output.new_empty(1).expand(input_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_pair(stride),
_pair(padding),
_pair(dilation),
False,
[0],
groups,
(True, False, False),
)[0]
def conv2d_weight(
input,
weight_size,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv2d with respect to the weight of the convolution.
Args:
input: input tensor of shape (minibatch x in_channels x iH x iW)
weight_size : Shape of the weight gradient tensor
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(1, 1, 3, 3, requires_grad=True)
>>> weight = torch.randn(1, 1, 1, 2, requires_grad=True)
>>> output = F.conv2d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> # xdoctest: +SKIP
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
>>> F.grad.conv2d_weight(input, weight.shape, grad_output)
"""
weight = grad_output.new_empty(1).expand(weight_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_pair(stride),
_pair(padding),
_pair(dilation),
False,
[0],
groups,
(False, True, False),
)[1]
def conv3d_input(
input_size,
weight,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv3d with respect to the input of the convolution.
This is same as the 3D transposed convolution operator under the hood but requires
the shape of the gradient w.r.t. input to be specified explicitly.
Args:
input_size : Shape of the input gradient tensor
weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
>>> output = F.conv3d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> grad_input = torch.autograd.grad(output, input, grad_output)
>>> F.grad.conv3d_input(input.shape, weight, grad_output)
"""
input = grad_output.new_empty(1).expand(input_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_triple(stride),
_triple(padding),
_triple(dilation),
False,
[0],
groups,
(True, False, False),
)[0]
def conv3d_weight(
input,
weight_size,
grad_output,
stride=1,
padding=0,
dilation=1,
groups=1,
):
r"""Compute the gradient of conv3d with respect to the weight of the convolution.
Args:
input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
weight_size : Shape of the weight gradient tensor
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
Examples::
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
>>> output = F.conv3d(input, weight)
>>> grad_output = torch.randn(output.shape)
>>> grad_weight = torch.autograd.grad(output, weight, grad_output)
>>> F.grad.conv3d_weight(input, weight.shape, grad_output)
"""
weight = grad_output.new_empty(1).expand(weight_size)
return torch.ops.aten.convolution_backward(
grad_output,
input,
weight,
None,
_triple(stride),
_triple(padding),
_triple(dilation),
False,
[0],
groups,
(False, True, False),
)[1]

View File

@ -0,0 +1,697 @@
# mypy: allow-untyped-defs
"""This file contains utilities for initializing neural network parameters."""
import math
import warnings
from typing import Optional as _Optional
import torch
from torch import Tensor
# These no_grad_* functions are necessary as wrappers around the parts of these
# functions that use `with torch.no_grad()`. The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
# lets us keep those builtins small and re-usable.
def _no_grad_uniform_(tensor, a, b, generator=None):
with torch.no_grad():
return tensor.uniform_(a, b, generator=generator)
def _no_grad_normal_(tensor, mean, std, generator=None):
with torch.no_grad():
return tensor.normal_(mean, std, generator=generator)
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def _no_grad_fill_(tensor, val):
with torch.no_grad():
return tensor.fill_(val)
def _no_grad_zero_(tensor):
with torch.no_grad():
return tensor.zero_()
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
SELU :math:`\frac{3}{4}`
================= ====================================================
.. warning::
In order to implement `Self-Normalizing Neural Networks`_ ,
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
This gives the initial weights a variance of ``1 / N``,
which is necessary to induce a stable fixed point in the forward pass.
In contrast, the default gain for ``SELU`` sacrifices the normalization
effect for more stable gradient flow in rectangular layers.
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
"""
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError(f"negative_slope {param} not a valid number")
return math.sqrt(2.0 / (1 + negative_slope**2))
elif nonlinearity == "selu":
return (
3.0 / 4
) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
def uniform_(
tensor: Tensor,
a: float = 0.0,
b: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the uniform distribution.
:math:`\mathcal{U}(a, b)`.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
)
return _no_grad_uniform_(tensor, a, b, generator)
def normal_(
tensor: Tensor,
mean: float = 0.0,
std: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the normal distribution.
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
)
return _no_grad_normal_(tensor, mean, std, generator)
def trunc_normal_(
tensor: Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from a truncated normal distribution.
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
def constant_(tensor: Tensor, val: float) -> Tensor:
r"""Fill the input Tensor with the value :math:`\text{val}`.
Args:
tensor: an n-dimensional `torch.Tensor`
val: the value to fill the tensor with
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
constant_, (tensor,), tensor=tensor, val=val
)
return _no_grad_fill_(tensor, val)
def ones_(tensor: Tensor) -> Tensor:
r"""Fill the input Tensor with the scalar value `1`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)
"""
return _no_grad_fill_(tensor, 1.0)
def zeros_(tensor: Tensor) -> Tensor:
r"""Fill the input Tensor with the scalar value `0`.
Args:
tensor: an n-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)
"""
return _no_grad_zero_(tensor)
def eye_(tensor):
r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
Preserves the identity of the inputs in `Linear` layers, where as
many inputs are preserved as possible.
Args:
tensor: a 2-dimensional `torch.Tensor`
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
with torch.no_grad():
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
return tensor
def dirac_(tensor, groups=1):
r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
Preserves the identity of the inputs in `Convolutional`
layers, where as many input channels are preserved as possible. In case
of groups>1, each group of channels preserves identity
Args:
tensor: a {3, 4, 5}-dimensional `torch.Tensor`
groups (int, optional): number of groups in the conv layer (default: 1)
Examples:
>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)
"""
dimensions = tensor.ndimension()
if dimensions not in [3, 4, 5]:
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
sizes = tensor.size()
if sizes[0] % groups != 0:
raise ValueError("dim 0 must be divisible by groups")
out_chans_per_grp = sizes[0] // groups
min_dim = min(out_chans_per_grp, sizes[1])
with torch.no_grad():
tensor.zero_()
for g in range(groups):
for d in range(min_dim):
if dimensions == 3: # Temporal convolution
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
elif dimensions == 4: # Spatial convolution
tensor[
g * out_chans_per_grp + d,
d,
tensor.size(2) // 2,
tensor.size(3) // 2,
] = 1
else: # Volumetric convolution
tensor[
g * out_chans_per_grp + d,
d,
tensor.size(2) // 2,
tensor.size(3) // 2,
tensor.size(4) // 2,
] = 1
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
)
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor.shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def xavier_uniform_(
tensor: Tensor,
gain: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
The method is described in `Understanding the difficulty of training
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
that the weight matrix is used in a transposed manner,
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
This is important for correct initialization.
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``.
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a, generator)
def xavier_normal_(
tensor: Tensor,
gain: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier normal distribution.
The method is described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
that the weight matrix is used in a transposed manner,
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
This is important for correct initialization.
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``.
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
return _no_grad_normal_(tensor, 0.0, std, generator)
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def kaiming_uniform_(
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
that the weight matrix is used in a transposed manner,
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
This is important for correct initialization.
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(
kaiming_uniform_,
(tensor,),
tensor=tensor,
a=a,
mode=mode,
nonlinearity=nonlinearity,
generator=generator,
)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound, generator=generator)
def kaiming_normal_(
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
The method is described in `Delving deep into rectifiers: Surpassing
human-level performance on ImageNet classification` - He, K. et al. (2015).
The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
that the weight matrix is used in a transposed manner,
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
This is important for correct initialization.
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
"""
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std, generator=generator)
def orthogonal_(
tensor,
gain=1,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
Described in `Exact solutions to the nonlinear dynamics of learning in deep
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
at least 2 dimensions, and for tensors with more than 2 dimensions the
trailing dimensions are flattened.
Args:
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
gain: optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)
"""
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
if tensor.numel() == 0:
# no-op
return tensor
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)
if rows < cols:
flattened.t_()
# Compute the qr factorization
q, r = torch.linalg.qr(flattened)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph
if rows < cols:
q.t_()
with torch.no_grad():
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
def sparse_(
tensor,
sparsity,
std=0.01,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the 2D input `Tensor` as a sparse matrix.
The non-zero elements will be drawn from the normal distribution
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
Hessian-free optimization` - Martens, J. (2010).
Args:
tensor: an n-dimensional `torch.Tensor`
sparsity: The fraction of elements in each column to be set to zero
std: the standard deviation of the normal distribution used to generate
the non-zero values
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)
"""
if tensor.ndimension() != 2:
raise ValueError("Only tensors with 2 dimensions are supported")
rows, cols = tensor.shape
num_zeros = int(math.ceil(sparsity * rows))
with torch.no_grad():
tensor.normal_(0, std, generator=generator)
for col_idx in range(cols):
row_indices = torch.randperm(rows)
zero_indices = row_indices[:num_zeros]
tensor[zero_indices, col_idx] = 0
return tensor
# for backward compatibility
def _make_deprecate(meth):
new_name = meth.__name__
old_name = new_name[:-1]
def deprecated_init(*args, **kwargs):
warnings.warn(
f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
FutureWarning,
stacklevel=2,
)
return meth(*args, **kwargs)
deprecated_init.__doc__ = rf"""
{old_name}(...)
.. warning::
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
See :func:`~torch.nn.init.{new_name}` for details."""
deprecated_init.__name__ = old_name
return deprecated_init
uniform = _make_deprecate(uniform_)
normal = _make_deprecate(normal_)
constant = _make_deprecate(constant_)
eye = _make_deprecate(eye_)
dirac = _make_deprecate(dirac_)
xavier_uniform = _make_deprecate(xavier_uniform_)
xavier_normal = _make_deprecate(xavier_normal_)
kaiming_uniform = _make_deprecate(kaiming_uniform_)
kaiming_normal = _make_deprecate(kaiming_normal_)
orthogonal = _make_deprecate(orthogonal_)
sparse = _make_deprecate(sparse_)

View File

@ -0,0 +1,36 @@
from torch.ao.nn.intrinsic import (
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
# Include the subpackages in case user imports from it directly
from torch.nn.intrinsic import modules, qat, quantized # noqa: F401
__all__ = [
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
]

View File

@ -0,0 +1,33 @@
from torch.nn.intrinsic.modules.fused import (
_FusedModule,
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
__all__ = [
"BNReLU2d",
"BNReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearBn1d",
"LinearReLU",
]

View File

@ -0,0 +1,33 @@
from torch.ao.nn.intrinsic import (
BNReLU2d,
BNReLU3d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearReLU,
)
from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401
__all__ = [
"BNReLU2d",
"BNReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearBn1d",
"LinearReLU",
]

View File

@ -0,0 +1 @@
from torch.nn.intrinsic.qat.modules import * # noqa: F403

View File

@ -0,0 +1,32 @@
from torch.nn.intrinsic.qat.modules.conv_fused import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
from torch.nn.intrinsic.qat.modules.linear_fused import LinearBn1d
from torch.nn.intrinsic.qat.modules.linear_relu import LinearReLU
__all__ = [
"LinearReLU",
"LinearBn1d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"update_bn_stats",
"freeze_bn_stats",
]

View File

@ -0,0 +1,40 @@
# flake8: noqa: F401
r"""Intrinsic QAT Modules.
This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.intrinsic.qat import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
__all__ = [
# Modules
"ConvBn1d",
"ConvBnReLU1d",
"ConvReLU1d",
"ConvBn2d",
"ConvBnReLU2d",
"ConvReLU2d",
"ConvBn3d",
"ConvBnReLU3d",
"ConvReLU3d",
# Utilities
"freeze_bn_stats",
"update_bn_stats",
]

View File

@ -0,0 +1,16 @@
# flake8: noqa: F401
r"""Intrinsic QAT Modules.
This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.intrinsic.qat import LinearBn1d
__all__ = [
"LinearBn1d",
]

View File

@ -0,0 +1,16 @@
# flake8: noqa: F401
r"""Intrinsic QAT Modules.
This file is in the process of migration to `torch/ao/nn/intrinsic/qat`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/intrinsic/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.intrinsic.qat import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,14 @@
# to ensure customers can use the module below
# without importing it directly
from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401
from torch.nn.intrinsic.quantized.modules import * # noqa: F403
__all__ = [
"BNReLU2d",
"BNReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
]

View File

@ -0,0 +1 @@
from torch.nn.intrinsic.quantized.dynamic.modules import * # noqa: F403

View File

@ -0,0 +1,6 @@
from torch.nn.intrinsic.quantized.dynamic.modules.linear_relu import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,6 @@
from torch.ao.nn.intrinsic.quantized.dynamic import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,17 @@
from torch.nn.intrinsic.quantized.modules.bn_relu import BNReLU2d, BNReLU3d
from torch.nn.intrinsic.quantized.modules.conv_relu import (
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
)
from torch.nn.intrinsic.quantized.modules.linear_relu import LinearReLU
__all__ = [
"LinearReLU",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"BNReLU2d",
"BNReLU3d",
]

View File

@ -0,0 +1,7 @@
from torch.ao.nn.intrinsic.quantized import BNReLU2d, BNReLU3d
__all__ = [
"BNReLU2d",
"BNReLU3d",
]

View File

@ -0,0 +1,8 @@
from torch.ao.nn.intrinsic.quantized import ConvReLU1d, ConvReLU2d, ConvReLU3d
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]

View File

@ -0,0 +1,6 @@
from torch.ao.nn.intrinsic.quantized import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,334 @@
from .module import Module # usort: skip
from .linear import Bilinear, Identity, LazyLinear, Linear # usort: skip
from .activation import (
CELU,
ELU,
GELU,
GLU,
Hardshrink,
Hardsigmoid,
Hardswish,
Hardtanh,
LeakyReLU,
LogSigmoid,
LogSoftmax,
Mish,
MultiheadAttention,
PReLU,
ReLU,
ReLU6,
RReLU,
SELU,
Sigmoid,
SiLU,
Softmax,
Softmax2d,
Softmin,
Softplus,
Softshrink,
Softsign,
Tanh,
Tanhshrink,
Threshold,
)
from .adaptive import AdaptiveLogSoftmaxWithLoss
from .batchnorm import (
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
LazyBatchNorm1d,
LazyBatchNorm2d,
LazyBatchNorm3d,
SyncBatchNorm,
)
from .channelshuffle import ChannelShuffle
from .container import (
Container,
ModuleDict,
ModuleList,
ParameterDict,
ParameterList,
Sequential,
)
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
LazyConv1d,
LazyConv2d,
LazyConv3d,
LazyConvTranspose1d,
LazyConvTranspose2d,
LazyConvTranspose3d,
)
from .distance import CosineSimilarity, PairwiseDistance
from .dropout import (
AlphaDropout,
Dropout,
Dropout1d,
Dropout2d,
Dropout3d,
FeatureAlphaDropout,
)
from .flatten import Flatten, Unflatten
from .fold import Fold, Unfold
from .instancenorm import (
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LazyInstanceNorm1d,
LazyInstanceNorm2d,
LazyInstanceNorm3d,
)
from .loss import (
BCELoss,
BCEWithLogitsLoss,
CosineEmbeddingLoss,
CrossEntropyLoss,
CTCLoss,
GaussianNLLLoss,
HingeEmbeddingLoss,
HuberLoss,
KLDivLoss,
L1Loss,
MarginRankingLoss,
MSELoss,
MultiLabelMarginLoss,
MultiLabelSoftMarginLoss,
MultiMarginLoss,
NLLLoss,
NLLLoss2d,
PoissonNLLLoss,
SmoothL1Loss,
SoftMarginLoss,
TripletMarginLoss,
TripletMarginWithDistanceLoss,
)
from .normalization import (
CrossMapLRN2d,
GroupNorm,
LayerNorm,
LocalResponseNorm,
RMSNorm,
)
from .padding import (
CircularPad1d,
CircularPad2d,
CircularPad3d,
ConstantPad1d,
ConstantPad2d,
ConstantPad3d,
ReflectionPad1d,
ReflectionPad2d,
ReflectionPad3d,
ReplicationPad1d,
ReplicationPad2d,
ReplicationPad3d,
ZeroPad1d,
ZeroPad2d,
ZeroPad3d,
)
from .pixelshuffle import PixelShuffle, PixelUnshuffle
from .pooling import (
AdaptiveAvgPool1d,
AdaptiveAvgPool2d,
AdaptiveAvgPool3d,
AdaptiveMaxPool1d,
AdaptiveMaxPool2d,
AdaptiveMaxPool3d,
AvgPool1d,
AvgPool2d,
AvgPool3d,
FractionalMaxPool2d,
FractionalMaxPool3d,
LPPool1d,
LPPool2d,
LPPool3d,
MaxPool1d,
MaxPool2d,
MaxPool3d,
MaxUnpool1d,
MaxUnpool2d,
MaxUnpool3d,
)
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNN, RNNBase, RNNCell, RNNCellBase
from .sparse import Embedding, EmbeddingBag
from .transformer import (
Transformer,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from .upsampling import Upsample, UpsamplingBilinear2d, UpsamplingNearest2d
__all__ = [
"AdaptiveAvgPool1d",
"AdaptiveAvgPool2d",
"AdaptiveAvgPool3d",
"AdaptiveLogSoftmaxWithLoss",
"AdaptiveMaxPool1d",
"AdaptiveMaxPool2d",
"AdaptiveMaxPool3d",
"AlphaDropout",
"AvgPool1d",
"AvgPool2d",
"AvgPool3d",
"BCELoss",
"BCEWithLogitsLoss",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"Bilinear",
"CELU",
"CTCLoss",
"ChannelShuffle",
"CircularPad1d",
"CircularPad2d",
"CircularPad3d",
"ConstantPad1d",
"ConstantPad2d",
"ConstantPad3d",
"Container",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"CosineEmbeddingLoss",
"CosineSimilarity",
"CrossEntropyLoss",
"CrossMapLRN2d",
"Dropout",
"Dropout1d",
"Dropout2d",
"Dropout3d",
"ELU",
"Embedding",
"EmbeddingBag",
"FeatureAlphaDropout",
"Flatten",
"Fold",
"FractionalMaxPool2d",
"FractionalMaxPool3d",
"GELU",
"GLU",
"GRU",
"GRUCell",
"GaussianNLLLoss",
"GroupNorm",
"Hardshrink",
"Hardsigmoid",
"Hardswish",
"Hardtanh",
"HingeEmbeddingLoss",
"HuberLoss",
"Identity",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"KLDivLoss",
"L1Loss",
"LPPool1d",
"LPPool2d",
"LPPool3d",
"LSTM",
"LSTMCell",
"LayerNorm",
"LazyBatchNorm1d",
"LazyBatchNorm2d",
"LazyBatchNorm3d",
"LazyConv1d",
"LazyConv2d",
"LazyConv3d",
"LazyConvTranspose1d",
"LazyConvTranspose2d",
"LazyConvTranspose3d",
"LazyInstanceNorm1d",
"LazyInstanceNorm2d",
"LazyInstanceNorm3d",
"LazyLinear",
"LeakyReLU",
"Linear",
"LocalResponseNorm",
"LogSigmoid",
"LogSoftmax",
"MSELoss",
"MarginRankingLoss",
"MaxPool1d",
"MaxPool2d",
"MaxPool3d",
"MaxUnpool1d",
"MaxUnpool2d",
"MaxUnpool3d",
"Mish",
"Module",
"ModuleDict",
"ModuleList",
"MultiLabelMarginLoss",
"MultiLabelSoftMarginLoss",
"MultiMarginLoss",
"MultiheadAttention",
"NLLLoss",
"NLLLoss2d",
"PReLU",
"PairwiseDistance",
"ParameterDict",
"ParameterList",
"PixelShuffle",
"PixelUnshuffle",
"PoissonNLLLoss",
"RMSNorm",
"RNN",
"RNNBase",
"RNNCell",
"RNNCellBase",
"RReLU",
"ReLU",
"ReLU6",
"ReflectionPad1d",
"ReflectionPad2d",
"ReflectionPad3d",
"ReplicationPad1d",
"ReplicationPad2d",
"ReplicationPad3d",
"SELU",
"Sequential",
"SiLU",
"Sigmoid",
"SmoothL1Loss",
"SoftMarginLoss",
"Softmax",
"Softmax2d",
"Softmin",
"Softplus",
"Softshrink",
"Softsign",
"SyncBatchNorm",
"Tanh",
"Tanhshrink",
"Threshold",
"Transformer",
"TransformerDecoder",
"TransformerDecoderLayer",
"TransformerEncoder",
"TransformerEncoderLayer",
"TripletMarginLoss",
"TripletMarginWithDistanceLoss",
"Unflatten",
"Unfold",
"Upsample",
"UpsamplingBilinear2d",
"UpsamplingNearest2d",
"ZeroPad1d",
"ZeroPad2d",
"ZeroPad3d",
]
# Please keep this list sorted
assert __all__ == sorted(__all__)

View File

@ -0,0 +1,319 @@
# mypy: allow-untyped-defs
import torch
import torch.distributed as dist
from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
def forward(
self,
input,
weight,
bias,
running_mean,
running_var,
eps,
momentum,
process_group,
world_size,
):
if not (
input.is_contiguous(memory_format=torch.channels_last)
or input.is_contiguous(memory_format=torch.channels_last_3d)
):
input = input.contiguous()
if weight is not None:
weight = weight.contiguous()
size = int(input.numel() // input.size(1))
if size == 1 and world_size < 2:
raise ValueError(
f"Expected more than 1 value per channel when training, got input size {size}"
)
num_channels = input.shape[1]
if input.numel() > 0:
# calculate mean/invstd for input.
mean, invstd = torch.batch_norm_stats(input, eps)
count = torch.full(
(1,),
input.numel() // input.size(1),
dtype=mean.dtype,
device=mean.device,
)
# C, C, 1 -> (2C + 1)
combined = torch.cat([mean, invstd, count], dim=0)
else:
# for empty input, set stats and the count to zero. The stats with
# zero count will be filtered out later when computing global mean
# & invstd, but they still needs to participate the all_gather
# collective communication to unblock other peer processes.
combined = torch.zeros(
2 * num_channels + 1, dtype=input.dtype, device=input.device
)
# Use allgather instead of allreduce because count could be different across
# ranks, simple all reduce op can not give correct results.
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
# all gathered mean, invstd and count.
# for nccl backend, use the optimized version of all gather.
# The Gloo backend does not support `all_gather_into_tensor`.
if process_group._get_backend_name() != "gloo":
# world_size * (2C + 1)
combined_size = combined.numel()
combined_flat = torch.empty(
1,
combined_size * world_size,
dtype=combined.dtype,
device=combined.device,
)
dist.all_gather_into_tensor(
combined_flat, combined, process_group, async_op=False
)
combined = torch.reshape(combined_flat, (world_size, combined_size))
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
else:
# world_size * (2C + 1)
combined_list = [torch.empty_like(combined) for _ in range(world_size)]
dist.all_gather(combined_list, combined, process_group, async_op=False)
combined = torch.stack(combined_list, dim=0)
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()):
# The lines below force a synchronization between CUDA and CPU, because
# the shape of the result count_all depends on the values in mask tensor.
# Such synchronizations break CUDA Graph capturing.
# See https://github.com/pytorch/pytorch/issues/78549
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
# a better longer-term solution.
# remove stats from empty inputs
mask = count_all.squeeze(-1) >= 1
count_all = count_all[mask]
mean_all = mean_all[mask]
invstd_all = invstd_all[mask]
# calculate global mean & invstd
counts = count_all.view(-1)
if running_mean is not None and counts.dtype != running_mean.dtype:
counts = counts.to(running_mean.dtype)
mean, invstd = torch.batch_norm_gather_stats_with_counts(
input,
mean_all,
invstd_all,
running_mean,
running_var,
momentum,
eps,
counts,
)
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
self.process_group = process_group
# apply element-wise normalization
if input.numel() > 0:
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
else:
return torch.empty_like(input)
@staticmethod
def backward(self, grad_output):
if not (
grad_output.is_contiguous(memory_format=torch.channels_last)
or grad_output.is_contiguous(memory_format=torch.channels_last_3d)
):
grad_output = grad_output.contiguous()
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
grad_input = grad_weight = grad_bias = None
process_group = self.process_group
if saved_input.numel() > 0:
# calculate local stats as well as grad_weight / grad_bias
(
sum_dy,
sum_dy_xmu,
grad_weight,
grad_bias,
) = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
weight,
self.needs_input_grad[0],
self.needs_input_grad[1],
self.needs_input_grad[2],
)
if self.needs_input_grad[0]:
# synchronizing stats used to calculate input gradient.
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
combined,
torch.distributed.ReduceOp.SUM,
process_group,
async_op=False,
)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
# backward pass for gradient calculation
if weight is not None and weight.dtype != mean.dtype:
weight = weight.to(mean.dtype)
grad_input = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count_tensor,
)
# synchronizing of grad_weight / grad_bias is not needed as distributed
# training would handle all reduce.
if weight is None or not self.needs_input_grad[1]:
grad_weight = None
if weight is None or not self.needs_input_grad[2]:
grad_bias = None
else:
# This process got an empty input tensor in the forward pass.
# Although this process can directly set grad_input as an empty
# tensor of zeros, it still needs to participate in the collective
# communication to unblock its peers, as other peer processes might
# have received non-empty inputs.
num_channels = saved_input.shape[1]
if self.needs_input_grad[0]:
# launch all_reduce to unblock other peer processes
combined = torch.zeros(
2 * num_channels, dtype=saved_input.dtype, device=saved_input.device
)
torch.distributed.all_reduce(
combined,
torch.distributed.ReduceOp.SUM,
process_group,
async_op=False,
)
# Leave grad_input, grad_weight and grad_bias as None, which will be
# interpreted by the autograd engine as Tensors full of zeros.
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
class CrossMapLRN2d(Function):
@staticmethod
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
ctx.beta = beta
ctx.k = k
ctx.scale = None
if input.dim() != 4:
raise ValueError(
f"CrossMapLRN2d: Expected input to be 4D, got {input.dim()}D instead."
)
ctx.scale = ctx.scale or input.new()
output = input.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
output.resize_as_(input)
ctx.scale.resize_as_(input)
# use output storage as temporary buffer
input_square = output
torch.pow(input, 2, out=input_square)
pre_pad = int((ctx.size - 1) / 2 + 1)
pre_pad_crop = min(pre_pad, channels)
scale_first = ctx.scale.select(1, 0)
scale_first.zero_()
# compute first feature map normalization
for c in range(pre_pad_crop):
scale_first.add_(input_square.select(1, c))
# reuse computations for next feature maps normalization
# by adding the next feature map and removing the previous
for c in range(1, channels):
scale_previous = ctx.scale.select(1, c - 1)
scale_current = ctx.scale.select(1, c)
scale_current.copy_(scale_previous)
if c < channels - pre_pad + 1:
square_next = input_square.select(1, c + pre_pad - 1)
scale_current.add_(square_next, alpha=1)
if c > pre_pad:
square_previous = input_square.select(1, c - pre_pad)
scale_current.add_(square_previous, alpha=-1)
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
torch.pow(ctx.scale, -ctx.beta, out=output)
output.mul_(input)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width)
accum_ratio = input.new(input_height, input_width)
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
grad_input.resize_as_(input)
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
paddded_ratio.zero_()
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels)
for n in range(batch_size):
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
padded_ratio_center.div_(ctx.scale[n])
torch.sum(
paddded_ratio.narrow(0, 0, ctx.size - 1),
0,
keepdim=False,
out=accum_ratio,
)
for c in range(channels):
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
grad_input[n][c].addcmul_(
input[n][c], accum_ratio, value=-cache_ratio_value
)
accum_ratio.add_(paddded_ratio[c], alpha=-1)
return grad_input, None, None, None, None
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args
@staticmethod
def backward(ctx, *args):
return args

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,330 @@
# mypy: allow-untyped-defs
from collections import namedtuple
from typing import List, Sequence
import torch
import torch.nn.functional as F
from torch import Tensor
from .container import ModuleList, Sequential
from .linear import Linear
from .module import Module
__all__ = ["AdaptiveLogSoftmaxWithLoss"]
_ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
class AdaptiveLogSoftmaxWithLoss(Module):
"""Efficient softmax approximation.
As described in
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
<https://arxiv.org/abs/1609.04309>`__.
""" r"""
Adaptive softmax is an approximate strategy for training models with large
output spaces. It is most effective when the label distribution is highly
imbalanced, for example in natural language modelling, where the word
frequency distribution approximately follows the `Zipf's law`_.
Adaptive softmax partitions the labels into several clusters, according to
their frequency. These clusters may contain different number of targets
each.
Additionally, clusters containing less frequent labels assign lower
dimensional embeddings to those labels, which speeds up the computation.
For each minibatch, only clusters for which at least one target is
present are evaluated.
The idea is that the clusters which are accessed frequently
(like the first one, containing most frequent labels), should also be cheap
to compute -- that is, contain a small number of assigned labels.
We highly recommend taking a look at the original paper for more details.
* :attr:`cutoffs` should be an ordered Sequence of integers sorted
in the increasing order.
It controls number of clusters and the partitioning of targets into
clusters. For example setting ``cutoffs = [10, 100, 1000]``
means that first `10` targets will be assigned
to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
assigned to the first cluster, and targets `101, 102, ..., 1000` will be
assigned to the second cluster, while targets
`1001, 1002, ..., n_classes - 1` will be assigned
to the last, third cluster.
* :attr:`div_value` is used to compute the size of each additional cluster,
which is given as
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
where :math:`idx` is the cluster index (with clusters
for less frequent words having larger indices,
and indices starting from :math:`1`).
* :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
adaptive softmax. See paper for details. Set to False in the official
implementation.
.. warning::
Labels passed as inputs to this module should be sorted according to
their frequency. This means that the most frequent label should be
represented by the index `0`, and the least frequent
label should be represented by the index `n_classes - 1`.
.. note::
This module returns a ``NamedTuple`` with ``output``
and ``loss`` fields. See further documentation for details.
.. note::
To compute log-probabilities for all classes, the ``log_prob``
method can be used.
Args:
in_features (int): Number of features in the input tensor
n_classes (int): Number of classes in the dataset
cutoffs (Sequence): Cutoffs used to assign targets to their buckets
div_value (float, optional): value used as an exponent to compute sizes
of the clusters. Default: 4.0
head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
adaptive softmax. Default: ``False``
Returns:
``NamedTuple`` with ``output`` and ``loss`` fields:
* **output** is a Tensor of size ``N`` containing computed target
log probabilities for each example
* **loss** is a Scalar representing the computed negative
log likelihood loss
Shape:
- input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
- target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
- output1: :math:`(N)` or :math:`()`
- output2: ``Scalar``
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
"""
in_features: int
n_classes: int
cutoffs: List[int]
div_value: float
head_bias: bool
head: Linear
tail: ModuleList
def __init__(
self,
in_features: int,
n_classes: int,
cutoffs: Sequence[int],
div_value: float = 4.0,
head_bias: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
cutoffs = list(cutoffs)
if len(cutoffs) == 0:
raise ValueError("cutoffs should be a sequence of length larger than 0")
if (
(cutoffs != sorted(cutoffs))
or (min(cutoffs) <= 0)
or (max(cutoffs) > (n_classes - 1))
or (len(set(cutoffs)) != len(cutoffs))
or any(int(c) != c for c in cutoffs)
):
raise ValueError(
"cutoffs should be a sequence of unique, positive "
"integers sorted in an increasing order, where "
"each value is between 1 and n_classes-1"
)
self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.div_value = div_value
self.head_bias = head_bias
self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters
self.head = Linear(
self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs
)
self.tail = ModuleList()
for i in range(self.n_clusters):
hsz = int(self.in_features // (self.div_value ** (i + 1)))
osz = self.cutoffs[i + 1] - self.cutoffs[i]
projection = Sequential(
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
Linear(hsz, osz, bias=False, **factory_kwargs),
)
self.tail.append(projection)
def reset_parameters(self) -> None:
self.head.reset_parameters()
for i2h, h2o in self.tail:
i2h.reset_parameters()
h2o.reset_parameters()
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
targ_dim = target_.dim()
if targ_dim == 1:
if input_.size(0) != target_.size(0):
raise RuntimeError(
"Input and target should have the same size "
"in the batch dimension."
)
if input_.dim() != 2:
raise RuntimeError(
"1D target tensor expects 2D input tensors, "
"but found inputs with size",
input_.size(),
)
elif targ_dim == 0:
if input_.dim() != 1:
raise RuntimeError(
"0D target tensor expects 1D input tensors, "
"but found inputs with size",
input_.size(),
)
else:
raise RuntimeError(
"0D or 1D target tensor expected, " "multi-target not supported"
)
is_batched = targ_dim > 0
input = input_ if is_batched else input_.unsqueeze(0)
target = target_ if is_batched else target_.unsqueeze(0)
used_rows = 0
batch_size = target.size(0)
output = input.new_zeros(batch_size)
gather_inds = target.new_empty(batch_size)
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]
target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()
if row_indices.numel() == 0:
continue
if i == 0:
gather_inds.index_copy_(0, row_indices, target[target_mask])
else:
relative_target = target[target_mask] - low_idx
input_subset = input.index_select(0, row_indices)
cluster_output = self.tail[i - 1](input_subset)
cluster_index = self.shortlist_size + i - 1
gather_inds.index_fill_(0, row_indices, cluster_index)
cluster_logprob = F.log_softmax(cluster_output, dim=1)
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
output.index_copy_(0, row_indices, local_logprob.squeeze(1))
used_rows += row_indices.numel()
if used_rows != batch_size:
raise RuntimeError(
f"Target values should be in [0, {self.n_classes - 1}], "
f"but values in range [{target.min().item()}, {target.max().item()}] "
"were found. "
)
head_output = self.head(input)
head_logprob = F.log_softmax(head_output, dim=1)
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()
if not is_batched:
output = output.squeeze(0)
return _ASMoutput(output, loss)
def _get_full_log_prob(self, input, head_output):
"""Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
out = input.new_empty((head_output.size(0), self.n_classes))
head_logprob = F.log_softmax(head_output, dim=1)
out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
cluster_output = self.tail[i](input)
cluster_logprob = F.log_softmax(cluster_output, dim=1)
output_logprob = cluster_logprob + head_logprob[
:, self.shortlist_size + i
].unsqueeze(1)
out[:, start_idx:stop_idx] = output_logprob
return out
def log_prob(self, input: Tensor) -> Tensor:
r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
Args:
input (Tensor): a minibatch of examples
Returns:
log-probabilities of for each class :math:`c`
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N, \texttt{n\_classes})`
"""
head_output = self.head(input)
return self._get_full_log_prob(input, head_output)
def predict(self, input: Tensor) -> Tensor:
r"""Return the class with the highest probability for each example in the input minibatch.
This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
Args:
input (Tensor): a minibatch of examples
Returns:
output (Tensor): a class with the highest probability for each example
Shape:
- Input: :math:`(N, \texttt{in\_features})`
- Output: :math:`(N)`
"""
head_output = self.head(input)
output = torch.argmax(head_output, dim=1)
not_in_shortlist = output >= self.shortlist_size
all_in_shortlist = not (not_in_shortlist.any())
if all_in_shortlist:
return output
elif not_in_shortlist.all():
log_prob = self._get_full_log_prob(input, head_output)
return torch.argmax(log_prob, dim=1)
else:
log_prob = self._get_full_log_prob(
input[not_in_shortlist], head_output[not_in_shortlist]
)
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
return output

View File

@ -0,0 +1,883 @@
# mypy: allow-untyped-defs
from typing import Any, Optional
import torch
from torch import Tensor
from torch.nn import functional as F, init
from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
from ._functions import SyncBatchNorm as sync_batch_norm
from .lazy import LazyModuleMixin
from .module import Module
__all__ = [
"BatchNorm1d",
"LazyBatchNorm1d",
"BatchNorm2d",
"LazyBatchNorm2d",
"BatchNorm3d",
"LazyBatchNorm3d",
"SyncBatchNorm",
]
class _NormBase(Module):
"""Common base of _InstanceNorm and _BatchNorm."""
_version = 2
__constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
num_features: int
eps: float
momentum: Optional[float]
affine: bool
track_running_stats: bool
# WARNING: weight and bias purposely not defined here.
# See https://github.com/pytorch/pytorch/issues/39670
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer(
"running_mean", torch.zeros(num_features, **factory_kwargs)
)
self.register_buffer(
"running_var", torch.ones(num_features, **factory_kwargs)
)
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer(
"num_batches_tracked",
torch.tensor(
0,
dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
def _check_input_dim(self, input):
raise NotImplementedError
def extra_repr(self):
return (
"{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
"track_running_stats={track_running_stats}".format(**self.__dict__)
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if (version is None or version < 2) and self.track_running_stats:
# at version 2: added num_batches_tracked buffer
# this should have a default value of 0
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key not in state_dict:
state_dict[num_batches_tracked_key] = (
self.num_batches_tracked
if self.num_batches_tracked is not None
and self.num_batches_tracked.device != torch.device("meta")
else torch.tensor(0, dtype=torch.long)
)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class _BatchNorm(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
class _LazyNormBase(LazyModuleMixin, _NormBase):
weight: UninitializedParameter # type: ignore[assignment]
bias: UninitializedParameter # type: ignore[assignment]
def __init__(
self,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
# affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten.
0,
eps,
momentum,
False,
False,
**factory_kwargs,
)
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = UninitializedParameter(**factory_kwargs)
self.bias = UninitializedParameter(**factory_kwargs)
if self.track_running_stats:
self.running_mean = UninitializedBuffer(**factory_kwargs)
self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor(
0,
dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
)
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
self.num_features = input.shape[1]
if self.affine:
assert isinstance(self.weight, UninitializedParameter)
assert isinstance(self.bias, UninitializedParameter)
self.weight.materialize((self.num_features,))
self.bias.materialize((self.num_features,))
if self.track_running_stats:
self.running_mean.materialize( # type:ignore[union-attr]
(self.num_features,)
)
self.running_var.materialize( # type:ignore[union-attr]
(self.num_features,)
)
self.reset_parameters()
class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input.
Method described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the number of features or channels of the input). By default, the
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
``torch.var(input, unbiased=True)``.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
Args:
num_features: number of features or channels :math:`C` of the input
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
:math:`C` is the number of features or channels, and :math:`L` is the sequence length
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples::
>>> # With Learnable Parameters
>>> m = nn.BatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm1d(100, affine=False)
>>> input = torch.randn(20, 100)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`,
`running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm1d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input.
4D is a mini-batch of 2D inputs
with additional channel dimension. Method described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
standard-deviation is calculated via the biased estimator, equivalent to
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
standard-deviation is calculated via the unbiased estimator, equivalent to
``torch.var(input, unbiased=True)``.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples::
>>> # With Learnable Parameters
>>> m = nn.BatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm2d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`,
`running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm2d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input.
5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
standard-deviation is calculated via the biased estimator, equivalent to
``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
standard-deviation is calculated via the unbiased estimator, equivalent to
``torch.var(input, unbiased=True)``.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
or Spatio-temporal Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples::
>>> # With Learnable Parameters
>>> m = nn.BatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`,
`running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm3d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
class SyncBatchNorm(_BatchNorm):
r"""Applies Batch Normalization over a N-Dimensional input.
The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension over all
mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
are learnable parameter vectors of size `C` (where `C` is the input size).
By default, the elements of :math:`\gamma` are sampled from
:math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
Also by default, during training this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during
evaluation. The running estimates are kept with a default :attr:`momentum`
of 0.1.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and batch statistics are instead used during
evaluation time as well.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
Normalization or Spatio-temporal Batch Normalization.
Currently :class:`SyncBatchNorm` only supports
:class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
:meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
:attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
Network with DDP.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, +)`
eps: a value added to the denominator for numerical stability.
Default: ``1e-5``
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
process_group: synchronization of stats happen within each process group
individually. Default behavior is synchronization across the whole
world
Shape:
- Input: :math:`(N, C, +)`
- Output: :math:`(N, C, +)` (same shape as input)
.. note::
Synchronization of batchnorm statistics occurs only while training, i.e.
synchronization is disabled when ``model.eval()`` is set or if
``self.training`` is otherwise ``False``.
Examples::
>>> # xdoctest: +SKIP
>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
>>> # network is nn.BatchNorm layer
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
>>> # only single gpu per process is currently supported
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
>>> sync_bn_network,
>>> device_ids=[args.local_rank],
>>> output_device=args.local_rank)
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
track_running_stats: bool = True,
process_group: Optional[Any] = None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.process_group = process_group
def _check_input_dim(self, input):
if input.dim() < 2:
raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
def _check_non_zero_input_channels(self, input):
if input.size(1) == 0:
raise ValueError(
"SyncBatchNorm number of input channels should be non-zero"
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
self._check_non_zero_input_channels(input)
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
assert self.num_batches_tracked is not None
self.num_batches_tracked.add_(1)
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
# If buffers are not to be tracked, ensure that they won't be updated
running_mean = (
self.running_mean if not self.training or self.track_running_stats else None
)
running_var = (
self.running_var if not self.training or self.track_running_stats else None
)
# Don't sync batchnorm stats in inference mode (model.eval()).
need_sync = (
bn_training
and self.training
and torch.distributed.is_available()
and torch.distributed.is_initialized()
)
if need_sync:
# currently only GPU/PrivateUse1 input is supported
if input.device.type not in [
"cuda",
torch._C._get_privateuse1_backend_name(),
]:
raise ValueError(
"SyncBatchNorm expected input tensor to be on GPU or "
f"{torch._C._get_privateuse1_backend_name()}"
)
process_group = torch.distributed.group.WORLD
if self.process_group:
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
need_sync = world_size > 1
# fallback to framework BN when synchronization is not necessary
if not need_sync:
return F.batch_norm(
input,
running_mean,
running_var,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
else:
assert bn_training
return sync_batch_norm.apply(
input,
self.weight,
self.bias,
running_mean,
running_var,
self.eps,
exponential_average_factor,
process_group, # type: ignore[possibly-undefined]
world_size, # type: ignore[possibly-undefined]
)
@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
Args:
module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
process_group (optional): process group to scope synchronization,
default is the whole world
Returns:
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
instead.
Example::
>>> # Network with nn.BatchNorm layer
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> module = torch.nn.Sequential(
>>> torch.nn.Linear(20, 100),
>>> torch.nn.BatchNorm1d(100),
>>> ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> # xdoctest: +SKIP("distributed")
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
module_output.training = module.training
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(
name, cls.convert_sync_batchnorm(child, process_group)
)
del module
return module_output

View File

@ -0,0 +1,56 @@
import torch.nn.functional as F
from torch import Tensor
from .module import Module
__all__ = ["ChannelShuffle"]
class ChannelShuffle(Module):
r"""Divides and rearranges the channels in a tensor.
This operation divides the channels in a tensor of shape :math:`(N, C, *)`
into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
while retaining the original tensor shape in the final output.
Args:
groups (int): number of groups to divide channels in.
Examples::
>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
>>> input
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]]])
>>> output = channel_shuffle(input)
>>> output
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 9., 10.],
[11., 12.]],
[[ 5., 6.],
[ 7., 8.]],
[[13., 14.],
[15., 16.]]]])
"""
__constants__ = ["groups"]
groups: int
def __init__(self, groups: int) -> None:
super().__init__()
self.groups = groups
def forward(self, input: Tensor) -> Tensor:
return F.channel_shuffle(input, self.groups)
def extra_repr(self) -> str:
return f"groups={self.groups}"

View File

@ -0,0 +1,976 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import operator
from collections import abc as container_abcs, OrderedDict
from itertools import chain, islice
from typing import (
Any,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
overload,
Tuple,
TypeVar,
Union,
)
from typing_extensions import deprecated, Self
import torch
from torch._jit_internal import _copy_to_script_wrapper
from torch.nn.parameter import Parameter
from .module import Module
__all__ = [
"Container",
"Sequential",
"ModuleList",
"ModuleDict",
"ParameterList",
"ParameterDict",
]
T = TypeVar("T", bound=Module)
# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
def _addindent(s_, numSpaces):
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
@deprecated(
"`nn.Container` is deprecated. "
"All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
category=FutureWarning,
)
class Container(Module):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
for key, value in kwargs.items():
self.add_module(key, value)
class Sequential(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``OrderedDict`` of modules can be
passed in. The ``forward()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
What's the difference between a ``Sequential`` and a
:class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
sounds like--a list for storing ``Module`` s! On the other hand,
the layers in a ``Sequential`` are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
_modules: Dict[str, Module] # type: ignore[assignment]
@overload
def __init__(self, *args: Module) -> None:
...
@overload
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
...
def __init__(self, *args):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""Get the idx-th item of the iterator."""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError(f"index {idx} is out of range")
idx %= size
return next(islice(iterator, idx, None))
@_copy_to_script_wrapper
def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: Module) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)
def __add__(self, other) -> "Sequential":
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
ret.append(layer)
for layer in other:
ret.append(layer)
return ret
else:
raise ValueError(
"add operator supports only objects "
f"of Sequential class, but {str(type(other))} is given."
)
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def __iadd__(self, other) -> Self:
if isinstance(other, Sequential):
offset = len(self)
for i, module in enumerate(other):
self.add_module(str(i + offset), module)
return self
else:
raise ValueError(
"add operator supports only objects "
f"of Sequential class, but {str(type(other))} is given."
)
def __mul__(self, other: int) -> "Sequential":
if not isinstance(other, int):
raise TypeError(
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
)
elif other <= 0:
raise ValueError(
f"Non-positive multiplication factor {other} for {type(self)}"
)
else:
combined = Sequential()
offset = 0
for _ in range(other):
for module in self:
combined.add_module(str(offset), module)
offset += 1
return combined
def __rmul__(self, other: int) -> "Sequential":
return self.__mul__(other)
def __imul__(self, other: int) -> Self:
if not isinstance(other, int):
raise TypeError(
f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
)
elif other <= 0:
raise ValueError(
f"Non-positive multiplication factor {other} for {type(self)}"
)
else:
len_original = len(self)
offset = len(self)
for _ in range(other - 1):
for i in range(len_original):
self.add_module(str(i + offset), self._modules[str(i)])
offset += len_original
return self
@_copy_to_script_wrapper
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
@_copy_to_script_wrapper
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
# NB: We can't really type check this function as the type of input
# may change dynamically (as is tested in
# TestScript.test_sequential_intermediary_types). Cannot annotate
# with Any as TorchScript expects a more precise type
def forward(self, input):
for module in self:
input = module(input)
return input
def append(self, module: Module) -> "Sequential":
r"""Append a given module to the end.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
def insert(self, index: int, module: Module) -> "Sequential":
if not isinstance(module, Module):
raise AssertionError(f"module should be of type: {Module}")
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError(f"Index out of range: {index}")
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self
def extend(self, sequential) -> "Sequential":
for layer in sequential:
self.append(layer)
return self
class ModuleList(Module):
r"""Holds submodules in a list.
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
:class:`~torch.nn.Module` methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
"""
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
if modules is not None:
self += modules
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f"index {idx} is out of range")
if idx < 0:
idx += len(self)
return str(idx)
@overload
def __getitem__(self, idx: slice) -> "ModuleList":
...
@overload
def __getitem__(self, idx: int) -> Module:
...
@_copy_to_script_wrapper
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
def __setitem__(self, idx: int, module: Module) -> None:
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)
def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)
@_copy_to_script_wrapper
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def __iadd__(self, modules: Iterable[Module]) -> Self:
return self.extend(modules)
def __add__(self, other: Iterable[Module]) -> "ModuleList":
combined = ModuleList()
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined
def __repr__(self):
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + "()"
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
lines = []
main_str = self._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
@_copy_to_script_wrapper
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def insert(self, index: int, module: Module) -> None:
r"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module (nn.Module): module to insert
"""
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
def append(self, module: Module) -> "ModuleList":
r"""Append a given module to the end of the list.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def extend(self, modules: Iterable[Module]) -> Self:
r"""Append modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError(
"ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__
)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ModuleDict(Module):
r"""Holds submodules in a dictionary.
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`~torch.nn.Module` methods.
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`~torch.nn.ModuleDict` (the argument to
:meth:`~torch.nn.ModuleDict.update`).
Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.
Args:
modules (iterable, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module)
Example::
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
"""
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)
@_copy_to_script_wrapper
def __getitem__(self, key: str) -> Module:
return self._modules[key]
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)
def __delitem__(self, key: str) -> None:
del self._modules[key]
@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)
@_copy_to_script_wrapper
def __iter__(self) -> Iterator[str]:
return iter(self._modules)
@_copy_to_script_wrapper
def __contains__(self, key: str) -> bool:
return key in self._modules
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.
Args:
key (str): key to pop from the ModuleDict
"""
v = self[key]
del self[key]
return v
@_copy_to_script_wrapper
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()
@_copy_to_script_wrapper
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
@_copy_to_script_wrapper
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError(
"ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " + type(modules).__name__
)
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError(
"ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(m).__name__
)
if not len(m) == 2:
raise ValueError(
"ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
)
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ParameterList(Module):
r"""Holds parameters in a list.
:class:`~torch.nn.ParameterList` can be used like a regular Python
list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
and will be visible by all :class:`~torch.nn.Module` methods.
Note that the constructor, assigning an element of the list, the
:meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
Args:
parameters (iterable, optional): an iterable of elements to add to the list.
Example::
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
"""
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
super().__init__()
self._size = 0
if values is not None:
self += values
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f"index {idx} is out of range")
if idx < 0:
idx += len(self)
return str(idx)
@overload
def __getitem__(self, idx: int) -> Any:
...
@overload
def __getitem__(self: T, idx: slice) -> T:
...
def __getitem__(self, idx):
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
out = self.__class__()
for i in range(start, stop, step):
out.append(self[i])
return out
else:
idx = self._get_abs_string_index(idx)
return getattr(self, str(idx))
def __setitem__(self, idx: int, param: Any) -> None:
# Note that all other function that add an entry to the list part of
# the ParameterList end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the list part and thus won't
# call into this function.
idx = self._get_abs_string_index(idx)
if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
param = Parameter(param)
return setattr(self, str(idx), param)
def __len__(self) -> int:
return self._size
def __iter__(self) -> Iterator[Any]:
return iter(self[i] for i in range(len(self)))
def __iadd__(self, parameters: Iterable[Any]) -> Self:
return self.extend(parameters)
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def append(self, value: Any) -> "ParameterList":
"""Append a given value at the end of the list.
Args:
value (Any): value to append
"""
new_idx = len(self)
self._size += 1
self[new_idx] = value
return self
def extend(self, values: Iterable[Any]) -> Self:
"""Append values from a Python iterable to the end of the list.
Args:
values (iterable): iterable of values to append
"""
# Tensor is an iterable but we never want to unpack it here
if not isinstance(values, container_abcs.Iterable) or isinstance(
values, torch.Tensor
):
raise TypeError(
"ParameterList.extend should be called with an "
"iterable, but got " + type(values).__name__
)
for value in values:
self.append(value)
return self
def extra_repr(self) -> str:
child_lines = []
for k, p in enumerate(self):
if isinstance(p, torch.Tensor):
size_str = "x".join(str(size) for size in p.size())
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
device_str = f" ({p.device})"
else:
device_str = ""
parastr = "{} containing: [{} of size {}{}]".format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
p.dtype,
size_str,
device_str,
)
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
" (" + str(k) + "): Object of type: " + type(p).__name__
)
tmpstr = "\n".join(child_lines)
return tmpstr
def __call__(self, *args, **kwargs):
raise RuntimeError("ParameterList should not be called.")
class ParameterDict(Module):
r"""Holds parameters in a dictionary.
ParameterDict can be indexed like a regular Python dictionary, but Parameters it
contains are properly registered, and will be visible by all Module methods.
Other objects are treated as would be done by a regular Python dictionary
:class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
:meth:`~torch.nn.ParameterDict.update` with other unordered mapping
types (e.g., Python's plain ``dict``) does not preserve the order of the
merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
will preserve their ordering.
Note that the constructor, assigning an element of the dictionary and the
:meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
:class:`~torch.nn.Parameter`.
Args:
values (iterable, optional): a mapping (dictionary) of
(string : Any) or an iterable of key-value pairs
of type (string, Any)
Example::
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
def forward(self, x, choice):
x = self.params[choice].mm(x)
return x
"""
def __init__(self, parameters: Any = None) -> None:
super().__init__()
self._keys: Dict[str, None] = {}
if parameters is not None:
self.update(parameters)
def _key_to_attr(self, key: str) -> str:
if not isinstance(key, str):
raise TypeError(
"Index given to ParameterDict cannot be used as a key as it is "
f"not a string (type is '{type(key).__name__}'). Open an issue on "
"github if you need non-string keys."
)
else:
# Use the key as-is so that `.named_parameters()` returns the right thing
return key
def __getitem__(self, key: str) -> Any:
attr = self._key_to_attr(key)
return getattr(self, attr)
def __setitem__(self, key: str, value: Any) -> None:
# Note that all other function that add an entry to the dictionary part of
# the ParameterDict end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the dictionary part and thus won't
# call into this function.
self._keys[key] = None
attr = self._key_to_attr(key)
if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
value = Parameter(value)
setattr(self, attr, value)
def __delitem__(self, key: str) -> None:
del self._keys[key]
attr = self._key_to_attr(key)
delattr(self, attr)
def __len__(self) -> int:
return len(self._keys)
def __iter__(self) -> Iterator[str]:
return iter(self._keys)
def __reversed__(self) -> Iterator[str]:
return reversed(list(self._keys))
def copy(self) -> "ParameterDict":
"""Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
# We have to use an OrderedDict because the ParameterDict constructor
# behaves differently on plain dict vs OrderedDict
return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
def __contains__(self, key: str) -> bool:
return key in self._keys
def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
"""Set the default for a key in the Parameterdict.
If key is in the ParameterDict, return its value.
If not, insert `key` with a parameter `default` and return `default`.
`default` defaults to `None`.
Args:
key (str): key to set default for
default (Any): the parameter set to the key
"""
if key not in self:
self[key] = default
return self[key]
def clear(self) -> None:
"""Remove all items from the ParameterDict."""
for k in self._keys.copy():
del self[k]
def pop(self, key: str) -> Any:
r"""Remove key from the ParameterDict and return its parameter.
Args:
key (str): key to pop from the ParameterDict
"""
v = self[key]
del self[key]
return v
def popitem(self) -> Tuple[str, Any]:
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
k, _ = self._keys.popitem()
# We need the key in the _keys to be able to access/del
self._keys[k] = None
val = self[k]
del self[k]
return k, val
def get(self, key: str, default: Optional[Any] = None) -> Any:
r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
Args:
key (str): key to get from the ParameterDict
default (Parameter, optional): value to return if key not present
"""
return self[key] if key in self else default
def fromkeys(
self, keys: Iterable[str], default: Optional[Any] = None
) -> "ParameterDict":
r"""Return a new ParameterDict with the keys provided.
Args:
keys (iterable, string): keys to make the new ParameterDict from
default (Parameter, optional): value to set for all keys
"""
return ParameterDict((k, default) for k in keys)
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ParameterDict keys."""
return self._keys.keys()
def items(self) -> Iterable[Tuple[str, Any]]:
r"""Return an iterable of the ParameterDict key/value pairs."""
return ((k, self[k]) for k in self._keys)
def values(self) -> Iterable[Any]:
r"""Return an iterable of the ParameterDict values."""
return (self[k] for k in self._keys)
def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None:
r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
.. note::
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
parameters (iterable): a mapping (dictionary) from string to
:class:`~torch.nn.Parameter`, or an iterable of
key-value pairs of type (string, :class:`~torch.nn.Parameter`)
"""
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError(
"ParametersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(parameters).__name__
)
if isinstance(parameters, (OrderedDict, ParameterDict)):
for key, parameter in parameters.items():
self[key] = parameter
elif isinstance(parameters, container_abcs.Mapping):
for key, parameter in sorted(parameters.items()):
self[key] = parameter
else:
for j, p in enumerate(parameters):
if not isinstance(p, container_abcs.Iterable):
raise TypeError(
"ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"ParameterDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
self[p[0]] = p[1] # type: ignore[assignment]
def extra_repr(self) -> str:
child_lines = []
for k, p in self.items():
if isinstance(p, torch.Tensor):
size_str = "x".join(str(size) for size in p.size())
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
device_str = f" ({p.device})"
else:
device_str = ""
parastr = "{} containing: [{} of size {}{}]".format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
torch.typename(p),
size_str,
device_str,
)
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
" (" + str(k) + "): Object of type: " + type(p).__name__
)
tmpstr = "\n".join(child_lines)
return tmpstr
def __call__(self, input):
raise RuntimeError("ParameterDict should not be called.")
def __or__(self, other: "ParameterDict") -> "ParameterDict":
copy = self.copy()
copy.update(other)
return copy
def __ror__(self, other: "ParameterDict") -> "ParameterDict":
copy = other.copy()
copy.update(self)
return copy
def __ior__(self, other: "ParameterDict") -> Self:
self.update(other)
return self

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,93 @@
import torch.nn.functional as F
from torch import Tensor
from .module import Module
__all__ = ["PairwiseDistance", "CosineSimilarity"]
class PairwiseDistance(Module):
r"""
Computes the pairwise distance between input vectors, or between columns of input matrices.
Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero
if ``p`` is negative, i.e.:
.. math ::
\mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p,
where :math:`e` is the vector of ones and the ``p``-norm is given by.
.. math ::
\Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
Args:
p (real, optional): the norm degree. Can be negative. Default: 2
eps (float, optional): Small value to avoid division by zero.
Default: 1e-6
keepdim (bool, optional): Determines whether or not to keep the vector dimension.
Default: False
Shape:
- Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension`
- Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1
- Output: :math:`(N)` or :math:`()` based on input dimension.
If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension.
Examples::
>>> pdist = nn.PairwiseDistance(p=2)
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> output = pdist(input1, input2)
"""
__constants__ = ["norm", "eps", "keepdim"]
norm: float
eps: float
keepdim: bool
def __init__(
self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False
) -> None:
super().__init__()
self.norm = p
self.eps = eps
self.keepdim = keepdim
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
class CosineSimilarity(Module):
r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`.
.. math ::
\text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.
Args:
dim (int, optional): Dimension where cosine similarity is computed. Default: 1
eps (float, optional): Small value to avoid division by zero.
Default: 1e-8
Shape:
- Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`
- Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`,
and broadcastable with x1 at other dimensions.
- Output: :math:`(\ast_1, \ast_2)`
Examples::
>>> input1 = torch.randn(100, 128)
>>> input2 = torch.randn(100, 128)
>>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
>>> output = cos(input1, input2)
"""
__constants__ = ["dim", "eps"]
dim: int
eps: float
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
return F.cosine_similarity(x1, x2, self.dim, self.eps)

View File

@ -0,0 +1,305 @@
import torch.nn.functional as F
from torch import Tensor
from .module import Module
__all__ = [
"Dropout",
"Dropout1d",
"Dropout2d",
"Dropout3d",
"AlphaDropout",
"FeatureAlphaDropout",
]
class _DropoutNd(Module):
__constants__ = ["p", "inplace"]
p: float
inplace: bool
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
if p < 0 or p > 1:
raise ValueError(
f"dropout probability has to be between 0 and 1, but got {p}"
)
self.p = p
self.inplace = inplace
def extra_repr(self) -> str:
return f"p={self.p}, inplace={self.inplace}"
class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`.
The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution.
Each channel will be zeroed out independently on every forward call.
This has proven to be an effective technique for regularization and
preventing the co-adaptation of neurons as described in the paper
`Improving neural networks by preventing co-adaptation of feature
detectors`_ .
Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
training. This means that during evaluation the module simply computes an
identity function.
Args:
p: probability of an element to be zeroed. Default: 0.5
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`. Input can be of any shape
- Output: :math:`(*)`. Output is of the same shape as input
Examples::
>>> m = nn.Dropout(p=0.2)
>>> input = torch.randn(20, 16)
>>> output = m(input)
.. _Improving neural networks by preventing co-adaptation of feature
detectors: https://arxiv.org/abs/1207.0580
"""
def forward(self, input: Tensor) -> Tensor:
return F.dropout(input, self.p, self.training, self.inplace)
class Dropout1d(_DropoutNd):
r"""Randomly zero out entire channels.
A channel is a 1D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
batched input is a 1D tensor :math:`\text{input}[i, j]`.
Each channel will be zeroed out independently on every forward call with
probability :attr:`p` using samples from a Bernoulli distribution.
Usually the input comes from :class:`nn.Conv1d` modules.
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then i.i.d. dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.Dropout1d` will help promote independence between
feature maps and should be used instead.
Args:
p (float, optional): probability of an element to be zero-ed.
inplace (bool, optional): If set to ``True``, will do this operation
in-place
Shape:
- Input: :math:`(N, C, L)` or :math:`(C, L)`.
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
Examples::
>>> m = nn.Dropout1d(p=0.2)
>>> input = torch.randn(20, 16, 32)
>>> output = m(input)
.. _Efficient Object Localization Using Convolutional Networks:
https://arxiv.org/abs/1411.4280
"""
def forward(self, input: Tensor) -> Tensor:
return F.dropout1d(input, self.p, self.training, self.inplace)
class Dropout2d(_DropoutNd):
r"""Randomly zero out entire channels.
A channel is a 2D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
batched input is a 2D tensor :math:`\text{input}[i, j]`.
Each channel will be zeroed out independently on every forward call with
probability :attr:`p` using samples from a Bernoulli distribution.
Usually the input comes from :class:`nn.Conv2d` modules.
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then i.i.d. dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.Dropout2d` will help promote independence between
feature maps and should be used instead.
Args:
p (float, optional): probability of an element to be zero-ed.
inplace (bool, optional): If set to ``True``, will do this operation
in-place
.. warning ::
Due to historical reasons, this class will perform 1D channel-wise dropout
for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
support inputs without a batch dimension of shape :math:`(C, H, W)`. This
behavior will change in a future release to interpret 3D inputs as no-batch-dim
inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
Shape:
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
Examples::
>>> m = nn.Dropout2d(p=0.2)
>>> input = torch.randn(20, 16, 32, 32)
>>> output = m(input)
.. _Efficient Object Localization Using Convolutional Networks:
https://arxiv.org/abs/1411.4280
"""
def forward(self, input: Tensor) -> Tensor:
return F.dropout2d(input, self.p, self.training, self.inplace)
class Dropout3d(_DropoutNd):
r"""Randomly zero out entire channels.
A channel is a 3D feature map,
e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
batched input is a 3D tensor :math:`\text{input}[i, j]`.
Each channel will be zeroed out independently on every forward call with
probability :attr:`p` using samples from a Bernoulli distribution.
Usually the input comes from :class:`nn.Conv3d` modules.
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then i.i.d. dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.Dropout3d` will help promote independence between
feature maps and should be used instead.
Args:
p (float, optional): probability of an element to be zeroed.
inplace (bool, optional): If set to ``True``, will do this operation
in-place
Shape:
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
Examples::
>>> m = nn.Dropout3d(p=0.2)
>>> input = torch.randn(20, 16, 4, 32, 32)
>>> output = m(input)
.. _Efficient Object Localization Using Convolutional Networks:
https://arxiv.org/abs/1411.4280
"""
def forward(self, input: Tensor) -> Tensor:
return F.dropout3d(input, self.p, self.training, self.inplace)
class AlphaDropout(_DropoutNd):
r"""Applies Alpha Dropout over the input.
Alpha Dropout is a type of Dropout that maintains the self-normalizing
property.
For an input with zero mean and unit standard deviation, the output of
Alpha Dropout maintains the original mean and standard deviation of the
input.
Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
that the outputs have zero mean and unit standard deviation.
During training, it randomly masks some of the elements of the input
tensor with probability *p* using samples from a bernoulli distribution.
The elements to masked are randomized on every forward call, and scaled
and shifted to maintain zero mean and unit standard deviation.
During evaluation the module simply computes an identity function.
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
Args:
p (float): probability of an element to be dropped. Default: 0.5
inplace (bool, optional): If set to ``True``, will do this operation
in-place
Shape:
- Input: :math:`(*)`. Input can be of any shape
- Output: :math:`(*)`. Output is of the same shape as input
Examples::
>>> m = nn.AlphaDropout(p=0.2)
>>> input = torch.randn(20, 16)
>>> output = m(input)
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
def forward(self, input: Tensor) -> Tensor:
return F.alpha_dropout(input, self.p, self.training)
class FeatureAlphaDropout(_DropoutNd):
r"""Randomly masks out entire channels.
A channel is a feature map,
e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of
setting activations to zero, as in regular Dropout, the activations are set
to the negative saturation value of the SELU activation function. More details
can be found in the paper `Self-Normalizing Neural Networks`_ .
Each element will be masked independently for each sample on every forward
call with probability :attr:`p` using samples from a Bernoulli distribution.
The elements to be masked are randomized on every forward call, and scaled
and shifted to maintain zero mean and unit variance.
Usually the input comes from :class:`nn.AlphaDropout` modules.
As described in the paper
`Efficient Object Localization Using Convolutional Networks`_ ,
if adjacent pixels within feature maps are strongly correlated
(as is normally the case in early convolution layers) then i.i.d. dropout
will not regularize the activations and will otherwise just result
in an effective learning rate decrease.
In this case, :func:`nn.AlphaDropout` will help promote independence between
feature maps and should be used instead.
Args:
p (float, optional): probability of an element to be zeroed. Default: 0.5
inplace (bool, optional): If set to ``True``, will do this operation
in-place
Shape:
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
Examples::
>>> m = nn.FeatureAlphaDropout(p=0.2)
>>> input = torch.randn(20, 16, 4, 32, 32)
>>> output = m(input)
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
.. _Efficient Object Localization Using Convolutional Networks:
https://arxiv.org/abs/1411.4280
"""
def forward(self, input: Tensor) -> Tensor:
return F.feature_alpha_dropout(input, self.p, self.training)

View File

@ -0,0 +1,158 @@
# mypy: allow-untyped-defs
from typing import Tuple, Union
from torch import Tensor
from torch.types import _size
from .module import Module
__all__ = ["Flatten", "Unflatten"]
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor.
For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
Shape:
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
number of dimensions including none.
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
"""
__constants__ = ["start_dim", "end_dim"]
start_dim: int
end_dim: int
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: Tensor) -> Tensor:
return input.flatten(self.start_dim, self.end_dim)
def extra_repr(self) -> str:
return f"start_dim={self.start_dim}, end_dim={self.end_dim}"
class Unflatten(Module):
r"""
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.
Shape:
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
Args:
dim (Union[int, str]): Dimension to be unflattened
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
Examples:
>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]
__constants__ = ["dim", "unflattened_size"]
dim: Union[int, str]
unflattened_size: Union[_size, NamedShape]
def __init__(
self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]
) -> None:
super().__init__()
if isinstance(dim, int):
self._require_tuple_int(unflattened_size)
elif isinstance(dim, str):
self._require_tuple_tuple(unflattened_size)
else:
raise TypeError("invalid argument type for dim parameter")
self.dim = dim
self.unflattened_size = unflattened_size
def _require_tuple_tuple(self, input):
if isinstance(input, tuple):
for idx, elem in enumerate(input):
if not isinstance(elem, tuple):
raise TypeError(
"unflattened_size must be tuple of tuples, "
+ f"but found element of type {type(elem).__name__} at pos {idx}"
)
return
raise TypeError(
"unflattened_size must be a tuple of tuples, "
+ f"but found type {type(input).__name__}"
)
def _require_tuple_int(self, input):
if isinstance(input, (tuple, list)):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError(
"unflattened_size must be tuple of ints, "
+ f"but found element of type {type(elem).__name__} at pos {idx}"
)
return
raise TypeError(
f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}"
)
def forward(self, input: Tensor) -> Tensor:
return input.unflatten(self.dim, self.unflattened_size)
def extra_repr(self) -> str:
return f"dim={self.dim}, unflattened_size={self.unflattened_size}"

View File

@ -0,0 +1,315 @@
import torch.nn.functional as F
from torch import Tensor
from torch.nn.common_types import _size_any_t
from .module import Module
__all__ = ["Fold", "Unfold"]
class Fold(Module):
r"""Combines an array of sliding local blocks into a large containing tensor.
Consider a batched :attr:`input` tensor containing sliding local blocks,
e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`,
where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`
is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`
spatial locations each containing a :math:`C`-channeled vector), and
:math:`L` is the total number of blocks. (This is exactly the
same specification as the output shape of :class:`~torch.nn.Unfold`.) This
operation combines these local blocks into the large :attr:`output` tensor
of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the
arguments must satisfy
.. math::
L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %
- \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
where :math:`d` is over all spatial dimensions.
* :attr:`output_size` describes the spatial shape of the large containing
tensor of the sliding local blocks. It is useful to resolve the ambiguity
when multiple input shapes map to same number of sliding blocks, e.g.,
with ``stride > 0``.
The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
how the sliding blocks are retrieved.
* :attr:`stride` controls the stride for the sliding blocks.
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for :attr:`padding` number of points for each dimension before
reshaping.
""" """
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
""" r"""
Args:
output_size (int or tuple): the shape of the spatial dimensions of the
output (i.e., ``output.sizes()[2:]``)
kernel_size (int or tuple): the size of the sliding blocks
dilation (int or tuple, optional): a parameter that controls the
stride of elements within the
neighborhood. Default: 1
padding (int or tuple, optional): implicit zero padding to be added on
both sides of input. Default: 0
stride (int or tuple): the stride of the sliding blocks in the input
spatial dimensions. Default: 1
* If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,
:attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then
their values will be replicated across all spatial dimensions.
* For the case of two output spatial dimensions this operation is sometimes
called ``col2im``.
.. note::
:class:`~torch.nn.Fold` calculates each combined value in the resulting
large tensor by summing all values from all containing blocks.
:class:`~torch.nn.Unfold` extracts the values in the local blocks by
copying from the large tensor. So, if the blocks overlap, they are not
inverses of each other.
In general, folding and unfolding operations are related as
follows. Consider :class:`~torch.nn.Fold` and
:class:`~torch.nn.Unfold` instances created with the same
parameters:
>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
>>> fold = nn.Fold(output_size=..., **fold_params)
>>> unfold = nn.Unfold(**fold_params)
Then for any (supported) ``input`` tensor the following
equality holds:
::
fold(unfold(input)) == divisor * input
where ``divisor`` is a tensor that depends only on the shape
and dtype of the ``input``:
>>> # xdoctest: +SKIP
>>> input_ones = torch.ones(input.shape, dtype=input.dtype)
>>> divisor = fold(unfold(input_ones))
When the ``divisor`` tensor contains no zero elements, then
``fold`` and ``unfold`` operations are inverses of each
other (up to constant divisor).
.. warning::
Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
Shape:
- Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)`
- Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above
Examples::
>>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
>>> input = torch.randn(1, 3 * 2 * 2, 12)
>>> output = fold(input)
>>> output.size()
torch.Size([1, 3, 4, 5])
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
__constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"]
output_size: _size_any_t
kernel_size: _size_any_t
dilation: _size_any_t
padding: _size_any_t
stride: _size_any_t
def __init__(
self,
output_size: _size_any_t,
kernel_size: _size_any_t,
dilation: _size_any_t = 1,
padding: _size_any_t = 0,
stride: _size_any_t = 1,
) -> None:
super().__init__()
self.output_size = output_size
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding
self.stride = stride
def forward(self, input: Tensor) -> Tensor:
return F.fold(
input,
self.output_size,
self.kernel_size,
self.dilation,
self.padding,
self.stride,
)
def extra_repr(self) -> str:
return (
"output_size={output_size}, kernel_size={kernel_size}, "
"dilation={dilation}, padding={padding}, stride={stride}".format(
**self.__dict__
)
)
class Unfold(Module):
r"""Extracts sliding local blocks from a batched input tensor.
Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
where :math:`N` is the batch dimension, :math:`C` is the channel dimension,
and :math:`*` represent arbitrary spatial dimensions. This operation flattens
each sliding :attr:`kernel_size`-sized block within the spatial dimensions
of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`
tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where
:math:`C \times \prod(\text{kernel\_size})` is the total number of values
within each block (a block has :math:`\prod(\text{kernel\_size})` spatial
locations each containing a :math:`C`-channeled vector), and :math:`L` is
the total number of such blocks:
.. math::
L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %
- \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
where :math:`\text{spatial\_size}` is formed by the spatial dimensions
of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial
dimensions.
Therefore, indexing :attr:`output` at the last dimension (column dimension)
gives all values within a certain block.
The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
how the sliding blocks are retrieved.
* :attr:`stride` controls the stride for the sliding blocks.
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for :attr:`padding` number of points for each dimension before
reshaping.
""" """
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
""" r"""
Args:
kernel_size (int or tuple): the size of the sliding blocks
dilation (int or tuple, optional): a parameter that controls the
stride of elements within the
neighborhood. Default: 1
padding (int or tuple, optional): implicit zero padding to be added on
both sides of input. Default: 0
stride (int or tuple, optional): the stride of the sliding blocks in the input
spatial dimensions. Default: 1
* If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or
:attr:`stride` is an int or a tuple of length 1, their values will be
replicated across all spatial dimensions.
* For the case of two input spatial dimensions this operation is sometimes
called ``im2col``.
.. note::
:class:`~torch.nn.Fold` calculates each combined value in the resulting
large tensor by summing all values from all containing blocks.
:class:`~torch.nn.Unfold` extracts the values in the local blocks by
copying from the large tensor. So, if the blocks overlap, they are not
inverses of each other.
In general, folding and unfolding operations are related as
follows. Consider :class:`~torch.nn.Fold` and
:class:`~torch.nn.Unfold` instances created with the same
parameters:
>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
>>> fold = nn.Fold(output_size=..., **fold_params)
>>> unfold = nn.Unfold(**fold_params)
Then for any (supported) ``input`` tensor the following
equality holds:
::
fold(unfold(input)) == divisor * input
where ``divisor`` is a tensor that depends only on the shape
and dtype of the ``input``:
>>> # xdoctest: +SKIP
>>> input_ones = torch.ones(input.shape, dtype=input.dtype)
>>> divisor = fold(unfold(input_ones))
When the ``divisor`` tensor contains no zero elements, then
``fold`` and ``unfold`` operations are inverses of each
other (up to constant divisor).
.. warning::
Currently, only 4-D input tensors (batched image-like tensors) are
supported.
Shape:
- Input: :math:`(N, C, *)`
- Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above
Examples::
>>> unfold = nn.Unfold(kernel_size=(2, 3))
>>> input = torch.randn(2, 5, 3, 4)
>>> output = unfold(input)
>>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
>>> # 4 blocks (2x3 kernels) in total in the 3x4 input
>>> output.size()
torch.Size([2, 30, 4])
>>> # xdoctest: +IGNORE_WANT
>>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
>>> inp = torch.randn(1, 3, 10, 12)
>>> w = torch.randn(2, 3, 4, 5)
>>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
>>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
>>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
>>> # or equivalently (and avoiding a copy),
>>> # out = out_unf.view(1, 2, 7, 8)
>>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
tensor(1.9073e-06)
.. _link:
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
__constants__ = ["kernel_size", "dilation", "padding", "stride"]
kernel_size: _size_any_t
dilation: _size_any_t
padding: _size_any_t
stride: _size_any_t
def __init__(
self,
kernel_size: _size_any_t,
dilation: _size_any_t = 1,
padding: _size_any_t = 0,
stride: _size_any_t = 1,
) -> None:
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding
self.stride = stride
def forward(self, input: Tensor) -> Tensor:
return F.unfold(
input, self.kernel_size, self.dilation, self.padding, self.stride
)
def extra_repr(self) -> str:
return (
"kernel_size={kernel_size}, dilation={dilation}, padding={padding},"
" stride={stride}".format(**self.__dict__)
)

View File

@ -0,0 +1,471 @@
# mypy: allow-untyped-defs
import warnings
import torch.nn.functional as F
from torch import Tensor
from .batchnorm import _LazyNormBase, _NormBase
__all__ = [
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LazyInstanceNorm1d",
"LazyInstanceNorm2d",
"LazyInstanceNorm3d",
]
class _InstanceNorm(_NormBase):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = False,
track_running_stats: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
def _check_input_dim(self, input):
raise NotImplementedError
def _get_no_batch_dim(self):
raise NotImplementedError
def _handle_no_batch_input(self, input):
return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0)
def _apply_instance_norm(self, input):
return F.instance_norm(
input,
self.running_mean,
self.running_var,
self.weight,
self.bias,
self.training or not self.track_running_stats,
self.momentum if self.momentum is not None else 0.0,
self.eps,
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
# at version 1: removed running_mean and running_var when
# track_running_stats=False (default)
if version is None and not self.track_running_stats:
running_stats_keys = []
for name in ("running_mean", "running_var"):
key = prefix + name
if key in state_dict:
running_stats_keys.append(key)
if len(running_stats_keys) > 0:
error_msgs.append(
"Unexpected running stats buffer(s) {names} for {klass} "
"with track_running_stats=False. If state_dict is a "
"checkpoint saved before 0.4.0, this may be expected "
"because {klass} does not track running stats by default "
"since 0.4.0. Please remove these keys from state_dict. If "
"the running stats are actually needed, instead set "
"track_running_stats=True in {klass} to enable them. See "
"the documentation of {klass} for details.".format(
names=" and ".join(f'"{k}"' for k in running_stats_keys),
klass=self.__class__.__name__,
)
)
for key in running_stats_keys:
state_dict.pop(key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
feature_dim = input.dim() - self._get_no_batch_dim()
if input.size(feature_dim) != self.num_features:
if self.affine:
raise ValueError(
f"expected input's size at dim={feature_dim} to match num_features"
f" ({self.num_features}), but got: {input.size(feature_dim)}."
)
else:
warnings.warn(
f"input's size at dim={feature_dim} does not match num_features. "
"You can silence this warning by not passing in num_features, "
"which is not used because affine=False"
)
if input.dim() == self._get_no_batch_dim():
return self._handle_no_batch_input(input)
return self._apply_instance_norm(input)
class InstanceNorm1d(_InstanceNorm):
r"""Applies Instance Normalization.
This operation applies Instance Normalization
over a 2D (unbatched) or 3D (batched) input as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm1d` is applied
on each channel of channeled data like multidimensional time series, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm1d` usually don't apply affine
transform.
Args:
num_features: number of features or channels :math:`C` of the input
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, L)` or :math:`(C, L)`
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm1d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm1d(100, affine=True)
>>> input = torch.randn(20, 100, 40)
>>> output = m(input)
"""
def _get_no_batch_dim(self):
return 2
def _check_input_dim(self, input):
if input.dim() not in (2, 3):
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument.
The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, L)` or :math:`(C, L)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, L)` or :math:`(C, L)`
- Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
"""
cls_to_become = InstanceNorm1d # type: ignore[assignment]
def _get_no_batch_dim(self):
return 2
def _check_input_dim(self, input):
if input.dim() not in (2, 3):
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
class InstanceNorm2d(_InstanceNorm):
r"""Applies Instance Normalization.
This operation applies Instance Normalization
over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm2d` is applied
on each channel of channeled data like RGB images, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm2d` usually don't apply affine
transform.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)` or :math:`(C, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm2d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm2d(100, affine=True)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
"""
def _get_no_batch_dim(self):
return 3
def _check_input_dim(self, input):
if input.dim() not in (3, 4):
raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument.
The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`,
`running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)` or :math:`(C, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
"""
cls_to_become = InstanceNorm2d # type: ignore[assignment]
def _get_no_batch_dim(self):
return 3
def _check_input_dim(self, input):
if input.dim() not in (3, 4):
raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
class InstanceNorm3d(_InstanceNorm):
r"""Applies Instance Normalization.
This operation applies Instance Normalization
over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`__.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated per-dimension separately
for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
of size C (where C is the input size) if :attr:`affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
By default, this layer uses instance statistics computed from input data in
both training and evaluation modes.
If :attr:`track_running_stats` is set to ``True``, during training this
layer keeps running estimates of its computed mean and variance, which are
then used for normalization during evaluation. The running estimates are
kept with a default :attr:`momentum` of 0.1.
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically, the
update rule for running statistics here is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
new observed value.
.. note::
:class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
have some subtle differences. :class:`InstanceNorm3d` is applied
on each channel of channeled data like 3D models with RGB color, but
:class:`LayerNorm` is usually applied on entire sample and often in NLP
tasks. Additionally, :class:`LayerNorm` applies elementwise affine
transform, while :class:`InstanceNorm3d` usually don't apply affine
transform.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
Examples::
>>> # Without Learnable Parameters
>>> m = nn.InstanceNorm3d(100)
>>> # With Learnable Parameters
>>> m = nn.InstanceNorm3d(100, affine=True)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
"""
def _get_no_batch_dim(self):
return 4
def _check_input_dim(self, input):
if input.dim() not in (4, 5):
raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument.
The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``.
The attributes that will be lazily initialized are `weight`, `bias`,
`running_mean` and `running_var`.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
eps: a value added to the denominator for numerical stability. Default: 1e-5
momentum: the value used for the running_mean and running_var computation. Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters, initialized the same way as done for batch normalization.
Default: ``False``.
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``False``
Shape:
- Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
"""
cls_to_become = InstanceNorm3d # type: ignore[assignment]
def _get_no_batch_dim(self):
return 4
def _check_input_dim(self, input):
if input.dim() not in (4, 5):
raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")

View File

@ -0,0 +1,289 @@
# mypy: allow-untyped-defs
import itertools
from typing import Any, Optional, Protocol, Type
import torch
from torch.nn.parameter import is_lazy
__all__ = ["LazyModuleMixin"]
class _LazyProtocol(Protocol):
"""This class is used to avoid errors with mypy checks for the attributes in a mixin.
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
"""
def _register_load_state_dict_pre_hook(self, hook):
...
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
...
def _lazy_load_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
...
def _get_name(self):
...
def _infer_parameters(self, module, input):
...
@property
def _parameters(self):
...
@property
def _buffers(self):
...
@property
def _non_persistent_buffers_set(self):
...
@property
def _load_hook(self):
...
@property
def _initialize_hook(self):
...
class LazyModuleMixin:
r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules".
.. warning:
Lazy modules are an experimental new feature under active development,
and their API is likely to change.
Modules that lazily initialize parameters, or "lazy modules",
derive the shapes of their parameters from the first input(s)
to their forward method. Until that first forward they contain
:class:`torch.nn.UninitializedParameter` s that should not be accessed
or used, and afterward they contain regular :class:`torch.nn.Parameter` s.
Lazy modules are convenient since they don't require computing some
module arguments, like the :attr:`in_features` argument of a
typical :class:`torch.nn.Linear`.
After construction, networks with lazy modules should first
be converted to the desired dtype and placed on the expected device.
This is because lazy modules only perform shape inference so the usual dtype
and device placement behavior applies.
The lazy modules should then perform "dry runs" to initialize all the components in the module.
These "dry runs" send inputs of the correct size, dtype, and device through
the network and to each one of its lazy modules. After this the network can be used as usual.
>>> # xdoctest: +SKIP
>>> class LazyMLP(torch.nn.Module):
... def __init__(self) -> None:
... super().__init__()
... self.fc1 = torch.nn.LazyLinear(10)
... self.relu1 = torch.nn.ReLU()
... self.fc2 = torch.nn.LazyLinear(1)
... self.relu2 = torch.nn.ReLU()
...
... def forward(self, input):
... x = self.relu1(self.fc1(input))
... y = self.relu2(self.fc2(x))
... return y
>>> # constructs a network with lazy modules
>>> lazy_mlp = LazyMLP()
>>> # transforms the network's device and dtype
>>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
>>> lazy_mlp = lazy_mlp.cuda().double()
>>> lazy_mlp
LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
(relu1): ReLU()
(fc2): LazyLinear(in_features=0, out_features=1, bias=True)
(relu2): ReLU()
)
>>> # performs a dry run to initialize the network's lazy modules
>>> lazy_mlp(torch.ones(10,10).cuda())
>>> # after initialization, LazyLinear modules become regular Linear modules
>>> lazy_mlp
LazyMLP(
(fc1): Linear(in_features=10, out_features=10, bias=True)
(relu1): ReLU()
(fc2): Linear(in_features=10, out_features=1, bias=True)
(relu2): ReLU()
)
>>> # attaches an optimizer, since parameters can now be used as usual
>>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
A final caveat when using lazy modules is that the order of initialization of a network's
parameters may change, since the lazy modules are always initialized after other modules.
For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
first and then a regular :class:`torch.nn.Linear` second, the second module would be
initialized on construction and the first module would be initialized during the first dry run.
This can cause the parameters of a network using lazy modules to be initialized differently
than the parameters of a network without lazy modules as the order of parameter initializations,
which often depends on a stateful random number generator, is different.
Check :doc:`/notes/randomness` for more details.
Lazy modules can be serialized with a state dict like other modules. For example:
>>> lazy_mlp = LazyMLP()
>>> # The state dict shows the uninitialized parameters
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight', Uninitialized parameter),
('fc1.bias',
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
('fc2.weight', Uninitialized parameter),
('fc2.bias', tensor([0.0019]))])
Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize
initialized LazyModules and they will remain initialized)
>>> full_mlp = LazyMLP()
>>> # Dry run to initialize another module
>>> full_mlp.forward(torch.ones(10, 1))
>>> # Load an initialized state into a lazy module
>>> lazy_mlp.load_state_dict(full_mlp.state_dict())
>>> # The state dict now holds valid values
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight',
tensor([[-0.3837],
[ 0.0907],
[ 0.6708],
[-0.5223],
[-0.9028],
[ 0.2851],
[-0.4537],
[ 0.6813],
[ 0.5766],
[-0.8678]])),
('fc1.bias',
tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30,
4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])),
('fc2.weight',
tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807,
0.2479, 0.1091]])),
('fc2.bias', tensor([0.0019]))])
Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized
when the state is loaded. This prevents using initialized modules in different contexts.
"""
# modules inheriting from this will change their __class__ to the specified
# one after they are fully initialized
cls_to_become: Optional[Type[Any]] = None
def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesnt like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore[misc]
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
self._initialize_hook = self.register_forward_pre_hook(
self._infer_parameters, with_kwargs=True
)
def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
# This should be ideally implemented as a hook,
# but we should override `detach` in the UninitializedParameter to return itself
# which is not clean
for name, param in self._parameters.items():
if param is not None:
if not (is_lazy(param) or keep_vars):
param = param.detach()
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
if not (is_lazy(buf) or keep_vars):
buf = buf.detach()
destination[prefix + name] = buf
def _lazy_load_hook(
self: _LazyProtocol,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""load_state_dict pre-hook function for lazy buffers and parameters.
The purpose of this hook is to adjust the current state and/or
``state_dict`` being loaded so that a module instance serialized in
both un/initialized state can be deserialized onto both un/initialized
module instance.
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
for the details of the hook specification.
"""
for name, param in itertools.chain(
self._parameters.items(), self._buffers.items()
):
key = prefix + name
if key in state_dict and param is not None:
input_param = state_dict[key]
if is_lazy(param):
# The current parameter is not initialized but the one being loaded one is
# create a new parameter based on the uninitialized one
if not is_lazy(input_param):
with torch.no_grad():
param.materialize(input_param.shape)
def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
r"""Initialize parameters according to the input batch properties.
This adds an interface to isolate parameter initialization from the
forward pass when doing parameter shape inference.
"""
raise NotImplementedError(
f"initialize_parameters is not implemented for {self.__class__.__name__}"
)
def has_uninitialized_params(self: _LazyProtocol):
r"""Check if a module has parameters that are not initialized."""
# This is to avoid the JIT to track this parameter and force
# custom modules __setstate__ to add it
params = self._parameters.values()
buffers = self._buffers.values()
for param in itertools.chain(params, buffers):
if is_lazy(param):
return True
return False
# torchrec tests the code consistency with the following code
# fmt: off
def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None):
r"""Infers the size and initializes the parameters according to the provided input batch.
Given a module that contains parameters that were declared inferrable
using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
in the complete module using the provided input to initialize all the parameters
as needed.
The module is set into evaluation mode before running the forward pass in order
to avoid saving statistics or calculating gradients
"""
kwargs = kwargs if kwargs else {}
module.initialize_parameters(*args, **kwargs)
if module.has_uninitialized_params():
raise RuntimeError(f'module {self._get_name()} has not been fully initialized')
module._initialize_hook.remove()
module._load_hook.remove()
delattr(module, '_initialize_hook')
delattr(module, '_load_hook')
if module.cls_to_become is not None:
module.__class__ = module.cls_to_become
# fmt: on
def _replicate_for_data_parallel(self: _LazyProtocol):
raise RuntimeError(
"Modules with uninitialized parameters can't be used with `DataParallel`. "
"Run a dummy forward pass to correctly initialize the modules"
)

Some files were not shown because too many files have changed in this diff Show More