I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/
import importlib
__all__ = [
"intrinsic",
"qat",
"quantizable",
"quantized",
"sparse",
]
def __getattr__(name):
if name in __all__:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,40 @@
# mypy: allow-untyped-defs
from .modules import * # noqa: F403
from .modules.fused import _FusedModule # noqa: F403
# # Subpackages
# from . import qat # noqa: F403
# from . import quantized # noqa: F403
__all__ = [
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/
def __getattr__(name):
if name in __all__:
import importlib
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,41 @@
from .fused import ( # noqa: F401
_FusedModule,
BNReLU2d,
BNReLU3d,
ConvAdd2d,
ConvAddReLU2d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearLeakyReLU,
LinearReLU,
LinearTanh,
)
__all__ = [
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1,245 @@
# mypy: allow-untyped-defs
import torch
from torch.nn import (
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
Linear,
ReLU,
)
from torch.nn.utils.parametrize import type_before_parametrizations
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"ConvBn1d",
"ConvBn2d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBn3d",
"ConvBnReLU3d",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]
# Used for identifying intrinsic modules used in quantization
class _FusedModule(torch.nn.Sequential):
pass
class ConvReLU1d(_FusedModule):
r"""This is a sequential container which calls the Conv1d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class ConvReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class ConvReLU3d(_FusedModule):
r"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class LinearReLU(_FusedModule):
r"""This is a sequential container which calls the Linear and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, relu):
assert (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}"
super().__init__(linear, relu)
class ConvBn1d(_FusedModule):
r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBn2d(_FusedModule):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBnReLU1d(_FusedModule):
r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class ConvBnReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class ConvBn3d(_FusedModule):
r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBnReLU3d(_FusedModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class BNReLU2d(_FusedModule):
r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
type_before_parametrizations(batch_norm) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
super().__init__(batch_norm, relu)
class BNReLU3d(_FusedModule):
r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
type_before_parametrizations(batch_norm) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
super().__init__(batch_norm, relu)
class LinearBn1d(_FusedModule):
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, bn):
assert (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(bn) == BatchNorm1d
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}"
super().__init__(linear, bn)
class LinearLeakyReLU(_FusedModule):
r"""This is a sequential container which calls the Linear and LeakyReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, leaky_relu):
assert (
type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU
), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
super().__init__(linear, leaky_relu)
class LinearTanh(_FusedModule):
r"""This is a sequential container which calls the Linear and Tanh modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, tanh):
assert (
type(linear) == Linear and type(tanh) == torch.nn.Tanh
), f"Incorrect types for input modules{type(linear)}{type(tanh)}"
super().__init__(linear, tanh)
class ConvAdd2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d modules with extra Add.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, add):
super().__init__(conv)
self.add = add
def forward(self, x1, x2):
return self.add(self[0](x1), x2)
class ConvAddReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d, add, Relu.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, add, relu):
super().__init__(conv)
self.add = add
self.relu = relu
def forward(self, x1, x2):
return self.relu(self.add(self[0](x1), x2))

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,32 @@
from .conv_fused import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
from .linear_fused import LinearBn1d
from .linear_relu import LinearReLU
__all__ = [
"LinearReLU",
"LinearBn1d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"update_bn_stats",
"freeze_bn_stats",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,193 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.utils.fusion import fuse_linear_bn_weights
__all__ = [
"LinearBn1d",
]
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
r"""
A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of :class:`torch.nn.Linear` and
:class:torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
def __init__(
self,
# Linear args
in_features,
out_features,
bias=True,
# BatchNorm1d args
# num_features: out_features
eps=1e-05,
momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
def reset_parameters(self):
super().reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def forward(self, input):
assert self.bn.running_var is not None
# Scale the linear weights by BN's running statistics to reduce
# weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
# for motivation.
#
# Instead of
#
# x1 = F.linear(x0, fq(w), b)
# x2 = self.bn(x1)
#
# We have
#
# # scale the weight by previous batch's running statistics
# scale_factor = bn.w / bn.running_std_from_prev_batch
# # do the linear transformation without bias
# x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
# # reverse the scaling and add original bias
# x1_orig = x1_scaled / scale_factor + b
# x2 = self.bn(x1_orig)
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(
self.weight * scale_factor.reshape(weight_shape)
)
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
linear_out = F.linear(input, scaled_weight, zero_bias)
linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
if self.bias is not None:
linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
bn_out = self.bn(linear_out_orig)
return bn_out
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod' a float module, either produced by torch.ao.quantization
utilities or directly from user
"""
assert type(mod) == nni.LinearBn1d, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ nni.LinearBn1d.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid config"
qconfig = mod.qconfig
linear, bn = mod[0], mod[1]
qat_linearbn = cls(
linear.in_features,
linear.out_features,
linear.bias is not None,
bn.eps,
bn.momentum,
False,
qconfig,
)
qat_linearbn.weight = linear.weight
qat_linearbn.bias = linear.bias
qat_linearbn.bn.weight = bn.weight
qat_linearbn.bn.bias = bn.bias
qat_linearbn.bn.running_mean = bn.running_mean
qat_linearbn.bn.running_var = bn.running_var
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_linearbn
def to_float(self):
linear = torch.nn.Linear(self.in_features, self.out_features)
assert self.bn.running_var is not None and self.bn.running_mean is not None
linear.weight, linear.bias = fuse_linear_bn_weights(
self.weight,
self.bias,
self.bn.running_mean,
self.bn.running_var,
self.bn.eps,
self.bn.weight,
self.bn.bias,
)
return linear

View File

@ -0,0 +1,51 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.nn.functional as F
class LinearReLU(nnqat.Linear, nni._FusedModule):
r"""
A LinearReLU module fused from Linear and ReLU modules, attached with
FakeQuantize modules for weight, used in
quantization aware training.
We adopt the same interface as :class:`torch.nn.Linear`.
Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.qat.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, qconfig=None):
super().__init__(in_features, out_features, bias, qconfig)
def forward(self, input):
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
def to_float(self):
linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias is not None
)
linear.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
linear.bias = torch.nn.Parameter(self.bias.detach())
relu = torch.nn.ReLU()
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)

View File

@ -0,0 +1,15 @@
from .modules import * # noqa: F403
__all__ = [
"BNReLU2d",
"BNReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,6 @@
from .linear_relu import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,60 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized.dynamic as nnqd
__all__ = ["LinearReLU"]
class LinearReLU(nnqd.Linear):
r"""
A LinearReLU module fused from Linear and ReLU modules that can be used
for dynamic quantization.
Supports both, FP16 and INT8 quantization.
We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
Attributes:
Same as torch.ao.nn.quantized.dynamic.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._packed_params.dtype == torch.qint8:
# TODO check if we should set reduce_rage = True by default here
Y = torch.ops.quantized.linear_relu_dynamic(
x, self._packed_params._packed_params, reduce_range=True
)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_relu_dynamic_fp16(
x, self._packed_params._packed_params
)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!")
return Y.to(x.dtype)
def _get_name(self):
return "DynamicQuantizedLinearReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qlinear_relu):
return super().from_reference(ref_qlinear_relu[0])

View File

@ -0,0 +1,18 @@
from .bn_relu import BNReLU2d, BNReLU3d
from .conv_add import ConvAdd2d, ConvAddReLU2d
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
__all__ = [
"LinearReLU",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"BNReLU2d",
"BNReLU3d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1,105 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
__all__ = ["BNReLU2d", "BNReLU3d"]
class BNReLU2d(nnq.BatchNorm2d):
r"""
A BNReLU2d module is a fused module of BatchNorm2d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
Attributes:
Same as torch.ao.nn.quantized.BatchNorm2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super().__init__(
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return torch.ops.quantized.batch_norm2d_relu(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedBNReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU2d
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)
class BNReLU3d(nnq.BatchNorm3d):
r"""
A BNReLU3d module is a fused module of BatchNorm3d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
Attributes:
Same as torch.ao.nn.quantized.BatchNorm3d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super().__init__(
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
return torch.ops.quantized.batch_norm3d_relu(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedBNReLU3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU3d
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)

View File

@ -0,0 +1,145 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
import torch.nn.functional as F
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
class ConvAdd2d(nnq.Conv2d):
r"""
A ConvAdd2d module is a fused module of Conv2d and Add
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input, extra_input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_add(
input, extra_input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvAdd2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvAddReLU2d(nnq.Conv2d):
r"""
A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input, extra_input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_add_relu(
input, extra_input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvAddReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)

View File

@ -0,0 +1,263 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
import torch.nn.functional as F
from torch.nn.utils import fuse_conv_bn_weights
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
# TODO: factor out the common parts to ConvNd
class ConvReLU1d(nnq.Conv1d):
r"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != "zeros":
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv1d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU1d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
), "BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvReLU2d(nnq.Conv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
), "BatchNorm2d should be fused into Conv2d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvReLU3d(nnq.Conv3d):
r"""
A ConvReLU3d module is a fused module of Conv3d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
Attributes: Same as torch.ao.nn.quantized.Conv3d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv3d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
), "BatchNorm3d should be fused into Conv3d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)

View File

@ -0,0 +1,187 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = [
"LinearReLU",
"LinearLeakyReLU",
"LinearTanh",
]
class LinearReLU(nnq.Linear):
r"""
A LinearReLU module fused from Linear and ReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_relu(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLinearReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
return super().from_reference(
ref_linear_relu[0], output_scale, output_zero_point
)
class LinearLeakyReLU(nnq.Linear):
r"""
For onednn backend only
A LinearLeakyReLU module fused from Linear and LeakyReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
+ negative_slope
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
def __init__(
self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
):
super().__init__(in_features, out_features, bias, dtype)
self.negative_slope = negative_slope
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_leaky_relu(
x,
self._packed_params._packed_params,
self.scale,
self.zero_point,
self.negative_slope,
)
def _get_name(self):
return "QuantizedLinearLeakyReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert (
type(mod) == nni.LinearLeakyReLU
), "Input float module should be LinearLeakyReLU"
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
leaky_relu = mod[1]
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear_leaky_relu = cls(
mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
)
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
qlinear_leaky_relu.scale = float(act_scale)
qlinear_leaky_relu.zero_point = int(act_zp)
return qlinear_leaky_relu
@classmethod
def from_reference(cls, ref_mod, output_scale, output_zero_point):
linear = ref_mod[0]
leaky_relu = ref_mod[1]
qlinear_leaky_relu = cls(
linear.in_features, linear.out_features, leaky_relu.negative_slope
)
qweight = linear.get_quantized_weight()
qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
qlinear_leaky_relu.scale = float(output_scale)
qlinear_leaky_relu.zero_point = int(output_zero_point)
return qlinear_leaky_relu
class LinearTanh(nnq.Linear):
r"""
A LinearTanh module fused from Linear and Tanh modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearTanh(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_tanh(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLinearTanh"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh"
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear_tanh.set_weight_bias(qweight, mod.bias)
qlinear_tanh.scale = float(act_scale)
qlinear_tanh.zero_point = int(act_zp)
return qlinear_tanh
@classmethod
def from_reference(cls, ref_mod, output_scale, output_zero_point):
linear = ref_mod[0]
qlinear_tanh = cls(linear.in_features, linear.out_features)
qweight = linear.get_quantized_weight()
qlinear_tanh.set_weight_bias(qweight, linear.bias)
qlinear_tanh.scale = float(output_scale)
qlinear_tanh.zero_point = int(output_zero_point)
return qlinear_tanh

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,4 @@
from .linear import Linear
__all__ = ["Linear"]

View File

@ -0,0 +1,35 @@
# mypy: allow-untyped-defs
import torch
__all__ = ["Linear"]
class Linear(torch.ao.nn.qat.Linear):
r"""
A linear module attached with FakeQuantize modules for weight,
used for dynamic quantization aware training.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
default.
"""
def __init__(
self,
in_features,
out_features,
bias=True,
qconfig=None,
device=None,
dtype=None,
) -> None:
super().__init__(in_features, out_features, bias, qconfig, device, dtype)
if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
raise ValueError(
"Dynamic QAT requires a memoryless observer."
+ "This means a MovingAverage observer with averaging constant equal to 1"
)

View File

@ -0,0 +1,13 @@
from .conv import Conv1d, Conv2d, Conv3d
from .embedding_ops import Embedding, EmbeddingBag
from .linear import Linear
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"Embedding",
"EmbeddingBag",
]

View File

@ -0,0 +1,310 @@
# mypy: allow-untyped-defs
from typing import Tuple, TypeVar, Union
import torch
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _pair, _single, _triple
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
class _ConvNd(nn.modules.conv._ConvNd):
_FLOAT_MODULE = MOD
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
nn.modules.conv._ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
bias,
padding_mode,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if issubclass(type(mod), _FusedModule):
mod = mod[0] # type: ignore[index]
qconfig = mod.qconfig
qat_conv = cls(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
stride=mod.stride,
padding=mod.padding,
dilation=mod.dilation,
groups=mod.groups,
bias=mod.bias is not None,
padding_mode=mod.padding_mode,
qconfig=qconfig,
)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
"""This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined, operator]
self.in_channels,
self.out_channels,
self.kernel_size, # type: ignore[arg-type]
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
self.groups,
self.bias is not None,
self.padding_mode,
)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
# conv relu
if issubclass(cls, _FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
modules.append(relu)
fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
fused.train(self.training)
return fused
else:
return conv
class Conv1d(_ConvNd, nn.Conv1d):
r"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv1d
_FLOAT_CONV_MODULE = nn.Conv1d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _single(kernel_size)
stride_ = _single(stride)
padding_ = padding if isinstance(padding, str) else _single(padding)
dilation_ = _single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_single(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class Conv2d(_ConvNd, nn.Conv2d):
r"""
A Conv2d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv2d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
for documentation.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv2d
_FLOAT_CONV_MODULE = nn.Conv2d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_pair(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class Conv3d(_ConvNd, nn.Conv3d):
r"""
A Conv3d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv3d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
for documentation.
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv3d
_FLOAT_CONV_MODULE = nn.Conv3d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: Union[str, _size_3_t] = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _triple(kernel_size)
stride_ = _triple(stride)
padding_ = padding if isinstance(padding, str) else _triple(padding)
dilation_ = _triple(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_triple(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)

View File

@ -0,0 +1,248 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
__all__ = ["Embedding", "EmbeddingBag"]
class Embedding(nn.Embedding):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
for documentation.
Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Embedding
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
qconfig=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, (
"Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(qconfig.weight().qscheme)
)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input) -> Tensor:
return F.embedding(
input,
self.weight_fake_quant(self.weight),
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, (
"Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(weight_qscheme)
)
qconfig = mod.qconfig
qat_embedding_bag = cls(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
mod.weight,
qconfig=qconfig,
)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
None,
)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag
class EmbeddingBag(nn.EmbeddingBag):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
for documentation.
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.EmbeddingBag
def __init__(
self,
num_embeddings,
embedding_dim,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
mode="mean",
sparse=False,
_weight=None,
include_last_offset=False,
padding_idx=None,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_embeddings,
embedding_dim,
max_norm,
norm_type,
scale_grad_by_freq,
mode,
sparse,
_weight,
include_last_offset,
padding_idx,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, (
"Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(qconfig.weight().qscheme)
)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
return F.embedding_bag(
input,
self.weight_fake_quant(self.weight),
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, (
"Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(weight_qscheme)
)
qconfig = mod.qconfig
qat_embedding_bag = cls(
mod.num_embeddings,
mod.embedding_dim,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.mode,
mod.sparse,
mod.weight,
mod.include_last_offset,
mod.padding_idx,
qconfig=qconfig,
)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.EmbeddingBag(
self.num_embeddings,
self.embedding_dim,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
None,
self.include_last_offset,
self.padding_idx,
)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.nn.intrinsic import LinearReLU
from torch.nn.utils.parametrize import (
is_parametrized,
transfer_parametrizations_and_params,
type_before_parametrizations,
)
__all__ = ["Linear"]
class Linear(nn.Linear):
r"""
A linear module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Linear
def __init__(
self,
in_features,
out_features,
bias=True,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(in_features, out_features, bias, **factory_kwargs)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if type_before_parametrizations(mod) == LinearReLU:
mod = mod[0]
qconfig = mod.qconfig
qat_linear = cls(
mod.in_features,
mod.out_features,
bias=mod.bias is not None,
qconfig=qconfig,
)
if is_parametrized(mod, "weight"):
transfer_parametrizations_and_params(mod, qat_linear, "weight")
else:
qat_linear.weight = mod.weight
if is_parametrized(mod, "bias"):
transfer_parametrizations_and_params(mod, qat_linear, "bias")
else:
qat_linear.bias = mod.bias
return qat_linear
def to_float(self):
linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias is not None
)
linear.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
linear.bias = torch.nn.Parameter(self.bias.detach())
linear.train(self.training)
return linear

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,9 @@
from .activation import MultiheadAttention
from .rnn import LSTM, LSTMCell
__all__ = [
"LSTM",
"LSTMCell",
"MultiheadAttention",
]

View File

@ -0,0 +1,550 @@
# mypy: allow-untyped-defs
import warnings
from typing import Optional, Tuple
import torch
import torch.jit # this is needed to avoid a circular import
import torch.nn.functional as F
from torch import nn, Tensor
__all__ = ["MultiheadAttention"]
class MultiheadAttention(nn.MultiheadAttention):
_FLOAT_MODULE = nn.MultiheadAttention
r"""Quantizable implementation of the MultiheadAttention.
Note::
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
information
Allows the model to jointly attend to information from different
representation subspaces.
See reference: Attention Is All You Need
The original MHA module is not quantizable.
This reimplements it by explicitly instantiating the linear layers.
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
to :attr:`embed_dim` such that query, key, and value have the same
number of features.
Examples::
>>> import torch.ao.nn.quantizable as nnqa
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
Note::
Please, follow the quantization flow to convert the quantizable MHA.
"""
__constants__ = ["batch_first"]
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
kdim: Optional[int] = None,
vdim: Optional[int] = None,
batch_first: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
embed_dim,
num_heads,
dropout,
bias,
add_bias_kv,
add_zero_attn,
kdim,
vdim,
batch_first,
**factory_kwargs,
)
self.linear_Q = nn.Linear(
self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs
)
self.linear_K = nn.Linear(
self.kdim, self.embed_dim, bias=bias, **factory_kwargs
)
self.linear_V = nn.Linear(
self.vdim, self.embed_dim, bias=bias, **factory_kwargs
)
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
# Functionals
self.q_scaling_product = torch.ao.nn.quantized.FloatFunctional()
# note: importing torch.ao.nn.quantized at top creates a circular import
# Quant/Dequant
self.quant_attn_output = torch.ao.quantization.QuantStub()
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
self.dequant_q = torch.ao.quantization.DeQuantStub()
self.dequant_k = torch.ao.quantization.DeQuantStub()
self.dequant_v = torch.ao.quantization.DeQuantStub()
def _get_name(self):
return "QuantizableMultiheadAttention"
@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
# Setting the dropout to 0.0!
observed = cls(
other.embed_dim,
other.num_heads,
other.dropout,
(other.in_proj_bias is not None),
(other.bias_k is not None),
other.add_zero_attn,
other.kdim,
other.vdim,
other.batch_first,
)
observed.bias_k = other.bias_k
observed.bias_v = other.bias_v
observed.qconfig = other.qconfig
# Set the linear weights
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
if other._qkv_same_embed_dim:
# Use separate params
bias = other.in_proj_bias
_start = 0
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_Q.weight = torch.nn.Parameter(weight, weight.requires_grad)
observed.linear_Q.bias = bias
bias = other.in_proj_bias
_start = _end
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_K.weight = torch.nn.Parameter(weight, weight.requires_grad)
observed.linear_K.bias = bias
bias = other.in_proj_bias
_start = _end
weight = other.in_proj_weight[_start:, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
observed.linear_V.weight = torch.nn.Parameter(weight, weight.requires_grad)
observed.linear_V.bias = bias
else:
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None:
observed.linear_Q.bias = None # type: ignore[assignment]
observed.linear_K.bias = None # type: ignore[assignment]
observed.linear_V.bias = None # type: ignore[assignment]
else:
observed.linear_Q.bias = nn.Parameter(
other.in_proj_bias[0 : other.embed_dim]
)
observed.linear_K.bias = nn.Parameter(
other.in_proj_bias[other.embed_dim : (other.embed_dim * 2)]
)
observed.linear_V.bias = nn.Parameter(
other.in_proj_bias[(other.embed_dim * 2) :]
)
observed.eval()
# Explicit prepare
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@torch.jit.unused
def dequantize(self):
r"""Utility to convert the quantized MHA back to float.
The motivation for this is that it is not trivial to conver the weights
from the format that is used in the quantized version back to the
float.
"""
fp = self._FLOAT_MODULE(
self.embed_dim,
self.num_heads,
self.dropout,
(self.linear_Q._weight_bias()[1] is not None),
(self.bias_k is not None),
self.add_zero_attn,
self.kdim,
self.vdim,
self.batch_first,
)
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
if self.bias_k is not None:
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
if self.bias_v is not None:
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
# Set the linear weights
# Note: Because the linear layers are quantized, mypy does not nkow how
# to deal with them -- might need to ignore the typing checks.
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
fp.out_proj.weight = nn.Parameter(w.dequantize())
if b is not None:
fp.out_proj.bias = nn.Parameter(b)
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
wQ = wQ.dequantize()
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
wK = wK.dequantize()
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
wV = wV.dequantize()
if fp._qkv_same_embed_dim:
# Use separate params
_start = 0
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None:
assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ
_start = _end
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None:
assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK
_start = _end
fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None:
assert all(bV == 0)
fp.in_proj_bias[_start:] = bV
else:
fp.q_proj_weight = nn.Parameter(wQ)
fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None:
self.linear_Q.bias = None
self.linear_K.bias = None
self.linear_V.bias = None
else:
fp.in_proj_bias[0 : fp.embed_dim] = bQ
fp.in_proj_bias[fp.embed_dim : (fp.embed_dim * 2)] = bK
fp.in_proj_bias[(fp.embed_dim * 2) :] = bV
return fp
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
# See nn.quantized.MultiheadAttention
raise NotImplementedError(
"It looks like you are trying to prepare an "
"MHA module. Please, see "
"the examples on quantizable MHAs."
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Note::
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
information
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- is_causal: If specified, applies a causal mask as attention mask. Mutually exclusive with providing attn_mask.
Default: ``False``.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(N, num_heads, L, S)`.
"""
return self._forward_impl(
query,
key,
value,
key_padding_mask,
need_weights,
attn_mask,
average_attn_weights,
is_causal,
)
def _forward_impl(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
# This version will not deal with the static key/value pairs.
# Keeping it here for future changes.
#
# TODO: This method has some duplicate lines with the
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
static_k = None
static_v = None
if attn_mask is not None and is_causal:
raise AssertionError("Only allow causal mask or attn_mask")
if is_causal:
raise AssertionError("causal mask not supported by AO MHA module")
if self.batch_first:
query, key, value = (x.transpose(0, 1) for x in (query, key, value))
tgt_len, bsz, embed_dim_to_check = query.size()
assert self.embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = self.embed_dim // self.num_heads
assert (
head_dim * self.num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
q = self.linear_Q(query)
k = self.linear_K(key)
v = self.linear_V(value)
q = self.q_scaling_product.mul_scalar(q, scaling)
if attn_mask is not None:
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for `attn_mask` in `nn.MultiheadAttention` is deprecated. "
"Use bool tensor instead.",
stacklevel=3,
)
attn_mask = attn_mask.to(torch.bool)
assert (
attn_mask.is_floating_point() or attn_mask.dtype == torch.bool
), f"Only float and bool types are supported for attn_mask, not {attn_mask.dtype}"
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * self.num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for `key_padding_mask` in `nn.MultiheadAttention` is deprecated. "
"Use bool tensor instead.",
stacklevel=3,
)
key_padding_mask = key_padding_mask.to(torch.bool)
if self.bias_k is not None and self.bias_v is not None:
if static_k is None and static_v is None:
# Explicitly assert that bias_k and bias_v are not None
# in a way that TorchScript can understand.
bias_k = self.bias_k
assert bias_k is not None
bias_v = self.bias_v
assert bias_v is not None
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert self.bias_k is None
assert self.bias_v is None
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * self.num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * self.num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
if k.is_quantized:
k_zeros = torch.quantize_per_tensor(
k_zeros, k.q_scale(), k.q_zero_point(), k.dtype
)
k = torch.cat([k, k_zeros], dim=1)
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
if v.is_quantized:
v_zeros = torch.quantize_per_tensor(
v_zeros, v.q_scale(), v.q_zero_point(), v.dtype
)
v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None:
attn_mask = F.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = F.pad(key_padding_mask, (0, 1))
# Leaving the quantized zone here
q = self.dequant_q(q)
k = self.dequant_k(k)
v = self.dequant_v(v)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [
bsz * self.num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * self.num_heads, tgt_len, src_len
)
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
if self.batch_first:
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
else:
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, self.embed_dim)
)
# Reentering the quantized zone
attn_output = self.quant_attn_output(attn_output)
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
return attn_output, attn_output_weights
else:
return attn_output, None

View File

@ -0,0 +1,499 @@
"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""
# mypy: allow-untyped-defs
import numbers
import warnings
from typing import Optional, Tuple
import torch
from torch import Tensor
__all__ = ["LSTMCell", "LSTM"]
class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
Examples::
>>> import torch.ao.nn.quantizable as nnqa
>>> rnn = nnqa.LSTMCell(10, 20)
>>> input = torch.randn(6, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell
def __init__(
self,
input_dim: int,
hidden_dim: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.igates = torch.nn.Linear(
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.hgates = torch.nn.Linear(
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.gates = torch.ao.nn.quantized.FloatFunctional()
self.input_gate = torch.nn.Sigmoid()
self.forget_gate = torch.nn.Sigmoid()
self.cell_gate = torch.nn.Tanh()
self.output_gate = torch.nn.Sigmoid()
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
self.hidden_state_dtype: torch.dtype = torch.quint8
self.cell_state_dtype: torch.dtype = torch.quint8
def forward(
self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
if hidden is None or hidden[0] is None or hidden[1] is None:
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = self.input_gate(input_gate)
forget_gate = self.forget_gate(forget_gate)
cell_gate = self.cell_gate(cell_gate)
out_gate = self.output_gate(out_gate)
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
cy = fgate_cx_igate_cgate
# TODO: make this tanh a member of the module so its qparams can be configured
tanh_cy = torch.tanh(cy)
hy = self.ogate_cy.mul(out_gate, tanh_cy)
return hy, cy
def initialize_hidden(
self, batch_size: int, is_quantized: bool = False
) -> Tuple[Tensor, Tensor]:
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros(
(batch_size, self.hidden_size)
)
if is_quantized:
(h_scale, h_zp) = self.initial_hidden_state_qparams
(c_scale, c_zp) = self.initial_cell_state_qparams
h = torch.quantize_per_tensor(
h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype
)
c = torch.quantize_per_tensor(
c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype
)
return h, c
def _get_name(self):
return "QuantizableLSTMCell"
@classmethod
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
wi, wh: Weights for the input and hidden layers
bi, bh: Biases for the input and hidden layers
"""
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
return cell
@classmethod
def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
observed = cls.from_params(
other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
)
observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig
return observed
class _LSTMSingleLayer(torch.nn.Module):
r"""A single one-directional LSTM layer.
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
seq_len = x.shape[0]
for i in range(seq_len):
hidden = self.cell(x[i], hidden)
result.append(hidden[0]) # type: ignore[index]
result_tensor = torch.stack(result, 0)
return result_tensor, hidden
@classmethod
def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs)
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
layer.cell = cell
return layer
class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
bias: bool = True,
batch_first: bool = False,
bidirectional: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
if hidden is None:
hx_fw, cx_fw = (None, None)
else:
hx_fw, cx_fw = hidden
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
if self.bidirectional:
if hx_fw is None:
hx_bw = None
else:
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
if cx_fw is None:
cx_bw = None
else:
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
if hx_bw is not None and cx_bw is not None:
hidden_bw = hx_bw, cx_bw
if hx_fw is None and cx_fw is None:
hidden_fw = None
else:
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(
cx_fw
)
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
if hasattr(self, "layer_bw") and self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
if hidden_fw is None and hidden_bw is None:
h = None
c = None
elif hidden_fw is None:
(h, c) = torch.jit._unwrap_optional(hidden_bw)
elif hidden_bw is None:
(h, c) = torch.jit._unwrap_optional(hidden_fw)
else:
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
return result, (h, c)
@classmethod
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
There is no FP equivalent of this class. This function is here just to
mimic the behavior of the `prepare` within the `torch.ao.quantization`
flow.
"""
assert hasattr(other, "qconfig") or (qconfig is not None)
input_size = kwargs.get("input_size", other.input_size)
hidden_size = kwargs.get("hidden_size", other.hidden_size)
bias = kwargs.get("bias", other.bias)
batch_first = kwargs.get("batch_first", other.batch_first)
bidirectional = kwargs.get("bidirectional", other.bidirectional)
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}")
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
if other.bidirectional:
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples below.
Examples::
>>> import torch.ao.nn.quantizable as nnqa
>>> rnn = nnqa.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> # To get the weights:
>>> # xdoctest: +SKIP
>>> print(rnn.layers[0].weight_ih)
tensor([[...]])
>>> print(rnn.layers[0].weight_hh)
AssertionError: There is no reverse path in the non-bidirectional layer
"""
_FLOAT_MODULE = torch.nn.LSTM
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: bool = True,
batch_first: bool = False,
dropout: float = 0.0,
bidirectional: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.training = False # Default to eval mode. If we want to train, we will explicitly set to training.
num_directions = 2 if bidirectional else 1
if (
not isinstance(dropout, numbers.Number)
or not 0 <= dropout <= 1
or isinstance(dropout, bool)
):
raise ValueError(
"dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed"
)
if dropout > 0:
warnings.warn(
"dropout option for quantizable LSTM is ignored. "
"If you are training, please, use nn.LSTM version "
"followed by `prepare` step."
)
if num_layers == 1:
warnings.warn(
"dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
f"num_layers greater than 1, but got dropout={dropout} "
f"and num_layers={num_layers}"
)
layers = [
_LSTMLayer(
self.input_size,
self.hidden_size,
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs,
)
]
for layer in range(1, num_layers):
layers.append(
_LSTMLayer(
self.hidden_size,
self.hidden_size,
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs,
)
)
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
max_batch_size = x.size(1)
num_directions = 2 if self.bidirectional else 1
if hidden is None:
zeros = torch.zeros(
num_directions,
max_batch_size,
self.hidden_size,
dtype=torch.float,
device=x.device,
)
zeros.squeeze_(0)
if x.is_quantized:
zeros = torch.quantize_per_tensor(
zeros, scale=1.0, zero_point=0, dtype=x.dtype
)
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
else:
hidden_non_opt = torch.jit._unwrap_optional(hidden)
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(
self.num_layers, num_directions, max_batch_size, self.hidden_size
)
cx = hidden_non_opt[1].reshape(
self.num_layers, num_directions, max_batch_size, self.hidden_size
)
hxcx = [
(hx[idx].squeeze(0), cx[idx].squeeze(0))
for idx in range(self.num_layers)
]
else:
hxcx = hidden_non_opt
hx_list = []
cx_list = []
for idx, layer in enumerate(self.layers):
x, (h, c) = layer(x, hxcx[idx])
hx_list.append(torch.jit._unwrap_optional(h))
cx_list.append(torch.jit._unwrap_optional(c))
hx_tensor = torch.stack(hx_list)
cx_tensor = torch.stack(cx_list)
# We are creating another dimension for bidirectional case
# need to collapse it
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
if self.batch_first:
x = x.transpose(0, 1)
return x, (hx_tensor, cx_tensor)
def _get_name(self):
return "QuantizableLSTM"
@classmethod
def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE)
assert hasattr(other, "qconfig") or qconfig
observed = cls(
other.input_size,
other.hidden_size,
other.num_layers,
other.bias,
other.batch_first,
other.dropout,
other.bidirectional,
)
observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(
other, idx, qconfig, batch_first=False
)
# Prepare the model
if other.training:
observed.train()
observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
else:
observed.eval()
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-quantizable LSTM module. Please, see "
"the examples on quantizable LSTMs."
)

View File

@ -0,0 +1,39 @@
from . import functional
from .modules import * # noqa: F403
from .modules import MaxPool2d
__all__ = [
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
"PReLU",
# Wrapper modules
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,26 @@
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .linear import Linear
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell
__all__ = [
"Linear",
"LSTM",
"GRU",
"LSTMCell",
"RNNCell",
"GRUCell",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -0,0 +1,520 @@
# mypy: allow-untyped-defs
r"""Dynamically quantized convolution modules."""
import warnings
import torch
import torch.ao.nn.quantized as nnq
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch._ops import ops
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _pair, _single, _triple
__all__ = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]
class Conv1d(nnq.Conv1d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv1d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv1d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
reduce_range=True,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv1d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != "zeros":
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
class Conv2d(nnq.Conv2d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv2d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv2d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module "
"has poor numerical accuracy and its use is not recommended"
)
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv2d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range)
class Conv3d(nnq.Conv3d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv3d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
>>> input = torch.randn(20, 16, 56, 56, 56)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv3d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super()._init(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_triple(0),
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv3d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range)
class ConvTranspose1d(nnq.ConvTranspose1d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose1d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose1d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
>>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose1d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose1d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
return torch.ops.quantized.conv_transpose1d_dynamic(
input, self._packed_params, reduce_range
)
class ConvTranspose2d(nnq.ConvTranspose2d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose2d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose2d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose2d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return ops.quantized.conv_transpose2d_dynamic(
input, self._packed_params, reduce_range
)
class ConvTranspose3d(nnq.ConvTranspose3d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose3d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose3d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With cubic kernels and equal stride
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-cubic kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose3d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose3d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, T, H, W)`!")
return ops.quantized.conv_transpose3d_dynamic(
input, self._packed_params, reduce_range
)

View File

@ -0,0 +1,165 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = [
"Linear",
]
class Linear(nnq.Linear):
r"""
A dynamic quantized linear module with floating point tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module which are of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable floating point bias of the module of shape
:math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
the values are initialized to zero.
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
# version used in this class is different from the parent class nnq.Linear
_version = 4
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias_, dtype=dtype)
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.version = 4
def forward(self, x):
# Note that we can handle self.bias == None case.
if self._packed_params.dtype == torch.qint8:
if self.version is None or self.version < 4:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params
)
else:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params, reduce_range=True
)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params
)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
return Y.to(x.dtype)
def _get_name(self):
return "DynamicQuantizedLinear"
def extra_repr(self):
extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}"
if self._packed_params.dtype == torch.qint8:
extra_repr_str += f", qscheme={self.weight().qscheme()}"
return extra_repr_str
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
self.version = version
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
float_modules = [
torch.nn.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.ao.nn.intrinsic.modules.fused.LinearReLU,
torch.ao.nn.qat.dynamic.Linear,
]
assert (
type(mod) in float_modules
), "nn.quantized.dynamic.Linear.from_float only works for one of" + str(
[float_mod.__name__ for float_mod in float_modules]
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
mod = mod[0]
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
assert dtype in [torch.qint8, torch.float16], (
"The only supported dtypes for "
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
)
weight_observer(mod.weight)
if dtype == torch.qint8:
qweight = _quantize_weight(mod.weight.float(), weight_observer)
elif dtype == torch.float16:
qweight = mod.weight.float()
else:
raise RuntimeError(
"Unsupported dtype specified for dynamic quantized Linear!"
)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear):
"""Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
module
Args:
ref_qlinear (Module): a reference quantized module, either produced by
torch.ao.quantization functions or provided by the user
"""
qlinear = cls(
ref_qlinear.in_features,
ref_qlinear.out_features,
dtype=ref_qlinear.weight_dtype,
)
qweight = ref_qlinear.get_quantized_weight()
bias = ref_qlinear.bias
qlinear.set_weight_bias(qweight, bias)
return qlinear

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,778 @@
# mypy: allow-untyped-defs
r""" Functional interface (quantized)."""
import warnings
from typing import List, Optional
import torch
from torch import Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair, _triple
from .modules.utils import _pair_from_first
# Although some of the functions and docstrings are mirrored from the torch.nn,
# we want to have them here for future changes.
__all__ = [
"avg_pool2d",
"avg_pool3d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"conv1d",
"conv2d",
"conv3d",
"interpolate",
"linear",
"max_pool1d",
"max_pool2d",
"celu",
"leaky_relu",
"hardtanh",
"hardswish",
"threshold",
"elu",
"hardsigmoid",
"clamp",
"upsample",
"upsample_bilinear",
"upsample_nearest",
]
def avg_pool2d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
r"""
Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
:math:`sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
return torch.nn.functional.avg_pool2d(
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def avg_pool3d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
r"""
Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
:math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kD, kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
return torch.nn.functional.avg_pool3d(
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 2D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError(
"Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!"
)
return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 3D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError(
"Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!"
)
return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
def conv1d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 1D convolution over a quantized 1D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sW,)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padW,)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dW,)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(33, 16, 3, dtype=torch.float)
>>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
>>> bias = torch.randn(33, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
stride = _pair_from_first(stride)
padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation)
packed_params = torch.ops.quantized.conv1d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
def conv2d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 2D convolution over a quantized 2D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
packed_params = torch.ops.quantized.conv2d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
def conv3d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 3D convolution over a quantized 3D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape.
Args:
input: quantized input tensor of shape
:math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
weight: quantized filters of shape
:math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
bias: **non-quantized** bias tensor of shape
:math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sD, sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dD, dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be
divisible by the number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for
quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
packed_params = torch.ops.quantized.conv3d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
r"""Down/up samples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
See :func:`torch.nn.functional.interpolate` for implementation details.
The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D/3D input is supported for quantized inputs
.. note:: Only the following modes are supported for the quantized inputs:
- `bilinear`
- `nearest`
Args:
input (Tensor): the input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'bilinear'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values, making this operation *independent* of input size
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
is ``'bilinear'``.
Default: ``False``
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.interpolate' must be quantized!")
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
def linear(
input: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
scale: Optional[float] = None,
zero_point: Optional[int] = None,
) -> Tensor:
r"""
Applies a linear transformation to the incoming quantized data:
:math:`y = xA^T + b`.
See :class:`~torch.ao.nn.quantized.Linear`
.. note::
Current implementation packs weights on every call, which has penalty on performance.
If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`.
Args:
input (Tensor): Quantized input of type `torch.quint8`
weight (Tensor): Quantized weight of type `torch.qint8`
bias (Tensor): None or fp32 bias of type `torch.float`
scale (double): output scale. If None, derived from the input scale
zero_point (long): output zero point. If None, derived from the input zero_point
Shape:
- Input: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions
- Weight: :math:`(out\_features, in\_features)`
- Bias: :math:`(out\_features)`
- Output: :math:`(N, *, out\_features)`
"""
if scale is None:
scale = input.q_scale()
if zero_point is None:
zero_point = input.q_zero_point()
_packed_params = torch.ops.quantized.linear_prepack(weight, bias)
return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
def max_pool1d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
r"""Applies a 1D max pooling over a quantized input signal composed of
several quantized input planes.
.. note:: The input quantization parameters are propagated to the output.
See :class:`~torch.ao.nn.quantized.MaxPool1d` for details.
"""
if return_indices:
raise NotImplementedError("return_indices is not yet implemented!")
if stride is None:
stride = torch.jit.annotate(List[int], [])
return torch.nn.functional.max_pool1d(
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode=ceil_mode,
return_indices=return_indices,
)
def max_pool2d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
r"""Applies a 2D max pooling over a quantized input signal composed of
several quantized input planes.
.. note:: The input quantization parameters are propagated to the output.
See :class:`~torch.ao.nn.quantized.MaxPool2d` for details.
"""
if return_indices:
raise NotImplementedError("return_indices is not yet implemented!")
if stride is None:
stride = torch.jit.annotate(List[int], [])
return torch.nn.functional.max_pool2d(
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode=ceil_mode,
return_indices=return_indices,
)
def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
r"""celu(input, scale, zero_point, alpha=1.) -> Tensor
Applies the quantized CELU function element-wise.
.. math::
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
Args:
input: quantized input
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.celu' must be quantized!")
return torch.ops.quantized.celu(input, scale, zero_point, alpha)
def leaky_relu(
input: Tensor,
negative_slope: float = 0.01,
inplace: bool = False,
scale: Optional[float] = None,
zero_point: Optional[int] = None,
):
r"""
Quantized version of the.
leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
Applies element-wise,
:math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
Args:
input: Quantized input
negative_slope: The slope of the negative input
inplace: Inplace modification of the input tensor
scale, zero_point: Scale and zero point of the output tensor.
See :class:`~torch.nn.LeakyReLU` for more details.
"""
if scale is not None and zero_point is not None:
assert not inplace, "Cannot rescale with `inplace`"
output = torch._empty_affine_quantized(
input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype
)
torch._C._nn.leaky_relu(input, negative_slope, out=output)
return output
if inplace:
result = torch._C._nn.leaky_relu_(input, negative_slope)
else:
result = torch._C._nn.leaky_relu(input, negative_slope)
return result
def hardtanh(
input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False
) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`."""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
if inplace:
return torch._C._nn.hardtanh_(input, min_val, max_val)
return torch._C._nn.hardtanh(input, min_val, max_val)
def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`.
Args:
input: quantized input
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardswish' must be quantized!")
return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
def threshold(input: Tensor, threshold: float, value: float) -> Tensor:
r"""Applies the quantized version of the threshold function element-wise:
.. math::
x = \begin{cases}
x & \text{if~} x > \text{threshold} \\
\text{value} & \text{otherwise}
\end{cases}
See :class:`~torch.nn.Threshold` for more details.
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.threshold' must be quantized!")
if threshold is None:
raise ValueError("Input to 'threshold' must be specified!")
if value is None:
raise ValueError("Input to 'value' must be specified!")
return torch._ops.ops.quantized.threshold(input, threshold, value)
def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
Args:
input: quantized input
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
alpha: the alpha constant
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.elu' must be quantized!")
return torch.ops.quantized.elu(input, scale, zero_point, alpha)
def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`."""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
if inplace:
return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined]
return torch._C._nn.hardsigmoid(input)
def clamp(input: Tensor, min_: float, max_: float) -> Tensor:
r"""float(input, min\_, max\_) -> Tensor
Applies the clamp function element-wise.
See :class:`~torch.ao.nn.quantized.clamp` for more details.
Args:
input: quantized input
min_: minimum value for clamping
max_: maximum value for clamping
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.clamp' must be quantized!")
return torch.clamp(input, min_, max_)
def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
r"""Upsamples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with ``nn.quantized.functional.interpolate(...)``.
See :func:`torch.nn.functional.interpolate` for implementation details.
The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D input is supported for quantized inputs
.. note:: Only the following modes are supported for the quantized inputs:
- `bilinear`
- `nearest`
Args:
input (Tensor): quantized input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'bilinear'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values, making this operation *independent* of input size
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
is ``'bilinear'``.
Default: ``False``
.. warning::
With ``align_corners = True``, the linearly interpolating modes
(`bilinear`) don't proportionally align the
output and input pixels, and thus the output values can depend on the
input size. This was the default behavior for these modes up to version
0.3.1. Since then, the default behavior is ``align_corners = False``.
See :class:`~torch.nn.Upsample` for concrete examples on how this
affects the outputs.
"""
warnings.warn(
"nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode, align_corners)
def upsample_bilinear(input, size=None, scale_factor=None):
r"""Upsamples the input, using bilinear upsampling.
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with
``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D inputs are supported
Args:
input (Tensor): quantized input
size (int or Tuple[int, int]): output spatial size.
scale_factor (int or Tuple[int, int]): multiplier for spatial size
"""
# DeprecationWarning is ignored by default
warnings.warn(
"nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
def upsample_nearest(input, size=None, scale_factor=None):
r"""Upsamples the input, using nearest neighbours' pixel values.
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D inputs are supported
Args:
input (Tensor): quantized input
size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
size.
scale_factor (int): multiplier for spatial size. Has to be an integer.
"""
# DeprecationWarning is ignored by default
warnings.warn(
"nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode="nearest")

View File

@ -0,0 +1,162 @@
# mypy: allow-untyped-defs
import torch
# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable`
# packages. However, the `quantizable` package uses "lazy imports"
# to avoid circular dependency.
# Hence we need to include it here to make sure it is resolved before
# they are used in the modules.
import torch.ao.nn.quantizable
from torch.nn.modules.pooling import MaxPool2d
from .activation import (
ELU,
Hardswish,
LeakyReLU,
MultiheadAttention,
PReLU,
ReLU6,
Sigmoid,
Softmax,
)
from .batchnorm import BatchNorm2d, BatchNorm3d
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .dropout import Dropout
from .embedding_ops import Embedding, EmbeddingBag
from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
from .linear import Linear
from .normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from .rnn import LSTM
__all__ = [
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
"PReLU",
# Wrapper modules
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]
class Quantize(torch.nn.Module):
r"""Quantizes an incoming tensor
Args:
`scale`: scale of the output Quantized Tensor
`zero_point`: zero_point of output Quantized Tensor
`dtype`: data type of output Quantized Tensor
`factory_kwargs`: Dictionary of kwargs used for configuring initialization
of internal buffers. Currently, `device` and `dtype` are supported.
Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}`
will initialize internal buffers as type `torch.float64` on the current CUDA device.
Note that `dtype` only applies to floating-point buffers.
Examples::
>>> t = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> # xdoctest: +SKIP
>>> qt = qm(t)
>>> print(qt)
tensor([[ 1., -1.],
[ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, scale, zero_point, dtype, factory_kwargs=None):
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
super().__init__()
self.register_buffer("scale", torch.tensor([scale], **factory_kwargs))
self.register_buffer(
"zero_point",
torch.tensor(
[zero_point],
dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
self.dtype = dtype
def forward(self, X):
return torch.quantize_per_tensor(
X, float(self.scale), int(self.zero_point), self.dtype
)
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
assert hasattr(mod, "activation_post_process")
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(
scale.float().item(),
zero_point.long().item(),
mod.activation_post_process.dtype,
)
def extra_repr(self):
return f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}"
class DeQuantize(torch.nn.Module):
r"""Dequantizes an incoming tensor
Examples::
>>> input = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> # xdoctest: +SKIP
>>> quantized_input = qm(input)
>>> dqm = DeQuantize()
>>> dequantized = dqm(quantized_input)
>>> print(dequantized)
tensor([[ 1., -1.],
[ 1., -1.]], dtype=torch.float32)
"""
def forward(self, Xq):
return Xq.dequantize()
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
return DeQuantize()

View File

@ -0,0 +1,343 @@
# mypy: allow-untyped-defs
from warnings import warn
import torch
__all__ = [
"ReLU6",
"Hardswish",
"ELU",
"LeakyReLU",
"Sigmoid",
"Softmax",
"MultiheadAttention",
"PReLU",
]
class ReLU6(torch.nn.ReLU):
r"""Applies the element-wise function:
:math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
zero_point, and :math:`q(6)` is the quantized representation of number 6.
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/ReLU6.png
Examples::
>>> m = nn.quantized.ReLU6()
>>> input = torch.randn(2)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
>>> output = m(input)
"""
def __init__(self, inplace=False):
super().__init__(inplace)
self.inplace = inplace
def forward(self, input):
return torch.ops.quantized.relu6(input, self.inplace)
def _get_name(self):
return "QuantizedReLU6"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
return ReLU6(mod.inplace)
class Hardswish(torch.nn.Hardswish):
r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
def _get_name(self):
return "QuantizedHardswish"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Hardswish(float(scale), int(zero_point))
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point))
class ELU(torch.nn.ELU):
r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
alpha: the alpha constant
"""
def __init__(self, scale, zero_point, alpha=1.0):
super().__init__(alpha)
self.scale = scale
self.zero_point = zero_point
def forward(self, input):
return torch.ao.nn.quantized.functional.elu(
input, self.scale, self.zero_point, self.alpha
)
def _get_name(self):
return "QuantizedELU"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return ELU(float(scale), int(zero_point), mod.alpha)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point), mod.alpha)
class LeakyReLU(torch.nn.LeakyReLU):
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
negative_slope: Controls the angle of the negative slope. Default: 1e-2
"""
def __init__(
self,
scale: float,
zero_point: int,
negative_slope: float = 1e-2,
inplace: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace)
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.leaky_relu(
input, self.negative_slope, self.inplace, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLeakyReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
class Sigmoid(torch.nn.Sigmoid):
r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, output_scale: float, output_zero_point: int):
super().__init__()
self.output_scale = output_scale
self.output_zero_point = output_zero_point
def forward(self, input):
return torch.ops.quantized.sigmoid(
input, self.output_scale, self.output_zero_point
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
(
output_scale,
output_zero_point,
) = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point))
class Softmax(torch.nn.Softmax):
r"""This is the quantized version of :class:`~torch.nn.Softmax`.
Args:
dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, dim=None, scale=1.0, zero_point=0):
super().__init__()
self.dim = dim
self.scale = scale
self.zero_point = zero_point
def forward(self, input):
dim = self.dim
if dim is None:
stacklevel = 3
# Note: adding the mypy ignore on _get_softmax_dim seems less bad
# than making `_get_softmax_dim` an official API.
dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined]
"softmax", input.dim(), stacklevel
)
return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point)
def _get_name(self):
return "QuantizedSoftmax"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Softmax(mod.dim, float(scale), int(zero_point))
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(mod.dim, float(scale), int(zero_point))
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self):
return "QuantizedMultiheadAttention"
@classmethod
def from_float(cls, other):
# The whole flow is float -> observed -> quantized
# This class does observed -> quantized only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-observed MHA module. Please, see "
"the examples on quantizable MHAs."
)
@classmethod
def from_observed(cls, other):
converted = torch.ao.quantization.convert(
other,
mapping=None,
inplace=False,
remove_qconfig=True,
convert_custom_config_dict=None,
)
converted.__class__ = cls
# Remove the parameters for the bias_k and bias_v to quantize them
# TODO: This is a potential source of accuracy drop.
# quantized cat takes the scale and zp of the first
# element, which might lose the precision in the bias_k
# and the bias_v (which are cat'ed with k/v being first).
if converted.bias_k is not None:
bias_k = converted._parameters.pop("bias_k")
sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False)
bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
setattr(converted, "bias_k", bias_k) # noqa: B010
if converted.bias_v is not None:
bias_v = converted._parameters.pop("bias_v")
sc, zp = torch._choose_qparams_per_tensor(
bias_k, reduce_range=False # type: ignore[possibly-undefined]
)
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
setattr(converted, "bias_v", bias_v) # noqa: B010
del converted.in_proj_weight
del converted.in_proj_bias
return converted
class PReLU(torch.nn.Module):
r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
"""
def __init__(
self, output_scale: float, output_zero_point: int, num_parameters: int = 1
) -> None:
super().__init__()
self.num_parameters = num_parameters
self.scale = output_scale
self.zero_point = output_zero_point
w = torch.randn(num_parameters, dtype=torch.float)
qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
self.set_weight(qw)
def set_weight(self, w: torch.Tensor) -> None:
self.weight = w
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.prelu(
input, self.weight, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedPReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8
)
qprelu.set_weight(qweight)
return qprelu
@classmethod
def from_reference(cls, mod, scale, zero_point):
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8
)
qprelu.set_weight(qweight)
return qprelu

View File

@ -0,0 +1,128 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
__all__ = ["BatchNorm2d", "BatchNorm3d"]
class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
def __init__(
self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE:
mod = mod[0]
scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.running_mean = mod.running_mean
new_mod.running_var = mod.running_var
new_mod.scale = scale
new_mod.zero_point = zero_point
return new_mod
@classmethod
def from_reference(cls, bn, output_scale, output_zero_point):
qbn = cls(
bn.num_features,
bn.eps,
bn.momentum,
device=bn.weight.device,
dtype=bn.weight.dtype,
)
qbn.weight = bn.weight
qbn.bias = bn.bias
qbn.running_mean = bn.running_mean
qbn.running_var = bn.running_var
qbn.scale = output_scale
qbn.zero_point = output_zero_point
return qbn
class BatchNorm2d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`."""
_NNI_BN_RELU_MODULE = nni.BNReLU2d
def __init__(
self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, **factory_kwargs)
def _get_name(self):
return "QuantizedBatchNorm2d"
def _check_input_dim(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
def forward(self, input: torch.Tensor) -> torch.Tensor:
# disabling this since this is not symbolically traceable
# self._check_input_dim(input)
return torch.ops.quantized.batch_norm2d(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class BatchNorm3d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`."""
_NNI_BN_RELU_MODULE = nni.BNReLU3d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, **factory_kwargs)
def _get_name(self):
return "QuantizedBatchNorm3d"
def _check_input_dim(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, H, W)`!")
def forward(self, input: torch.Tensor) -> torch.Tensor:
# disabling this since this is not symbolically traceable
# self._check_input_dim(input)
return torch.ops.quantized.batch_norm3d(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
# mypy: allow-untyped-defs
import torch
__all__ = ["Dropout"]
class Dropout(torch.nn.Dropout):
r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
And this is a placeholder to enable models where fp32 tensors
had dropout to work with quantized tensors in train and eval mode.
Args:
p: probability of an element to be zeroed
inplace: can optionally do the operation in-place. Default: ``False``
"""
def forward(self, input):
return input
def _get_name(self):
return "QuantizedDropout"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return cls(mod.p, mod.inplace)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(mod.p, mod.inplace)

View File

@ -0,0 +1,405 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import List, Optional # noqa: F401
from .utils import _hide_packed_params_repr, _quantize_weight
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
class EmbeddingPackedParams(torch.nn.Module):
_version = 1
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
super().__init__()
self.dtype = dtype
if self.dtype in [torch.quint8, torch.quint4x2]:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
wq = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=self.dtype,
)
self.set_weight(wq)
else:
raise NotImplementedError(
f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
)
@torch.jit.export
def set_weight(self, weight: torch.Tensor) -> None:
if self.dtype in [torch.quint8, torch.quint4x2]:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
)
@torch.jit.export
def _weight(self):
if self.dtype in [torch.quint8, torch.quint4x2]:
return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
)
def forward(self, x):
return x
# Version 1
# self
# |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "dtype"] = self.dtype
destination[prefix + "_packed_weight"] = self._weight()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.dtype = state_dict[prefix + "dtype"]
state_dict.pop(prefix + "dtype")
weight = state_dict[prefix + "_packed_weight"]
state_dict.pop(prefix + "_packed_weight")
self.set_weight(weight)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def __repr__(self):
return self._weight().__repr__()
class Embedding(torch.nn.Module):
r"""
A quantized Embedding module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation.
Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
>>> output = m(indices)
>>> print(output.size())
torch.Size([9, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
dtype=torch.quint8,
) -> None:
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.dtype = dtype
if _weight is None:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
qweight = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=torch.quint8,
)
else:
assert list(_weight.shape) == [
num_embeddings,
embedding_dim,
], "Shape of weight does not match num_embeddings and embedding_dim"
qweight = _weight
self._packed_params = EmbeddingPackedParams(
num_embeddings, embedding_dim, dtype
)
self._packed_params.set_weight(qweight)
def forward(self, indices: Tensor) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_4bit(
self._packed_params._packed_weight, indices
)
else:
return torch.ops.quantized.embedding_byte(
self._packed_params._packed_weight, indices
)
def _get_name(self):
return "QuantizedEmbedding"
def __repr__(self):
return _hide_packed_params_repr(self, EmbeddingPackedParams)
def extra_repr(self):
extra_repr_str = (
f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, "
f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}"
)
return extra_repr_str
def set_weight(self, w: torch.Tensor) -> None:
self._packed_params.set_weight(w)
def weight(self):
return self._packed_params._weight()
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
assert type(mod) == torch.ao.nn.qat.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float "
+ "with fake quant only works for "
+ torch.ao.nn.qat.Embedding.__name__
)
weight_observer = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
assert type(mod) == nn.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.Embedding.__name__
)
assert hasattr(
mod, "qconfig"
), "Embedding input float module must have qconfig defined"
from torch.ao.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "Embedding quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized Embedding module and pass in the quantized weight
qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
qembedding.set_weight(qweight)
return qembedding
@classmethod
def from_reference(cls, ref_embedding):
qembedding = cls(
ref_embedding.num_embeddings,
ref_embedding.embedding_dim,
ref_embedding.padding_idx,
ref_embedding.max_norm,
ref_embedding.norm_type,
ref_embedding.scale_grad_by_freq,
ref_embedding.sparse,
ref_embedding.get_quantized_weight(),
ref_embedding.weight_dtype,
)
return qembedding
class EmbeddingBag(Embedding):
r"""
A quantized EmbeddingBag module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation.
Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
>>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
>>> output = m(indices, offsets)
>>> print(output.size())
torch.Size([5, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
mode: str = "sum",
sparse: bool = False,
_weight: Optional[Tensor] = None,
include_last_offset: bool = False,
dtype=torch.quint8,
) -> None:
super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
self.mode = mode
self.pruned_weights = False
self.include_last_offset = include_last_offset
self.dtype = dtype
def forward(
self,
indices: Tensor,
offsets: Optional[Tensor] = None,
per_sample_weights: Optional[Tensor] = None,
compressed_indices_mapping: Optional[Tensor] = None,
) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_bag_4bit(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
else:
return torch.ops.quantized.embedding_bag_byte(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
def _get_name(self):
return "QuantizedEmbeddingBag"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding_bag module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
weight_observer = mod.weight_fake_quant
else:
assert type(mod) == nn.EmbeddingBag, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.EmbeddingBag.__name__
)
assert hasattr(
mod, "qconfig"
), "EmbeddingBag input float module must have qconfig defined"
from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized EmbeddingBag module and pass in the quantized weight
qembedding_bag = EmbeddingBag(
mod.num_embeddings, mod.embedding_dim, dtype=dtype
)
qembedding_bag.set_weight(qweight)
return qembedding_bag
@classmethod
def from_reference(cls, ref_embedding_bag):
qembedding_bag = cls(
ref_embedding_bag.num_embeddings,
ref_embedding_bag.embedding_dim,
ref_embedding_bag.max_norm,
ref_embedding_bag.norm_type,
ref_embedding_bag.scale_grad_by_freq,
ref_embedding_bag.mode,
ref_embedding_bag.sparse,
ref_embedding_bag.get_quantized_weight(),
ref_embedding_bag.include_last_offset,
ref_embedding_bag.weight_dtype,
)
return qembedding_bag

View File

@ -0,0 +1,299 @@
# mypy: allow-untyped-defs
from typing import List
import torch
from torch import Tensor
from torch._ops import ops
__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
class FloatFunctional(torch.nn.Module):
r"""State collector class for float operations.
The instance of this class can be used instead of the ``torch.`` prefix for
some operations. See example usage below.
.. note::
This class does not provide a ``forward`` hook. Instead, you must use
one of the underlying functions (e.g. ``add``).
Examples::
>>> f_add = FloatFunctional()
>>> a = torch.tensor(3.0)
>>> b = torch.tensor(4.0)
>>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)``
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def __init__(self) -> None:
super().__init__()
self.activation_post_process = torch.nn.Identity()
def forward(self, x):
raise RuntimeError(
"FloatFunctional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
r = self.activation_post_process(r)
return r
class FXFloatFunctional(torch.nn.Module):
r"""module to replace FloatFunctional module before FX graph mode quantization,
since activation_post_process will be inserted in top level module directly
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def forward(self, x):
raise RuntimeError(
"FloatFunctional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
return r
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
return r
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
return r
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
return r
r"""Operation equivalent to ``torch.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
return r
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
return r
class QFunctional(torch.nn.Module):
r"""Wrapper class for quantized operations.
The instance of this class can be used instead of the
``torch.ops.quantized`` prefix. See example usage below.
.. note::
This class does not provide a ``forward`` hook. Instead, you must use
one of the underlying functions (e.g. ``add``).
Examples::
>>> q_add = QFunctional()
>>> # xdoctest: +SKIP
>>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
>>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
>>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def __init__(self) -> None:
super().__init__()
self.scale = 1.0
self.zero_point = 0
self.activation_post_process = torch.nn.Identity()
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "scale"] = torch.tensor(self.scale)
destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.scale = float(state_dict.pop(prefix + "scale"))
self.zero_point = int(state_dict.pop(prefix + "zero_point"))
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def _get_name(self):
return "QFunctional"
def extra_repr(self):
return f"scale={self.scale}, zero_point={self.zero_point}"
def forward(self, x):
raise RuntimeError(
"Functional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.ops.quantized.add``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.add_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.mul_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.ops.quantized.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert (
type(mod) == FloatFunctional
), "QFunctional.from_float expects an instance of FloatFunctional"
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
new_mod = QFunctional()
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod

View File

@ -0,0 +1,358 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from collections.abc import Iterable
from typing import Optional
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn as nn
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import _hide_packed_params_repr, _quantize_weight, WeightedQuantizedModule
__all__ = ["LinearPackedParams", "Linear"]
class LinearPackedParams(torch.nn.Module):
_version = 3
def __init__(self, dtype=torch.qint8):
super().__init__()
self.dtype = dtype
if self.dtype == torch.qint8:
wq = torch._empty_affine_quantized(
[1, 1], scale=1.0, zero_point=0, dtype=torch.qint8
)
elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined]
@torch.jit.export
def set_weight_bias(
self, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> None:
if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16:
self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
@torch.jit.export
def _weight_bias(self):
if self.dtype == torch.qint8:
return torch.ops.quantized.linear_unpack(self._packed_params)
elif self.dtype == torch.float16:
return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
def forward(self, x):
return x
# Version 1
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- dtype : torch.dtype
#
# Version 3
# self
# |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
# of LinearPackedParams
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "dtype"] = self.dtype
destination[prefix + "_packed_params"] = self._weight_bias()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
self.dtype = torch.qint8
else:
self.dtype = state_dict[prefix + "dtype"]
state_dict.pop(prefix + "dtype")
if version is None or version < 3:
self.set_weight_bias(
state_dict[prefix + "weight"], state_dict[prefix + "bias"]
)
state_dict.pop(prefix + "weight")
state_dict.pop(prefix + "bias")
if version == 3:
weight, bias = state_dict[prefix + "_packed_params"]
state_dict.pop(prefix + "_packed_params")
self.set_weight_bias(weight, bias)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def __repr__(self):
return self._weight_bias().__repr__()
class Linear(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
scale: `scale` parameter of output Quantized Tensor, type: double
zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_version = 3
_FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__()
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.in_features = in_features
self.out_features = out_features
bias = None
if bias_:
bias = torch.zeros(out_features, dtype=torch.float)
if dtype == torch.qint8:
qweight = torch._empty_affine_quantized(
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
)
elif dtype == torch.float16:
qweight = torch.zeros([out_features, in_features], dtype=torch.float)
else:
raise RuntimeError("Unsupported dtype specified for quantized Linear!")
self._packed_params = LinearPackedParams(dtype)
self._packed_params.set_weight_bias(qweight, bias)
self.scale = 1.0
self.zero_point = 0
def _get_name(self):
return "QuantizedLinear"
def extra_repr(self):
return (
f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, "
f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}"
)
def __repr__(self):
return _hide_packed_params_repr(self, LinearPackedParams)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
# regular QTensor form for serialization. Packed weights should not live
# outside the process in which they were created, rather they should be derived
# from the QTensor weight.
#
# Version 1
# self
# |--- scale : float
# |--- zero_point : int
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 3
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- _packed_params : (Tensor, Tensor) representing weight, bias
# of LinearPackedParams C++ struct
#
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "scale"] = torch.tensor(self.scale)
destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.scale = float(state_dict[prefix + "scale"])
state_dict.pop(prefix + "scale")
self.zero_point = int(state_dict[prefix + "zero_point"])
state_dict.pop(prefix + "zero_point")
version = local_metadata.get("version", None)
if version is None or version == 1:
# We moved the parameters into a LinearPackedParameters submodule
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias")
state_dict.update(
{
prefix + "_packed_params.weight": weight,
prefix + "_packed_params.bias": bias,
}
)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
# Function rather than property to make sure that JIT serialization doesn't
# register this as an attribute
def _weight_bias(self):
return self._packed_params._weight_bias()
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params.set_weight_bias(w, b)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
use_precomputed_fake_quant (bool): if True, the module will reuse min/max
values from the precomputed fake quant module.
"""
if hasattr(mod, "weight_fake_quant"):
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
mod.weight, mod.bias = fuse_linear_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
# This function does not participate in JIT, so it is OK to ignore
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore[assignment]
supported_modules = ", ".join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore[attr-defined]
error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}"
assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore[attr-defined]
assert hasattr(
mod, "qconfig"
), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0]
weight_post_process = (
mod.qconfig.weight()
if not hasattr(mod, "weight_fake_quant")
else mod.weight_fake_quant
)
if not use_precomputed_fake_quant:
# Observer may not have been called yet
# Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
qlinear.scale = float(act_scale)
qlinear.zero_point = int(act_zp)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features)
qweight = ref_qlinear.get_quantized_weight()
qlinear.set_weight_bias(qweight, ref_qlinear.bias)
qlinear.scale = float(output_scale)
qlinear.zero_point = int(output_zero_point)
return qlinear

View File

@ -0,0 +1,346 @@
# mypy: allow-untyped-defs
import torch
__all__ = [
"LayerNorm",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
]
class LayerNorm(torch.nn.LayerNorm):
r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
normalized_shape,
weight,
bias,
scale,
zero_point,
eps=1e-5,
elementwise_affine=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
**factory_kwargs,
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.layer_norm(
input,
self.normalized_shape,
weight=self.weight,
bias=self.bias,
eps=self.eps,
output_scale=self.scale,
output_zero_point=self.zero_point,
)
def _get_name(self):
return "QuantizedLayerNorm"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.normalized_shape,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.elementwise_affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.normalized_shape,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.elementwise_affine,
)
class GroupNorm(torch.nn.GroupNorm):
r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
__constants__ = ["num_groups", "num_channels", "eps", "affine"]
def __init__(
self,
num_groups,
num_channels,
weight,
bias,
scale,
zero_point,
eps=1e-5,
affine=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.group_norm(
input,
self.num_groups,
self.weight,
self.bias,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedGroupNorm"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_groups,
mod.num_channels,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
class InstanceNorm1d(torch.nn.InstanceNorm1d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm1d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
class InstanceNorm2d(torch.nn.InstanceNorm2d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
class InstanceNorm3d(torch.nn.InstanceNorm3d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)

View File

@ -0,0 +1,57 @@
# mypy: allow-untyped-defs
import torch
__all__ = [
"LSTM",
]
class LSTM(torch.ao.nn.quantizable.LSTM):
r"""A quantized long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples in :class:`~torch.ao.nn.quantizable.LSTM`
Examples::
>>> # xdoctest: +SKIP
>>> custom_module_config = {
... 'float_to_observed_custom_module_class': {
... nn.LSTM: nn.quantizable.LSTM,
... },
... 'observed_to_quantized_custom_module_class': {
... nn.quantizable.LSTM: nn.quantized.LSTM,
... }
... }
>>> tq.prepare(model, prepare_custom_module_class=custom_module_config)
>>> tq.convert(model, convert_custom_module_class=custom_module_config)
"""
_FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment]
def _get_name(self):
return "QuantizedLSTM"
@classmethod
def from_float(cls, *args, **kwargs):
# The whole flow is float -> observed -> quantized
# This class does observed -> quantized only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-observed LSTM module. Please, see "
"the examples on quantizable LSTMs."
)
@classmethod
def from_observed(cls, other):
assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)
converted.__class__ = cls
return converted

View File

@ -0,0 +1,144 @@
# mypy: allow-untyped-defs
import abc
import collections
import itertools
import torch
from torch.nn.modules.module import _addindent
__all__ = [
"WeightedQuantizedModule",
]
class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
"""Wrapper for quantized modules than can be lowered from reference modules."""
@classmethod
@abc.abstractmethod
def from_reference(cls, ref_module, output_scale, output_zero_point):
raise NotImplementedError
def _get_weight_observer(observer):
# FakeQuantize observer
if hasattr(observer, "activation_post_process"):
observer = observer.activation_post_process
# UniformQuantizationObserverBase observer
return observer
def _needs_weight_clamping(observer, dtype):
observer = _get_weight_observer(observer)
if dtype in [torch.qint8, torch.quint8, torch.qint32]:
info = torch.iinfo(dtype)
return observer.quant_min > info.min or observer.quant_max < info.max
return False
def _clamp_weights(qweight, observer, scale, zp):
if not _needs_weight_clamping(observer, qweight.dtype):
return qweight
observer = _get_weight_observer(observer)
min_, max_ = observer.quant_min, observer.quant_max
# Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch._make_per_tensor_quantized_tensor(
qw_int, scale.item(), zp.item()
)
elif observer.qscheme in [
torch.per_channel_symmetric,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
qweight = torch._make_per_channel_quantized_tensor(
qw_int, scale, zp, axis=observer.ch_axis
)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.qint8
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
wt_axis = observer.ch_axis
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double),
wt_zp.to(torch.int64),
wt_axis,
torch.qint8,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float),
wt_zp.to(torch.float),
observer.ch_axis,
observer.dtype,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _ntuple_from_first(n):
"""Converts the argument to a tuple of size n
with the first element repeated."""
def parse(x):
while isinstance(x, collections.abc.Sequence):
if len(x) == n:
break
x = x[0]
return tuple(itertools.repeat(x, n))
return parse
def _hide_packed_params_repr(self, params):
# We don't want to show `PackedParams` children, hence custom
# `__repr__`. This is the same as nn.Module.__repr__, except the check
# for the `params module`.
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
if isinstance(module, params):
continue
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
_pair_from_first = _ntuple_from_first(2)

View File

@ -0,0 +1,19 @@
from .modules import * # noqa: F403
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"RNNCell",
"LSTMCell",
"GRUCell",
"LSTM",
"GRU",
"Embedding",
"EmbeddingBag",
]

View File

@ -0,0 +1,29 @@
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .linear import Linear
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell
from .sparse import Embedding, EmbeddingBag
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"RNNCell",
"LSTMCell",
"GRUCell",
"LSTM",
"GRU",
"Embedding",
"EmbeddingBag",
]

Some files were not shown because too many files have changed in this diff Show More