I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,39 @@
from . import functional
from .modules import * # noqa: F403
from .modules import MaxPool2d
__all__ = [
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
"PReLU",
# Wrapper modules
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,26 @@
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .linear import Linear
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell
__all__ = [
"Linear",
"LSTM",
"GRU",
"LSTMCell",
"RNNCell",
"GRUCell",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]

View File

@ -0,0 +1,520 @@
# mypy: allow-untyped-defs
r"""Dynamically quantized convolution modules."""
import warnings
import torch
import torch.ao.nn.quantized as nnq
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch._ops import ops
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _pair, _single, _triple
__all__ = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]
class Conv1d(nnq.Conv1d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv1d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv1d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
reduce_range=True,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv1d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != "zeros":
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
class Conv2d(nnq.Conv2d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv2d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv2d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module "
"has poor numerical accuracy and its use is not recommended"
)
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv2d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv2d_dynamic(input, self._packed_params, reduce_range)
class Conv3d(nnq.Conv3d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv3d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
>>> input = torch.randn(20, 16, 56, 56, 56)
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv3d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super()._init(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_triple(0),
groups,
bias,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConv3d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return ops.quantized.conv3d_dynamic(input, self._packed_params, reduce_range)
class ConvTranspose1d(nnq.ConvTranspose1d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose1d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose1d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
>>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose1d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose1d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
return torch.ops.quantized.conv_transpose1d_dynamic(
input, self._packed_params, reduce_range
)
class ConvTranspose2d(nnq.ConvTranspose2d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose2d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With square kernels and equal stride
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose2d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose2d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return ops.quantized.conv_transpose2d_dynamic(
input, self._packed_params, reduce_range
)
class ConvTranspose3d(nnq.ConvTranspose3d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose3d`.
For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose3d` for other attributes.
Examples::
>>> # xdoctest: +SKIP
>>> # With cubic kernels and equal stride
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-cubic kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose3d
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
):
warnings.warn(
f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended" # noqa: B950
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
**factory_kwargs,
)
def _get_name(self):
return "DynamicQuantizedConvTranspose3d"
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, T, H, W)`!")
return ops.quantized.conv_transpose3d_dynamic(
input, self._packed_params, reduce_range
)

View File

@ -0,0 +1,165 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = [
"Linear",
]
class Linear(nnq.Linear):
r"""
A dynamic quantized linear module with floating point tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module which are of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable floating point bias of the module of shape
:math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
the values are initialized to zero.
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
# version used in this class is different from the parent class nnq.Linear
_version = 4
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias_, dtype=dtype)
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.version = 4
def forward(self, x):
# Note that we can handle self.bias == None case.
if self._packed_params.dtype == torch.qint8:
if self.version is None or self.version < 4:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params
)
else:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params, reduce_range=True
)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params
)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
return Y.to(x.dtype)
def _get_name(self):
return "DynamicQuantizedLinear"
def extra_repr(self):
extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}"
if self._packed_params.dtype == torch.qint8:
extra_repr_str += f", qscheme={self.weight().qscheme()}"
return extra_repr_str
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
self.version = version
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
float_modules = [
torch.nn.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.ao.nn.intrinsic.modules.fused.LinearReLU,
torch.ao.nn.qat.dynamic.Linear,
]
assert (
type(mod) in float_modules
), "nn.quantized.dynamic.Linear.from_float only works for one of" + str(
[float_mod.__name__ for float_mod in float_modules]
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
mod = mod[0]
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
assert dtype in [torch.qint8, torch.float16], (
"The only supported dtypes for "
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
)
weight_observer(mod.weight)
if dtype == torch.qint8:
qweight = _quantize_weight(mod.weight.float(), weight_observer)
elif dtype == torch.float16:
qweight = mod.weight.float()
else:
raise RuntimeError(
"Unsupported dtype specified for dynamic quantized Linear!"
)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear):
"""Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
module
Args:
ref_qlinear (Module): a reference quantized module, either produced by
torch.ao.quantization functions or provided by the user
"""
qlinear = cls(
ref_qlinear.in_features,
ref_qlinear.out_features,
dtype=ref_qlinear.weight_dtype,
)
qweight = ref_qlinear.get_quantized_weight()
bias = ref_qlinear.bias
qlinear.set_weight_bias(qweight, bias)
return qlinear

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,778 @@
# mypy: allow-untyped-defs
r""" Functional interface (quantized)."""
import warnings
from typing import List, Optional
import torch
from torch import Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair, _triple
from .modules.utils import _pair_from_first
# Although some of the functions and docstrings are mirrored from the torch.nn,
# we want to have them here for future changes.
__all__ = [
"avg_pool2d",
"avg_pool3d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"conv1d",
"conv2d",
"conv3d",
"interpolate",
"linear",
"max_pool1d",
"max_pool2d",
"celu",
"leaky_relu",
"hardtanh",
"hardswish",
"threshold",
"elu",
"hardsigmoid",
"clamp",
"upsample",
"upsample_bilinear",
"upsample_nearest",
]
def avg_pool2d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
r"""
Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
:math:`sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AvgPool2d` for details and output shape.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
return torch.nn.functional.avg_pool2d(
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def avg_pool3d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
r"""
Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
:math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
input planes.
.. note:: The input quantization parameters propagate to the output.
Args:
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
kernel_size: size of the pooling region. Can be a single number or a
tuple `(kD, kH, kW)`
stride: stride of the pooling operation. Can be a single number or a
tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
padding: implicit zero paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
to compute the output shape. Default: ``False``
count_include_pad: when True, will include the zero-padding in the
averaging calculation. Default: ``True``
divisor_override: if specified, it will be used as divisor, otherwise
size of the pooling region will be used. Default: None
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
return torch.nn.functional.avg_pool3d(
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 2D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError(
"Input to 'quantized.functional.adaptive_avg_pool2d' must be quantized!"
)
return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
r"""
Applies a 3D adaptive average pooling over a quantized input signal composed
of several quantized input planes.
.. note:: The input quantization parameters propagate to the output.
See :class:`~torch.ao.nn.quantized.AdaptiveAvgPool3d` for details and output shape.
Args:
output_size: the target output size (single integer or
double-integer tuple)
"""
if not input.is_quantized:
raise ValueError(
"Input to 'quantized.functional.adaptive_avg_pool3d' must be quantized!"
)
return torch.nn.functional.adaptive_avg_pool3d(input, output_size)
def conv1d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 1D convolution over a quantized 1D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv1d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sW,)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padW,)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dW,)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(33, 16, 3, dtype=torch.float)
>>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
>>> bias = torch.randn(33, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
stride = _pair_from_first(stride)
padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation)
packed_params = torch.ops.quantized.conv1d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
def conv2d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 2D convolution over a quantized 2D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv2d` for details and output shape.
Args:
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
packed_params = torch.ops.quantized.conv2d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
def conv3d(
input,
weight,
bias,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
scale=1.0,
zero_point=0,
dtype=torch.quint8,
):
r"""
Applies a 3D convolution over a quantized 3D input composed of several input
planes.
See :class:`~torch.ao.nn.quantized.Conv3d` for details and output shape.
Args:
input: quantized input tensor of shape
:math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
weight: quantized filters of shape
:math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
bias: **non-quantized** bias tensor of shape
:math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
stride: the stride of the convolving kernel. Can be a single number or a
tuple `(sD, sH, sW)`. Default: 1
padding: implicit paddings on both sides of the input. Can be a
single number or a tuple `(padD, padH, padW)`. Default: 0
dilation: the spacing between kernel elements. Can be a single number or
a tuple `(dD, dH, dW)`. Default: 1
groups: split input into groups, :math:`\text{in\_channels}` should be
divisible by the number of groups. Default: 1
padding_mode: the padding mode to use. Only "zeros" is supported for
quantized convolution at the moment. Default: "zeros"
scale: quantization scale for the output. Default: 1.0
zero_point: quantization zero_point for the output. Default: 0
dtype: quantization data type to use. Default: ``torch.quint8``
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> from torch.ao.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
>>> bias = torch.randn(8, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype_inputs = torch.quint8
>>> dtype_filters = torch.qint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
>>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
""" # noqa: E501
if padding_mode != "zeros":
raise NotImplementedError("Only zero-padding is supported!")
if input.dtype != torch.quint8:
raise NotImplementedError(
"Only torch.quint8 is supported for activation tensor!"
)
if weight.dtype != torch.qint8:
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
if input.ndim != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
packed_params = torch.ops.quantized.conv3d_prepack(
weight, bias, stride, padding, dilation, groups
)
return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
r"""Down/up samples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
See :func:`torch.nn.functional.interpolate` for implementation details.
The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D/3D input is supported for quantized inputs
.. note:: Only the following modes are supported for the quantized inputs:
- `bilinear`
- `nearest`
Args:
input (Tensor): the input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'bilinear'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values, making this operation *independent* of input size
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
is ``'bilinear'``.
Default: ``False``
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.interpolate' must be quantized!")
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
def linear(
input: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
scale: Optional[float] = None,
zero_point: Optional[int] = None,
) -> Tensor:
r"""
Applies a linear transformation to the incoming quantized data:
:math:`y = xA^T + b`.
See :class:`~torch.ao.nn.quantized.Linear`
.. note::
Current implementation packs weights on every call, which has penalty on performance.
If you want to avoid the overhead, use :class:`~torch.ao.nn.quantized.Linear`.
Args:
input (Tensor): Quantized input of type `torch.quint8`
weight (Tensor): Quantized weight of type `torch.qint8`
bias (Tensor): None or fp32 bias of type `torch.float`
scale (double): output scale. If None, derived from the input scale
zero_point (long): output zero point. If None, derived from the input zero_point
Shape:
- Input: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions
- Weight: :math:`(out\_features, in\_features)`
- Bias: :math:`(out\_features)`
- Output: :math:`(N, *, out\_features)`
"""
if scale is None:
scale = input.q_scale()
if zero_point is None:
zero_point = input.q_zero_point()
_packed_params = torch.ops.quantized.linear_prepack(weight, bias)
return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
def max_pool1d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
r"""Applies a 1D max pooling over a quantized input signal composed of
several quantized input planes.
.. note:: The input quantization parameters are propagated to the output.
See :class:`~torch.ao.nn.quantized.MaxPool1d` for details.
"""
if return_indices:
raise NotImplementedError("return_indices is not yet implemented!")
if stride is None:
stride = torch.jit.annotate(List[int], [])
return torch.nn.functional.max_pool1d(
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode=ceil_mode,
return_indices=return_indices,
)
def max_pool2d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
r"""Applies a 2D max pooling over a quantized input signal composed of
several quantized input planes.
.. note:: The input quantization parameters are propagated to the output.
See :class:`~torch.ao.nn.quantized.MaxPool2d` for details.
"""
if return_indices:
raise NotImplementedError("return_indices is not yet implemented!")
if stride is None:
stride = torch.jit.annotate(List[int], [])
return torch.nn.functional.max_pool2d(
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode=ceil_mode,
return_indices=return_indices,
)
def celu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
r"""celu(input, scale, zero_point, alpha=1.) -> Tensor
Applies the quantized CELU function element-wise.
.. math::
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x / \alpha) - 1))
Args:
input: quantized input
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.celu' must be quantized!")
return torch.ops.quantized.celu(input, scale, zero_point, alpha)
def leaky_relu(
input: Tensor,
negative_slope: float = 0.01,
inplace: bool = False,
scale: Optional[float] = None,
zero_point: Optional[int] = None,
):
r"""
Quantized version of the.
leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
Applies element-wise,
:math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
Args:
input: Quantized input
negative_slope: The slope of the negative input
inplace: Inplace modification of the input tensor
scale, zero_point: Scale and zero point of the output tensor.
See :class:`~torch.nn.LeakyReLU` for more details.
"""
if scale is not None and zero_point is not None:
assert not inplace, "Cannot rescale with `inplace`"
output = torch._empty_affine_quantized(
input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype
)
torch._C._nn.leaky_relu(input, negative_slope, out=output)
return output
if inplace:
result = torch._C._nn.leaky_relu_(input, negative_slope)
else:
result = torch._C._nn.leaky_relu(input, negative_slope)
return result
def hardtanh(
input: Tensor, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False
) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardtanh`."""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
if inplace:
return torch._C._nn.hardtanh_(input, min_val, max_val)
return torch._C._nn.hardtanh(input, min_val, max_val)
def hardswish(input: Tensor, scale: float, zero_point: int) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardswish`.
Args:
input: quantized input
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardswish' must be quantized!")
return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
def threshold(input: Tensor, threshold: float, value: float) -> Tensor:
r"""Applies the quantized version of the threshold function element-wise:
.. math::
x = \begin{cases}
x & \text{if~} x > \text{threshold} \\
\text{value} & \text{otherwise}
\end{cases}
See :class:`~torch.nn.Threshold` for more details.
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.threshold' must be quantized!")
if threshold is None:
raise ValueError("Input to 'threshold' must be specified!")
if value is None:
raise ValueError("Input to 'value' must be specified!")
return torch._ops.ops.quantized.threshold(input, threshold, value)
def elu(input: Tensor, scale: float, zero_point: int, alpha: float = 1.0) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.elu`.
Args:
input: quantized input
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
alpha: the alpha constant
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.elu' must be quantized!")
return torch.ops.quantized.elu(input, scale, zero_point, alpha)
def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
r"""This is the quantized version of :func:`~torch.nn.functional.hardsigmoid`."""
if not input.is_quantized:
raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
if inplace:
return torch._C._nn.hardsigmoid_(input) # type: ignore[attr-defined]
return torch._C._nn.hardsigmoid(input)
def clamp(input: Tensor, min_: float, max_: float) -> Tensor:
r"""float(input, min\_, max\_) -> Tensor
Applies the clamp function element-wise.
See :class:`~torch.ao.nn.quantized.clamp` for more details.
Args:
input: quantized input
min_: minimum value for clamping
max_: maximum value for clamping
"""
if not input.is_quantized:
raise ValueError("Input to 'quantized.clamp' must be quantized!")
return torch.clamp(input, min_, max_)
def upsample(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
r"""Upsamples the input to either the given :attr:`size` or the given
:attr:`scale_factor`
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with ``nn.quantized.functional.interpolate(...)``.
See :func:`torch.nn.functional.interpolate` for implementation details.
The input dimensions are interpreted in the form:
`mini-batch x channels x [optional depth] x [optional height] x width`.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D input is supported for quantized inputs
.. note:: Only the following modes are supported for the quantized inputs:
- `bilinear`
- `nearest`
Args:
input (Tensor): quantized input tensor
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
mode (str): algorithm used for upsampling:
``'nearest'`` | ``'bilinear'``
align_corners (bool, optional): Geometrically, we consider the pixels of the
input and output as squares rather than points.
If set to ``True``, the input and output tensors are aligned by the
center points of their corner pixels, preserving the values at the corner pixels.
If set to ``False``, the input and output tensors are aligned by the corner
points of their corner pixels, and the interpolation uses edge value padding
for out-of-boundary values, making this operation *independent* of input size
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
is ``'bilinear'``.
Default: ``False``
.. warning::
With ``align_corners = True``, the linearly interpolating modes
(`bilinear`) don't proportionally align the
output and input pixels, and thus the output values can depend on the
input size. This was the default behavior for these modes up to version
0.3.1. Since then, the default behavior is ``align_corners = False``.
See :class:`~torch.nn.Upsample` for concrete examples on how this
affects the outputs.
"""
warnings.warn(
"nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode, align_corners)
def upsample_bilinear(input, size=None, scale_factor=None):
r"""Upsamples the input, using bilinear upsampling.
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with
``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D inputs are supported
Args:
input (Tensor): quantized input
size (int or Tuple[int, int]): output spatial size.
scale_factor (int or Tuple[int, int]): multiplier for spatial size
"""
# DeprecationWarning is ignored by default
warnings.warn(
"nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode="bilinear", align_corners=True)
def upsample_nearest(input, size=None, scale_factor=None):
r"""Upsamples the input, using nearest neighbours' pixel values.
.. warning::
This function is deprecated in favor of
:func:`torch.ao.nn.quantized.functional.interpolate`.
This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
.. note:: The input quantization parameters propagate to the output.
.. note:: Only 2D inputs are supported
Args:
input (Tensor): quantized input
size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
size.
scale_factor (int): multiplier for spatial size. Has to be an integer.
"""
# DeprecationWarning is ignored by default
warnings.warn(
"nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead."
)
return interpolate(input, size, scale_factor, mode="nearest")

View File

@ -0,0 +1,162 @@
# mypy: allow-untyped-defs
import torch
# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable`
# packages. However, the `quantizable` package uses "lazy imports"
# to avoid circular dependency.
# Hence we need to include it here to make sure it is resolved before
# they are used in the modules.
import torch.ao.nn.quantizable
from torch.nn.modules.pooling import MaxPool2d
from .activation import (
ELU,
Hardswish,
LeakyReLU,
MultiheadAttention,
PReLU,
ReLU6,
Sigmoid,
Softmax,
)
from .batchnorm import BatchNorm2d, BatchNorm3d
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .dropout import Dropout
from .embedding_ops import Embedding, EmbeddingBag
from .functional_modules import FloatFunctional, FXFloatFunctional, QFunctional
from .linear import Linear
from .normalization import (
GroupNorm,
InstanceNorm1d,
InstanceNorm2d,
InstanceNorm3d,
LayerNorm,
)
from .rnn import LSTM
__all__ = [
"BatchNorm2d",
"BatchNorm3d",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"DeQuantize",
"ELU",
"Embedding",
"EmbeddingBag",
"GroupNorm",
"Hardswish",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"LayerNorm",
"LeakyReLU",
"Linear",
"LSTM",
"MultiheadAttention",
"Quantize",
"ReLU6",
"Sigmoid",
"Softmax",
"Dropout",
"PReLU",
# Wrapper modules
"FloatFunctional",
"FXFloatFunctional",
"QFunctional",
]
class Quantize(torch.nn.Module):
r"""Quantizes an incoming tensor
Args:
`scale`: scale of the output Quantized Tensor
`zero_point`: zero_point of output Quantized Tensor
`dtype`: data type of output Quantized Tensor
`factory_kwargs`: Dictionary of kwargs used for configuring initialization
of internal buffers. Currently, `device` and `dtype` are supported.
Example: `factory_kwargs={'device': 'cuda', 'dtype': torch.float64}`
will initialize internal buffers as type `torch.float64` on the current CUDA device.
Note that `dtype` only applies to floating-point buffers.
Examples::
>>> t = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> # xdoctest: +SKIP
>>> qt = qm(t)
>>> print(qt)
tensor([[ 1., -1.],
[ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
"""
scale: torch.Tensor
zero_point: torch.Tensor
def __init__(self, scale, zero_point, dtype, factory_kwargs=None):
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
super().__init__()
self.register_buffer("scale", torch.tensor([scale], **factory_kwargs))
self.register_buffer(
"zero_point",
torch.tensor(
[zero_point],
dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
self.dtype = dtype
def forward(self, X):
return torch.quantize_per_tensor(
X, float(self.scale), int(self.zero_point), self.dtype
)
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
assert hasattr(mod, "activation_post_process")
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(
scale.float().item(),
zero_point.long().item(),
mod.activation_post_process.dtype,
)
def extra_repr(self):
return f"scale={self.scale}, zero_point={self.zero_point}, dtype={self.dtype}"
class DeQuantize(torch.nn.Module):
r"""Dequantizes an incoming tensor
Examples::
>>> input = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> # xdoctest: +SKIP
>>> quantized_input = qm(input)
>>> dqm = DeQuantize()
>>> dequantized = dqm(quantized_input)
>>> print(dequantized)
tensor([[ 1., -1.],
[ 1., -1.]], dtype=torch.float32)
"""
def forward(self, Xq):
return Xq.dequantize()
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
return DeQuantize()

View File

@ -0,0 +1,343 @@
# mypy: allow-untyped-defs
from warnings import warn
import torch
__all__ = [
"ReLU6",
"Hardswish",
"ELU",
"LeakyReLU",
"Sigmoid",
"Softmax",
"MultiheadAttention",
"PReLU",
]
class ReLU6(torch.nn.ReLU):
r"""Applies the element-wise function:
:math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
zero_point, and :math:`q(6)` is the quantized representation of number 6.
Args:
inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input
.. image:: ../scripts/activation_images/ReLU6.png
Examples::
>>> m = nn.quantized.ReLU6()
>>> input = torch.randn(2)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
>>> output = m(input)
"""
def __init__(self, inplace=False):
super().__init__(inplace)
self.inplace = inplace
def forward(self, input):
return torch.ops.quantized.relu6(input, self.inplace)
def _get_name(self):
return "QuantizedReLU6"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
return ReLU6(mod.inplace)
class Hardswish(torch.nn.Hardswish):
r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.hardswish(input, self.scale, self.zero_point)
def _get_name(self):
return "QuantizedHardswish"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Hardswish(float(scale), int(zero_point))
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point))
class ELU(torch.nn.ELU):
r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
alpha: the alpha constant
"""
def __init__(self, scale, zero_point, alpha=1.0):
super().__init__(alpha)
self.scale = scale
self.zero_point = zero_point
def forward(self, input):
return torch.ao.nn.quantized.functional.elu(
input, self.scale, self.zero_point, self.alpha
)
def _get_name(self):
return "QuantizedELU"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return ELU(float(scale), int(zero_point), mod.alpha)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point), mod.alpha)
class LeakyReLU(torch.nn.LeakyReLU):
r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
negative_slope: Controls the angle of the negative slope. Default: 1e-2
"""
def __init__(
self,
scale: float,
zero_point: int,
negative_slope: float = 1e-2,
inplace: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace)
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.leaky_relu(
input, self.negative_slope, self.inplace, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLeakyReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
class Sigmoid(torch.nn.Sigmoid):
r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, output_scale: float, output_zero_point: int):
super().__init__()
self.output_scale = output_scale
self.output_zero_point = output_zero_point
def forward(self, input):
return torch.ops.quantized.sigmoid(
input, self.output_scale, self.output_zero_point
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
(
output_scale,
output_zero_point,
) = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point))
class Softmax(torch.nn.Softmax):
r"""This is the quantized version of :class:`~torch.nn.Softmax`.
Args:
dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
"""
def __init__(self, dim=None, scale=1.0, zero_point=0):
super().__init__()
self.dim = dim
self.scale = scale
self.zero_point = zero_point
def forward(self, input):
dim = self.dim
if dim is None:
stacklevel = 3
# Note: adding the mypy ignore on _get_softmax_dim seems less bad
# than making `_get_softmax_dim` an official API.
dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined]
"softmax", input.dim(), stacklevel
)
return torch.ops.quantized.softmax(input, dim, self.scale, self.zero_point)
def _get_name(self):
return "QuantizedSoftmax"
@staticmethod
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Softmax(mod.dim, float(scale), int(zero_point))
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(mod.dim, float(scale), int(zero_point))
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self):
return "QuantizedMultiheadAttention"
@classmethod
def from_float(cls, other):
# The whole flow is float -> observed -> quantized
# This class does observed -> quantized only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-observed MHA module. Please, see "
"the examples on quantizable MHAs."
)
@classmethod
def from_observed(cls, other):
converted = torch.ao.quantization.convert(
other,
mapping=None,
inplace=False,
remove_qconfig=True,
convert_custom_config_dict=None,
)
converted.__class__ = cls
# Remove the parameters for the bias_k and bias_v to quantize them
# TODO: This is a potential source of accuracy drop.
# quantized cat takes the scale and zp of the first
# element, which might lose the precision in the bias_k
# and the bias_v (which are cat'ed with k/v being first).
if converted.bias_k is not None:
bias_k = converted._parameters.pop("bias_k")
sc, zp = torch._choose_qparams_per_tensor(bias_k, reduce_range=False)
bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
setattr(converted, "bias_k", bias_k) # noqa: B010
if converted.bias_v is not None:
bias_v = converted._parameters.pop("bias_v")
sc, zp = torch._choose_qparams_per_tensor(
bias_k, reduce_range=False # type: ignore[possibly-undefined]
)
bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
setattr(converted, "bias_v", bias_v) # noqa: B010
del converted.in_proj_weight
del converted.in_proj_bias
return converted
class PReLU(torch.nn.Module):
r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
Args:
scale: quantization scale of the output tensor
zero_point: quantization zero point of the output tensor
num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
"""
def __init__(
self, output_scale: float, output_zero_point: int, num_parameters: int = 1
) -> None:
super().__init__()
self.num_parameters = num_parameters
self.scale = output_scale
self.zero_point = output_zero_point
w = torch.randn(num_parameters, dtype=torch.float)
qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
self.set_weight(qw)
def set_weight(self, w: torch.Tensor) -> None:
self.weight = w
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.prelu(
input, self.weight, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedPReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8
)
qprelu.set_weight(qweight)
return qprelu
@classmethod
def from_reference(cls, mod, scale, zero_point):
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()
observer = mod.qconfig.weight()
observer(float_wt)
if observer.dtype != torch.quint8:
warn(
f"PReLU's weight observer should have dtype quint8 but got {observer.dtype}"
)
wt_scale, wt_zp = observer.calculate_qparams()
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.quint8
)
qprelu.set_weight(qweight)
return qprelu

View File

@ -0,0 +1,128 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
__all__ = ["BatchNorm2d", "BatchNorm3d"]
class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
def __init__(
self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE:
mod = mod[0]
scale, zero_point = activation_post_process.calculate_qparams()
new_mod = cls(mod.num_features, mod.eps)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.running_mean = mod.running_mean
new_mod.running_var = mod.running_var
new_mod.scale = scale
new_mod.zero_point = zero_point
return new_mod
@classmethod
def from_reference(cls, bn, output_scale, output_zero_point):
qbn = cls(
bn.num_features,
bn.eps,
bn.momentum,
device=bn.weight.device,
dtype=bn.weight.dtype,
)
qbn.weight = bn.weight
qbn.bias = bn.bias
qbn.running_mean = bn.running_mean
qbn.running_var = bn.running_var
qbn.scale = output_scale
qbn.zero_point = output_zero_point
return qbn
class BatchNorm2d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`."""
_NNI_BN_RELU_MODULE = nni.BNReLU2d
def __init__(
self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, **factory_kwargs)
def _get_name(self):
return "QuantizedBatchNorm2d"
def _check_input_dim(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
def forward(self, input: torch.Tensor) -> torch.Tensor:
# disabling this since this is not symbolically traceable
# self._check_input_dim(input)
return torch.ops.quantized.batch_norm2d(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class BatchNorm3d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`."""
_NNI_BN_RELU_MODULE = nni.BNReLU3d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, **factory_kwargs)
def _get_name(self):
return "QuantizedBatchNorm3d"
def _check_input_dim(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, H, W)`!")
def forward(self, input: torch.Tensor) -> torch.Tensor:
# disabling this since this is not symbolically traceable
# self._check_input_dim(input)
return torch.ops.quantized.batch_norm3d(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
# mypy: allow-untyped-defs
import torch
__all__ = ["Dropout"]
class Dropout(torch.nn.Dropout):
r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
And this is a placeholder to enable models where fp32 tensors
had dropout to work with quantized tensors in train and eval mode.
Args:
p: probability of an element to be zeroed
inplace: can optionally do the operation in-place. Default: ``False``
"""
def forward(self, input):
return input
def _get_name(self):
return "QuantizedDropout"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return cls(mod.p, mod.inplace)
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(mod.p, mod.inplace)

View File

@ -0,0 +1,405 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import List, Optional # noqa: F401
from .utils import _hide_packed_params_repr, _quantize_weight
__all__ = ["EmbeddingPackedParams", "Embedding", "EmbeddingBag"]
class EmbeddingPackedParams(torch.nn.Module):
_version = 1
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
super().__init__()
self.dtype = dtype
if self.dtype in [torch.quint8, torch.quint4x2]:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
wq = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=self.dtype,
)
self.set_weight(wq)
else:
raise NotImplementedError(
f"Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}"
)
@torch.jit.export
def set_weight(self, weight: torch.Tensor) -> None:
if self.dtype in [torch.quint8, torch.quint4x2]:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2."
)
@torch.jit.export
def _weight(self):
if self.dtype in [torch.quint8, torch.quint4x2]:
return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
else:
raise NotImplementedError(
"Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2."
)
def forward(self, x):
return x
# Version 1
# self
# |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "dtype"] = self.dtype
destination[prefix + "_packed_weight"] = self._weight()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.dtype = state_dict[prefix + "dtype"]
state_dict.pop(prefix + "dtype")
weight = state_dict[prefix + "_packed_weight"]
state_dict.pop(prefix + "_packed_weight")
self.set_weight(weight)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def __repr__(self):
return self._weight().__repr__()
class Embedding(torch.nn.Module):
r"""
A quantized Embedding module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html for documentation.
Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
>>> output = m(indices)
>>> print(output.size())
torch.Size([9, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
dtype=torch.quint8,
) -> None:
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.dtype = dtype
if _weight is None:
scales = torch.ones(num_embeddings, dtype=torch.float)
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
qweight = torch._empty_per_channel_affine_quantized(
[num_embeddings, embedding_dim],
scales=scales,
zero_points=zero_points,
axis=0,
dtype=torch.quint8,
)
else:
assert list(_weight.shape) == [
num_embeddings,
embedding_dim,
], "Shape of weight does not match num_embeddings and embedding_dim"
qweight = _weight
self._packed_params = EmbeddingPackedParams(
num_embeddings, embedding_dim, dtype
)
self._packed_params.set_weight(qweight)
def forward(self, indices: Tensor) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_4bit(
self._packed_params._packed_weight, indices
)
else:
return torch.ops.quantized.embedding_byte(
self._packed_params._packed_weight, indices
)
def _get_name(self):
return "QuantizedEmbedding"
def __repr__(self):
return _hide_packed_params_repr(self, EmbeddingPackedParams)
def extra_repr(self):
extra_repr_str = (
f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, "
f"dtype={self._packed_params.dtype}, qscheme={self.weight().qscheme()}"
)
return extra_repr_str
def set_weight(self, w: torch.Tensor) -> None:
self._packed_params.set_weight(w)
def weight(self):
return self._packed_params._weight()
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
assert type(mod) == torch.ao.nn.qat.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float "
+ "with fake quant only works for "
+ torch.ao.nn.qat.Embedding.__name__
)
weight_observer = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
assert type(mod) == nn.Embedding, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.Embedding.__name__
)
assert hasattr(
mod, "qconfig"
), "Embedding input float module must have qconfig defined"
from torch.ao.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "Embedding quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.Embedding is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized Embedding module and pass in the quantized weight
qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
qembedding.set_weight(qweight)
return qembedding
@classmethod
def from_reference(cls, ref_embedding):
qembedding = cls(
ref_embedding.num_embeddings,
ref_embedding.embedding_dim,
ref_embedding.padding_idx,
ref_embedding.max_norm,
ref_embedding.norm_type,
ref_embedding.scale_grad_by_freq,
ref_embedding.sparse,
ref_embedding.get_quantized_weight(),
ref_embedding.weight_dtype,
)
return qembedding
class EmbeddingBag(Embedding):
r"""
A quantized EmbeddingBag module with quantized packed weights as inputs.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html for documentation.
Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
Examples::
>>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
>>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
>>> output = m(indices, offsets)
>>> print(output.size())
torch.Size([5, 12])
"""
_version = 1
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
mode: str = "sum",
sparse: bool = False,
_weight: Optional[Tensor] = None,
include_last_offset: bool = False,
dtype=torch.quint8,
) -> None:
super().__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
self.mode = mode
self.pruned_weights = False
self.include_last_offset = include_last_offset
self.dtype = dtype
def forward(
self,
indices: Tensor,
offsets: Optional[Tensor] = None,
per_sample_weights: Optional[Tensor] = None,
compressed_indices_mapping: Optional[Tensor] = None,
) -> Tensor:
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_bag_4bit(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
else:
return torch.ops.quantized.embedding_bag_byte(
self._packed_params._packed_weight,
indices,
offsets,
False,
0,
self.pruned_weights,
per_sample_weights,
compressed_indices_mapping,
self.include_last_offset,
)
def _get_name(self):
return "QuantizedEmbeddingBag"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding_bag module from a float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, "weight_fake_quant"):
weight_observer = mod.weight_fake_quant
else:
assert type(mod) == nn.EmbeddingBag, (
"nnq."
+ cls.__name__
+ ".from_float only works for "
+ nn.EmbeddingBag.__name__
)
assert hasattr(
mod, "qconfig"
), "EmbeddingBag input float module must have qconfig defined"
from torch.ao.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None: # type: ignore[union-attr]
weight_observer = mod.qconfig.weight() # type: ignore[union-attr, operator]
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = (
weight_observer.qscheme == torch.per_channel_affine_float_qparams
)
assert (
is_float_qparams_qconfig
), "EmbeddingBag quantization is only supported with float_qparams_weight_only_qconfig."
assert (
dtype == torch.quint8 or dtype == torch.quint4x2
), f"The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2, got {dtype}"
# Run the observer to calculate qparams.
weight_observer(mod.weight)
qweight = _quantize_weight(mod.weight.float(), weight_observer)
# Create quantized EmbeddingBag module and pass in the quantized weight
qembedding_bag = EmbeddingBag(
mod.num_embeddings, mod.embedding_dim, dtype=dtype
)
qembedding_bag.set_weight(qweight)
return qembedding_bag
@classmethod
def from_reference(cls, ref_embedding_bag):
qembedding_bag = cls(
ref_embedding_bag.num_embeddings,
ref_embedding_bag.embedding_dim,
ref_embedding_bag.max_norm,
ref_embedding_bag.norm_type,
ref_embedding_bag.scale_grad_by_freq,
ref_embedding_bag.mode,
ref_embedding_bag.sparse,
ref_embedding_bag.get_quantized_weight(),
ref_embedding_bag.include_last_offset,
ref_embedding_bag.weight_dtype,
)
return qembedding_bag

View File

@ -0,0 +1,299 @@
# mypy: allow-untyped-defs
from typing import List
import torch
from torch import Tensor
from torch._ops import ops
__all__ = ["FloatFunctional", "FXFloatFunctional", "QFunctional"]
class FloatFunctional(torch.nn.Module):
r"""State collector class for float operations.
The instance of this class can be used instead of the ``torch.`` prefix for
some operations. See example usage below.
.. note::
This class does not provide a ``forward`` hook. Instead, you must use
one of the underlying functions (e.g. ``add``).
Examples::
>>> f_add = FloatFunctional()
>>> a = torch.tensor(3.0)
>>> b = torch.tensor(4.0)
>>> f_add.add(a, b) # Equivalent to ``torch.add(a, b)``
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def __init__(self) -> None:
super().__init__()
self.activation_post_process = torch.nn.Identity()
def forward(self, x):
raise RuntimeError(
"FloatFunctional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
r = self.activation_post_process(r)
return r
class FXFloatFunctional(torch.nn.Module):
r"""module to replace FloatFunctional module before FX graph mode quantization,
since activation_post_process will be inserted in top level module directly
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def forward(self, x):
raise RuntimeError(
"FloatFunctional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.add(Tensor, Tensor)``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
return r
r"""Operation equivalent to ``torch.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.add(x, y)
return r
r"""Operation equivalent to ``torch.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.mul(x, y)
return r
r"""Operation equivalent to ``torch.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = torch.mul(x, y)
return r
r"""Operation equivalent to ``torch.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = torch.cat(x, dim=dim)
return r
r"""Operation equivalent to ``relu(torch.add(x,y))``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.add(x, y)
r = torch.nn.functional.relu(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
return r
class QFunctional(torch.nn.Module):
r"""Wrapper class for quantized operations.
The instance of this class can be used instead of the
``torch.ops.quantized`` prefix. See example usage below.
.. note::
This class does not provide a ``forward`` hook. Instead, you must use
one of the underlying functions (e.g. ``add``).
Examples::
>>> q_add = QFunctional()
>>> # xdoctest: +SKIP
>>> a = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint32)
>>> b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
>>> q_add.add(a, b) # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)``
Valid operation names:
- add
- cat
- mul
- add_relu
- add_scalar
- mul_scalar
"""
def __init__(self) -> None:
super().__init__()
self.scale = 1.0
self.zero_point = 0
self.activation_post_process = torch.nn.Identity()
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "scale"] = torch.tensor(self.scale)
destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.scale = float(state_dict.pop(prefix + "scale"))
self.zero_point = int(state_dict.pop(prefix + "zero_point"))
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def _get_name(self):
return "QFunctional"
def extra_repr(self):
return f"scale={self.scale}, zero_point={self.zero_point}"
def forward(self, x):
raise RuntimeError(
"Functional is not intended to use the "
+ "'forward'. Please use the underlying operation"
)
r"""Operation equivalent to ``torch.ops.quantized.add``"""
def add(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.add(Tensor, float)``"""
def add_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.add_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, Tensor)``"""
def mul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.mul(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.mul(Tensor, float)``"""
def mul_scalar(self, x: Tensor, y: float) -> Tensor:
r = ops.quantized.mul_scalar(x, y)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
r"""Operation equivalent to ``torch.ops.quantized.cat``"""
def cat(self, x: List[Tensor], dim: int = 0) -> Tensor:
r = ops.quantized.cat(x, scale=self.scale, zero_point=self.zero_point, dim=dim)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.add_relu``"""
def add_relu(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.add_relu(x, y, scale=self.scale, zero_point=self.zero_point)
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert (
type(mod) == FloatFunctional
), "QFunctional.from_float expects an instance of FloatFunctional"
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
new_mod = QFunctional()
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod

View File

@ -0,0 +1,358 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from collections.abc import Iterable
from typing import Optional
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
import torch.nn as nn
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
from .utils import _hide_packed_params_repr, _quantize_weight, WeightedQuantizedModule
__all__ = ["LinearPackedParams", "Linear"]
class LinearPackedParams(torch.nn.Module):
_version = 3
def __init__(self, dtype=torch.qint8):
super().__init__()
self.dtype = dtype
if self.dtype == torch.qint8:
wq = torch._empty_affine_quantized(
[1, 1], scale=1.0, zero_point=0, dtype=torch.qint8
)
elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None) # type: ignore[possibly-undefined]
@torch.jit.export
def set_weight_bias(
self, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> None:
if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16:
self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
@torch.jit.export
def _weight_bias(self):
if self.dtype == torch.qint8:
return torch.ops.quantized.linear_unpack(self._packed_params)
elif self.dtype == torch.float16:
return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
def forward(self, x):
return x
# Version 1
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- dtype : torch.dtype
#
# Version 3
# self
# |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
# of LinearPackedParams
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "dtype"] = self.dtype
destination[prefix + "_packed_params"] = self._weight_bias()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
self.dtype = torch.qint8
else:
self.dtype = state_dict[prefix + "dtype"]
state_dict.pop(prefix + "dtype")
if version is None or version < 3:
self.set_weight_bias(
state_dict[prefix + "weight"], state_dict[prefix + "bias"]
)
state_dict.pop(prefix + "weight")
state_dict.pop(prefix + "bias")
if version == 3:
weight, bias = state_dict[prefix + "_packed_params"]
state_dict.pop(prefix + "_packed_params")
self.set_weight_bias(weight, bias)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def __repr__(self):
return self._weight_bias().__repr__()
class Linear(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
scale: `scale` parameter of output Quantized Tensor, type: double
zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_version = 3
_FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__()
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.in_features = in_features
self.out_features = out_features
bias = None
if bias_:
bias = torch.zeros(out_features, dtype=torch.float)
if dtype == torch.qint8:
qweight = torch._empty_affine_quantized(
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8
)
elif dtype == torch.float16:
qweight = torch.zeros([out_features, in_features], dtype=torch.float)
else:
raise RuntimeError("Unsupported dtype specified for quantized Linear!")
self._packed_params = LinearPackedParams(dtype)
self._packed_params.set_weight_bias(qweight, bias)
self.scale = 1.0
self.zero_point = 0
def _get_name(self):
return "QuantizedLinear"
def extra_repr(self):
return (
f"in_features={self.in_features}, out_features={self.out_features}, scale={self.scale}, "
f"zero_point={self.zero_point}, qscheme={self.weight().qscheme()}"
)
def __repr__(self):
return _hide_packed_params_repr(self, LinearPackedParams)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
# regular QTensor form for serialization. Packed weights should not live
# outside the process in which they were created, rather they should be derived
# from the QTensor weight.
#
# Version 1
# self
# |--- scale : float
# |--- zero_point : int
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 3
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- _packed_params : (Tensor, Tensor) representing weight, bias
# of LinearPackedParams C++ struct
#
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + "scale"] = torch.tensor(self.scale)
destination[prefix + "zero_point"] = torch.tensor(self.zero_point)
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
self.scale = float(state_dict[prefix + "scale"])
state_dict.pop(prefix + "scale")
self.zero_point = int(state_dict[prefix + "zero_point"])
state_dict.pop(prefix + "zero_point")
version = local_metadata.get("version", None)
if version is None or version == 1:
# We moved the parameters into a LinearPackedParameters submodule
weight = state_dict.pop(prefix + "weight")
bias = state_dict.pop(prefix + "bias")
state_dict.update(
{
prefix + "_packed_params.weight": weight,
prefix + "_packed_params.bias": bias,
}
)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
# Function rather than property to make sure that JIT serialization doesn't
# register this as an attribute
def _weight_bias(self):
return self._packed_params._weight_bias()
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params.set_weight_bias(w, b)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
use_precomputed_fake_quant (bool): if True, the module will reuse min/max
values from the precomputed fake quant module.
"""
if hasattr(mod, "weight_fake_quant"):
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
mod.weight, mod.bias = fuse_linear_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
# This function does not participate in JIT, so it is OK to ignore
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore[assignment]
supported_modules = ", ".join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore[attr-defined]
error_msg = f"nnq.{cls.__name__}.from_float only works for {supported_modules}, but got: {type(mod)}"
assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore[attr-defined]
assert hasattr(
mod, "qconfig"
), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0]
weight_post_process = (
mod.qconfig.weight()
if not hasattr(mod, "weight_fake_quant")
else mod.weight_fake_quant
)
if not use_precomputed_fake_quant:
# Observer may not have been called yet
# Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
qlinear.scale = float(act_scale)
qlinear.zero_point = int(act_zp)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
output_zero_point (int): zero point for output Tensor
"""
qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features)
qweight = ref_qlinear.get_quantized_weight()
qlinear.set_weight_bias(qweight, ref_qlinear.bias)
qlinear.scale = float(output_scale)
qlinear.zero_point = int(output_zero_point)
return qlinear

View File

@ -0,0 +1,346 @@
# mypy: allow-untyped-defs
import torch
__all__ = [
"LayerNorm",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
]
class LayerNorm(torch.nn.LayerNorm):
r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
normalized_shape,
weight,
bias,
scale,
zero_point,
eps=1e-5,
elementwise_affine=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
**factory_kwargs,
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.layer_norm(
input,
self.normalized_shape,
weight=self.weight,
bias=self.bias,
eps=self.eps,
output_scale=self.scale,
output_zero_point=self.zero_point,
)
def _get_name(self):
return "QuantizedLayerNorm"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.normalized_shape,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.elementwise_affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.normalized_shape,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.elementwise_affine,
)
class GroupNorm(torch.nn.GroupNorm):
r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
__constants__ = ["num_groups", "num_channels", "eps", "affine"]
def __init__(
self,
num_groups,
num_channels,
weight,
bias,
scale,
zero_point,
eps=1e-5,
affine=True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.group_norm(
input,
self.num_groups,
self.weight,
self.bias,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedGroupNorm"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_groups,
mod.num_channels,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
class InstanceNorm1d(torch.nn.InstanceNorm1d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm1d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
class InstanceNorm2d(torch.nn.InstanceNorm2d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
class InstanceNorm3d(torch.nn.InstanceNorm3d):
r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
Additional args:
* **scale** - quantization scale of the output, type: double.
* **zero_point** - quantization zero point of the output, type: long.
"""
def __init__(
self,
num_features,
weight,
bias,
scale,
zero_point,
eps=1e-5,
momentum=0.1,
affine=False,
track_running_stats=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
)
self.weight = weight
self.bias = bias
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
return torch.ops.quantized.instance_norm(
input, self.weight, self.bias, self.eps, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedInstanceNorm3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)
return new_mod
@classmethod
def from_reference(cls, mod, scale, zero_point):
return cls(
mod.num_features,
mod.weight,
mod.bias,
float(scale),
int(zero_point),
mod.eps,
mod.affine,
)

View File

@ -0,0 +1,57 @@
# mypy: allow-untyped-defs
import torch
__all__ = [
"LSTM",
]
class LSTM(torch.ao.nn.quantizable.LSTM):
r"""A quantized long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples in :class:`~torch.ao.nn.quantizable.LSTM`
Examples::
>>> # xdoctest: +SKIP
>>> custom_module_config = {
... 'float_to_observed_custom_module_class': {
... nn.LSTM: nn.quantizable.LSTM,
... },
... 'observed_to_quantized_custom_module_class': {
... nn.quantizable.LSTM: nn.quantized.LSTM,
... }
... }
>>> tq.prepare(model, prepare_custom_module_class=custom_module_config)
>>> tq.convert(model, convert_custom_module_class=custom_module_config)
"""
_FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment]
def _get_name(self):
return "QuantizedLSTM"
@classmethod
def from_float(cls, *args, **kwargs):
# The whole flow is float -> observed -> quantized
# This class does observed -> quantized only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-observed LSTM module. Please, see "
"the examples on quantizable LSTMs."
)
@classmethod
def from_observed(cls, other):
assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)
converted.__class__ = cls
return converted

View File

@ -0,0 +1,144 @@
# mypy: allow-untyped-defs
import abc
import collections
import itertools
import torch
from torch.nn.modules.module import _addindent
__all__ = [
"WeightedQuantizedModule",
]
class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
"""Wrapper for quantized modules than can be lowered from reference modules."""
@classmethod
@abc.abstractmethod
def from_reference(cls, ref_module, output_scale, output_zero_point):
raise NotImplementedError
def _get_weight_observer(observer):
# FakeQuantize observer
if hasattr(observer, "activation_post_process"):
observer = observer.activation_post_process
# UniformQuantizationObserverBase observer
return observer
def _needs_weight_clamping(observer, dtype):
observer = _get_weight_observer(observer)
if dtype in [torch.qint8, torch.quint8, torch.qint32]:
info = torch.iinfo(dtype)
return observer.quant_min > info.min or observer.quant_max < info.max
return False
def _clamp_weights(qweight, observer, scale, zp):
if not _needs_weight_clamping(observer, qweight.dtype):
return qweight
observer = _get_weight_observer(observer)
min_, max_ = observer.quant_min, observer.quant_max
# Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch._make_per_tensor_quantized_tensor(
qw_int, scale.item(), zp.item()
)
elif observer.qscheme in [
torch.per_channel_symmetric,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
qweight = torch._make_per_channel_quantized_tensor(
qw_int, scale, zp, axis=observer.ch_axis
)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.qint8
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
wt_axis = observer.ch_axis
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double),
wt_zp.to(torch.int64),
wt_axis,
torch.qint8,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float),
wt_zp.to(torch.float),
observer.ch_axis,
observer.dtype,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _ntuple_from_first(n):
"""Converts the argument to a tuple of size n
with the first element repeated."""
def parse(x):
while isinstance(x, collections.abc.Sequence):
if len(x) == n:
break
x = x[0]
return tuple(itertools.repeat(x, n))
return parse
def _hide_packed_params_repr(self, params):
# We don't want to show `PackedParams` children, hence custom
# `__repr__`. This is the same as nn.Module.__repr__, except the check
# for the `params module`.
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
if isinstance(module, params):
continue
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
_pair_from_first = _ntuple_from_first(2)

View File

@ -0,0 +1,19 @@
from .modules import * # noqa: F403
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"RNNCell",
"LSTMCell",
"GRUCell",
"LSTM",
"GRU",
"Embedding",
"EmbeddingBag",
]

View File

@ -0,0 +1,29 @@
from .conv import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
)
from .linear import Linear
from .rnn import GRU, GRUCell, LSTM, LSTMCell, RNNCell
from .sparse import Embedding, EmbeddingBag
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
"RNNCell",
"LSTMCell",
"GRUCell",
"LSTM",
"GRU",
"Embedding",
"EmbeddingBag",
]

View File

@ -0,0 +1,511 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.common_types import _size_1_t
from .utils import ReferenceQuantizedModule
__all__ = [
"Conv1d",
"Conv2d",
"Conv3d",
"ConvTranspose1d",
"ConvTranspose2d",
"ConvTranspose3d",
]
class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule):
"""A reference version of nn.quantized.Conv2d
we will not pack the parameters in this module, since weight packing is an
optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
this is useful when user want to use this module in other backends like Glow.
"""
__annotations__ = {"bias": Optional[torch.Tensor]}
_IS_REFERENCE = True
@staticmethod
def from_float(cls, float_conv, weight_qparams):
qref_conv = cls(
float_conv.in_channels,
float_conv.out_channels,
float_conv.kernel_size, # type: ignore[arg-type]
float_conv.stride, # type: ignore[arg-type]
float_conv.padding, # type: ignore[arg-type]
float_conv.dilation, # type: ignore[arg-type]
float_conv.groups,
float_conv.bias is not None, # type: ignore[arg-type]
float_conv.padding_mode,
device=float_conv.weight.device,
dtype=float_conv.weight.dtype,
weight_qparams=weight_qparams,
)
qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
if float_conv.bias is not None:
qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
return qref_conv
class Conv1d(_ConvNd, nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.Conv1d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.conv1d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.conv1d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv1d
"""
weight_quant_dequant = self.get_weight()
result = F.conv1d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return result
def _get_name(self):
return "QuantizedConv1d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvNd.from_float(cls, float_conv, weight_qparams)
class Conv2d(_ConvNd, nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.Conv2d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.conv2d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.conv2d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv2d
"""
weight_quant_dequant = self.get_weight()
result = F.conv2d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return result
def _get_name(self):
return "QuantizedConv2d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvNd.from_float(cls, float_conv, weight_qparams)
class Conv3d(_ConvNd, nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.Conv3d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.conv3d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.conv3d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv3d
"""
weight_quant_dequant = self.get_weight()
result = F.conv3d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
return result
def _get_name(self):
return "QuantizedConv3d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvNd.from_float(cls, float_conv, weight_qparams)
class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
"""A reference version of nn.quantized.ConvTranspose2d
we will not pack the parameters in this module, since weight packing is an
optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
this is useful when user want to use this module in other backends like Glow.
"""
@staticmethod
def from_float(cls, float_conv, weight_qparams):
qref_conv = cls(
float_conv.in_channels,
float_conv.out_channels,
float_conv.kernel_size, # type: ignore[arg-type]
float_conv.stride, # type: ignore[arg-type]
float_conv.padding, # type: ignore[arg-type]
float_conv.output_padding, # type: ignore[arg-type]
float_conv.groups,
float_conv.bias is not None, # type: ignore[arg-type]
float_conv.dilation, # type: ignore[arg-type]
float_conv.padding_mode,
device=float_conv.weight.device,
dtype=float_conv.weight.dtype,
weight_qparams=weight_qparams,
)
qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
if float_conv.bias is not None:
qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
return qref_conv
class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
output_padding: _size_1_t = 0,
groups: int = 1,
bias: bool = True,
dilation: _size_1_t = 1,
padding_mode: str = "zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.ConvTranspose1d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(
self, x: torch.Tensor, output_size: Optional[List[int]] = None
) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.convTranspose1d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv1d
"""
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
output_padding = self._output_padding(
input, # type: ignore[arg-type]
output_size,
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.kernel_size, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
)
weight_quant_dequant = self.get_weight()
result = F.conv_transpose1d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation,
)
return result
def _get_name(self):
return "QuantizedConvTranspose1d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.ConvTranspose2d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(
self, x: torch.Tensor, output_size: Optional[List[int]] = None
) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.convTranspose2d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv2d
"""
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
output_padding = self._output_padding(
input, # type: ignore[arg-type]
output_size,
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.kernel_size, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
)
weight_quant_dequant = self.get_weight()
result = F.conv_transpose2d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation,
)
return result
def _get_name(self):
return "QuantizedConvTranspose2d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups=1,
bias=True,
dilation=1,
padding_mode="zeros",
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
nn.ConvTranspose3d.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
groups,
bias,
dilation,
padding_mode,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def forward(
self, x: torch.Tensor, output_size: Optional[List[int]] = None
) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.convTranspose3d ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized conv3d
"""
assert isinstance(self.padding, tuple)
# One cannot replace List by Tuple or Sequence in "_output_padding" because
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
output_padding = self._output_padding(
input, # type: ignore[arg-type]
output_size,
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.kernel_size, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
)
weight_quant_dequant = self.get_weight()
result = F.conv_transpose3d(
x,
weight_quant_dequant,
self.bias,
self.stride,
self.padding,
output_padding,
self.groups,
self.dilation,
)
return result
def _get_name(self):
return "QuantizedConvTranspose3d(Reference)"
@classmethod
def from_float(cls, float_conv, weight_qparams):
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)

View File

@ -0,0 +1,68 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import ReferenceQuantizedModule
__all__ = ["Linear"]
class Linear(nn.Linear, ReferenceQuantizedModule):
"""A reference quantized linear module that fits into the FX
Graph Mode Quantization workflow
activation will be floating point Tensor, we will store floating
point weight as well in the module, but in forward we'll quantize
and dequantize the weight before running the floating point functional
linear operator.
"""
_IS_REFERENCE = True
def __init__(
self,
in_features: int,
out_features: int,
bias_: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
weight_qparams: Optional[Dict[str, Any]] = None,
):
super().__init__(in_features, out_features, bias_, device, dtype)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedLinear(Reference)"
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
we have:
w(float) -- quant - dequant \
x(float) ------------- F.linear ---
In the full model, we will see
w(float) -- quant - *dequant \
x -- quant --- *dequant -- *F.linear --- *quant - dequant
and the backend should be able to fuse the ops with `*` into a quantized linear
"""
weight_quant_dequant = self.get_weight()
result = F.linear(x, weight_quant_dequant, self.bias)
return result
@classmethod
def from_float(cls, float_linear, weight_qparams):
qref_linear = Linear(
float_linear.in_features,
float_linear.out_features,
float_linear.bias is not None,
device=float_linear.weight.device,
dtype=float_linear.weight.dtype,
weight_qparams=weight_qparams,
)
qref_linear.weight = torch.nn.Parameter(float_linear.weight.detach())
if float_linear.bias is not None:
qref_linear.bias = torch.nn.Parameter(float_linear.bias.detach())
return qref_linear

View File

@ -0,0 +1,845 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch import _VF, Tensor
from torch.nn.utils.rnn import PackedSequence
from .utils import _quantize_and_dequantize_weight, _quantize_weight
__all__ = [
"RNNCellBase",
"RNNCell",
"LSTMCell",
"GRUCell",
"RNNBase",
"LSTM",
"GRU",
"get_quantized_weight",
]
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)
def _get_weight_and_quantization_params(module, wn):
weight = getattr(module, wn)
params = [weight]
for param_name in [
wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis_int"]
]:
if hasattr(module, param_name):
param = getattr(module, param_name)
else:
param = None
params.append(param)
return params
def get_quantized_weight(module, wn):
if not hasattr(module, wn):
return None
params = _get_weight_and_quantization_params(module, wn)
weight = _quantize_weight(*params)
return weight
def _get_quantize_and_dequantized_weight(module, wn):
if not hasattr(module, wn):
return None
params = _get_weight_and_quantization_params(module, wn)
weight = _quantize_and_dequantize_weight(*params)
return weight
class RNNCellBase(nn.RNNCellBase):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool,
num_chunks: int,
device=None,
dtype=None,
weight_qparams_dict=None,
) -> None:
super().__init__(
input_size, hidden_size, bias, num_chunks, device=device, dtype=dtype
)
# TODO(jerryzh168): maybe make this arg a required arg
if weight_qparams_dict is None:
weight_qparams = {
"qscheme": torch.per_tensor_affine,
"dtype": torch.quint8,
"scale": 1.0,
"zero_point": 0,
}
weight_qparams_dict = {
"weight_ih": weight_qparams,
"weight_hh": weight_qparams,
"is_decomposed": False,
}
assert (
len(weight_qparams_dict) == 3
), "Expected length for weight_qparams_dict to be 3 for QuantizedRNNCellBase(Reference)"
self._init_weight_qparams_dict(weight_qparams_dict, device)
def _init_weight_qparams_dict(self, weight_qparams_dict, device):
assert weight_qparams_dict is not None
self.is_decomposed = weight_qparams_dict["is_decomposed"]
for key, weight_qparams in weight_qparams_dict.items():
if key == "is_decomposed":
continue
# TODO: refactor the duplicated code to utils.py
weight_qscheme = weight_qparams["qscheme"]
weight_dtype = weight_qparams["dtype"]
setattr(self, key + "_qscheme", weight_qscheme)
setattr(self, key + "_dtype", weight_dtype)
assert weight_qscheme in [
None,
torch.per_tensor_affine,
torch.per_channel_affine,
], Exception(
f"qscheme: {weight_qscheme} is not support in {self._get_name()}"
)
if weight_qscheme is not None:
scale = weight_qparams["scale"]
scale_tensor = (
scale.clone().detach()
if isinstance(scale, torch.Tensor)
else torch.tensor(scale, dtype=torch.float, device=device)
)
self.register_buffer(key + "_scale", scale_tensor)
zp = weight_qparams["zero_point"]
zp_tensor = (
zp.clone().detach()
if isinstance(zp, torch.Tensor)
else torch.tensor(zp, dtype=torch.int, device=device)
)
self.register_buffer(key + "_zero_point", zp_tensor)
if weight_qscheme == torch.per_channel_affine:
axis = weight_qparams["axis"]
axis_tensor = (
axis.clone().detach()
if isinstance(axis, torch.Tensor)
else torch.tensor(axis, dtype=torch.int, device=device)
)
self.register_buffer(key + "_axis", axis_tensor)
else:
# added for TorchScriptability, not used
self.register_buffer(
key + "_axis", torch.tensor(0, dtype=torch.int, device=device)
)
setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
def _get_name(self):
return "QuantizedRNNCellBase(Reference)"
def get_quantized_weight_ih(self):
return get_quantized_weight(self, "weight_ih")
def get_quantized_weight_hh(self):
return get_quantized_weight(self, "weight_hh")
def get_weight_ih(self):
return _get_quantize_and_dequantized_weight(self, "weight_ih")
def get_weight_hh(self):
return _get_quantize_and_dequantized_weight(self, "weight_hh")
class RNNCell(RNNCellBase):
"""
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
we need to pass in a `weight_qparams_dict` that maps from weight name,
e.g. weight_ih, to the weight_qparams for that weight
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
nonlinearity: str = "tanh",
device=None,
dtype=None,
weight_qparams_dict: Optional[Dict[str, Any]] = None,
) -> None:
factory_kwargs = {
"device": device,
"dtype": dtype,
"weight_qparams_dict": weight_qparams_dict,
}
super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
self.nonlinearity = nonlinearity
def _get_name(self):
return "QuantizedRNNCell(Reference)"
# TODO: refactor nn.RNNCell to have a _forward that takes weight_ih and weight_hh as input
# and remove duplicated code, same for the other two Cell modules
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
assert input.dim() in (
1,
2,
), f"RNNCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
is_batched = input.dim() == 2
if not is_batched:
input = input.unsqueeze(0)
if hx is None:
hx = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
else:
hx = hx.unsqueeze(0) if not is_batched else hx
if self.nonlinearity == "tanh":
ret = _VF.rnn_tanh_cell(
input,
hx,
self.get_weight_ih(),
self.get_weight_hh(),
self.bias_ih,
self.bias_hh,
)
elif self.nonlinearity == "relu":
ret = _VF.rnn_relu_cell(
input,
hx,
self.get_weight_ih(),
self.get_weight_hh(),
self.bias_ih,
self.bias_hh,
)
else:
ret = input # TODO: remove when jit supports exception flow
raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
if not is_batched:
ret = ret.squeeze(0)
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.nonlinearity,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict,
)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class LSTMCell(RNNCellBase):
"""
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
we need to pass in a `weight_qparams_dict` that maps from weight name,
e.g. weight_ih, to the weight_qparams for that weight
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
device=None,
dtype=None,
weight_qparams_dict: Optional[Dict[str, Any]] = None,
) -> None:
factory_kwargs = {
"device": device,
"dtype": dtype,
"weight_qparams_dict": weight_qparams_dict,
}
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
def _get_name(self):
return "QuantizedLSTMCell(Reference)"
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
assert input.dim() in (
1,
2,
), f"LSTMCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
is_batched = input.dim() == 2
if not is_batched:
input = input.unsqueeze(0)
if hx is None:
zeros = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
hx = (zeros, zeros)
else:
hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
ret = _VF.lstm_cell(
input,
hx,
self.get_weight_ih(),
self.get_weight_hh(),
self.bias_ih,
self.bias_hh,
)
if not is_batched:
ret = (ret[0].squeeze(0), ret[1].squeeze(0))
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict,
)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class GRUCell(RNNCellBase):
"""
We'll store weight_qparams for all the weights (weight_ih and weight_hh),
we need to pass in a `weight_qparams_dict` that maps from weight name,
e.g. weight_ih, to the weight_qparams for that weight
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
device=None,
dtype=None,
weight_qparams_dict: Optional[Dict[str, Any]] = None,
) -> None:
factory_kwargs = {
"device": device,
"dtype": dtype,
"weight_qparams_dict": weight_qparams_dict,
}
super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
def _get_name(self):
return "QuantizedGRUCell(Reference)"
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
assert input.dim() in (
1,
2,
), f"GRUCell: Expected input to be 1-D or 2-D but received {input.dim()}-D tensor"
is_batched = input.dim() == 2
if not is_batched:
input = input.unsqueeze(0)
if hx is None:
hx = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
else:
hx = hx.unsqueeze(0) if not is_batched else hx
ret = _VF.gru_cell(
input,
hx,
self.get_weight_ih(),
self.get_weight_hh(),
self.bias_ih,
self.bias_hh,
)
if not is_batched:
ret = ret.squeeze(0)
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict,
)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class RNNBase(nn.RNNBase):
def __init__(
self,
mode: str,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: bool = True,
batch_first: bool = False,
dropout: float = 0.0,
bidirectional: bool = False,
proj_size: int = 0,
device=None,
dtype=None,
weight_qparams_dict: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
mode,
input_size,
hidden_size,
num_layers,
bias,
batch_first,
dropout,
bidirectional,
proj_size,
device,
dtype,
)
# TODO(jerryzh168): maybe make this arg a required arg
if weight_qparams_dict is None:
weight_qparams = {
"qscheme": torch.per_tensor_affine,
"dtype": torch.quint8,
"scale": 1.0,
"zero_point": 0,
}
weight_qparams_dict = {"is_decomposed": False} # type: ignore[dict-item]
for wn in self._flat_weights_names:
if wn.startswith("weight"):
weight_qparams_dict[wn] = weight_qparams
self._init_weight_qparams_dict(weight_qparams_dict, device)
def _init_weight_qparams_dict(self, weight_qparams_dict, device):
self.is_decomposed = weight_qparams_dict["is_decomposed"]
for key, weight_qparams in weight_qparams_dict.items():
if key == "is_decomposed":
continue
weight_qscheme = weight_qparams["qscheme"]
weight_dtype = weight_qparams["dtype"]
setattr(self, key + "_qscheme", weight_qscheme)
setattr(self, key + "_dtype", weight_dtype)
assert weight_qscheme in [
None,
torch.per_tensor_affine,
torch.per_channel_affine,
], Exception(
f"qscheme: {weight_qscheme} is not support in {self._get_name()}"
)
if weight_qscheme is not None:
self.register_buffer(
key + "_scale",
torch.tensor(
weight_qparams["scale"], dtype=torch.float, device=device
),
)
self.register_buffer(
key + "_zero_point",
torch.tensor(
weight_qparams["zero_point"], dtype=torch.int, device=device
),
)
if weight_qscheme == torch.per_channel_affine:
self.register_buffer(
key + "_axis",
torch.tensor(
weight_qparams["axis"], dtype=torch.int, device=device
),
)
else:
# added for TorchScriptability, not used
self.register_buffer(
key + "_axis", torch.tensor(0, dtype=torch.int, device=device)
)
setattr(self, key + "_axis_int", getattr(self, key + "_axis").item())
class LSTM(RNNBase):
"""Reference Quantized LSTM Module
We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
to the weight_qparams for that weight
"""
def __init__(self, *args, **kwargs):
super().__init__("LSTM", *args, **kwargs)
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
def permute_hidden( # type: ignore[override]
self,
hx: Tuple[Tensor, Tensor],
permutation: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return _apply_permutation(hx[0], permutation), _apply_permutation(
hx[1], permutation
)
def get_expected_cell_size(
self, input: Tensor, batch_sizes: Optional[Tensor]
) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (
self.num_layers * num_directions,
mini_batch,
self.hidden_size,
)
return expected_hidden_size
# In the future, we should prevent mypy from applying contravariance rules here.
# See torch/nn/modules/module.py::_forward_unimplemented
def check_forward_args( # type: ignore[override]
self,
input: Tensor,
hidden: Tuple[Tensor, Tensor],
batch_sizes: Optional[Tensor],
):
self.check_input(input, batch_sizes)
self.check_hidden_size(
hidden[0],
self.get_expected_hidden_size(input, batch_sizes),
"Expected hidden[0] size {}, got {}",
)
self.check_hidden_size(
hidden[1],
self.get_expected_cell_size(input, batch_sizes),
"Expected hidden[1] size {}, got {}",
)
def get_quantized_weight_bias_dict(self):
"""dictionary from flat_weight_name to quantized weight or (unquantized) bias
e.g.
{
"weight_ih_l0": quantized_weight,
"bias_ih_l0": unquantized_bias,
...
}
"""
quantized_weight_bias_dict = {}
for wn in self._flat_weights_names:
if hasattr(self, wn):
if wn.startswith("weight"):
weight_or_bias = get_quantized_weight(self, wn)
else:
weight_or_bias = getattr(self, wn)
else:
weight_or_bias = None
quantized_weight_bias_dict[wn] = weight_or_bias
return quantized_weight_bias_dict
def get_flat_weights(self):
flat_weights = []
for wn in self._flat_weights_names:
if hasattr(self, wn):
weight = getattr(self, wn)
if wn.startswith("weight"):
params = _get_weight_and_quantization_params(self, wn)
weight = _quantize_and_dequantize_weight(*params)
else:
weight = None
flat_weights.append(weight)
return flat_weights
def forward(self, input, hx=None): # noqa: F811
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
batch_sizes = None
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
input = input.unsqueeze(batch_dim)
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
real_hidden_size = (
self.proj_size if self.proj_size > 0 else self.hidden_size
)
h_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
real_hidden_size,
dtype=input.dtype,
device=input.device,
)
c_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
hx = (h_zeros, c_zeros)
else:
if batch_sizes is None: # If not PackedSequence input.
if is_batched: # type: ignore[possibly-undefined]
if hx[0].dim() != 3 or hx[1].dim() != 3:
msg = (
"For batched 3-D input, hx and cx should "
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
)
raise RuntimeError(msg)
else:
if hx[0].dim() != 2 or hx[1].dim() != 2:
msg = (
"For unbatched 2-D input, hx and cx should "
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
)
raise RuntimeError(msg)
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.lstm(
input,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
self.batch_first,
)
else:
result = _VF.lstm(
input,
batch_sizes,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
)
output = result[0]
hidden = result[1:]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
return output, self.permute_hidden(hidden, unsorted_indices)
def _get_name(self):
return "QuantizedLSTM(Reference)"
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.num_layers,
mod.bias,
mod.batch_first,
mod.dropout,
mod.bidirectional,
weight_qparams_dict=weight_qparams_dict,
)
for wn in mod._flat_weights_names:
setattr(ref_mod, wn, getattr(mod, wn))
return ref_mod
class GRU(RNNBase):
"""Reference Quantized GRU Module
We'll store weight_qparams for all the weights in _flat_weights, we need to pass in
a `weight_qparams_dict` that maps from weight name, e.g. weight_ih_l0,
to the weight_qparams for that weight
"""
def __init__(self, *args, **kwargs):
if "proj_size" in kwargs:
raise ValueError(
"proj_size argument is only supported for LSTM, not RNN or GRU"
)
super().__init__("GRU", *args, **kwargs)
def get_quantized_weight_bias_dict(self):
"""dictionary from flat_weight_name to quantized weight or (unquantized) bias
e.g.
{
"weight_ih_l0": quantized_weight,
"bias_ih_l0": unquantized_bias,
...
}
"""
quantized_weight_bias_dict = {}
for wn in self._flat_weights_names:
if hasattr(self, wn):
if wn.startswith("weight"):
weight_or_bias = get_quantized_weight(self, wn)
else:
weight_or_bias = getattr(self, wn)
else:
weight_or_bias = None
quantized_weight_bias_dict[wn] = weight_or_bias
return quantized_weight_bias_dict
def get_flat_weights(self):
flat_weights = []
for wn in self._flat_weights_names:
if hasattr(self, wn):
weight = getattr(self, wn)
if wn.startswith("weight"):
params = _get_weight_and_quantization_params(self, wn)
weight = _quantize_and_dequantize_weight(*params)
else:
weight = None
flat_weights.append(weight)
return flat_weights
def forward(self, input, hx=None): # noqa: F811
# Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
# only changed self._flat_weights to self.get_flat_weights()
# TODO: maybe we can try inheriting from that class and define get_flat_weights
# as a @property? this might interfere with TorchScript, if we remove that
# requirement in the future we should be able to do this
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
assert input.dim() in (
2,
3,
), f"GRU: Expected input to be 2-D or 3-D but received {input.dim()}-D tensor"
is_batched = input.dim() == 3
batch_dim = 0 if self.batch_first else 1
if not is_batched:
input = input.unsqueeze(batch_dim)
if hx is not None:
if hx.dim() != 2:
raise RuntimeError(
f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
)
hx = hx.unsqueeze(1)
else:
if hx is not None and hx.dim() != 3:
raise RuntimeError(
f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
)
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.gru(
input,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
self.batch_first,
)
else:
result = _VF.gru(
input,
batch_sizes,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
)
output = result[0]
hidden = result[1]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
if not is_batched: # type: ignore[possibly-undefined]
output = output.squeeze(batch_dim) # type: ignore[possibly-undefined]
hidden = hidden.squeeze(1)
return output, self.permute_hidden(hidden, unsorted_indices)
def _get_name(self):
return "QuantizedGRU(Reference)"
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.num_layers,
mod.bias,
mod.batch_first,
mod.dropout,
mod.bidirectional,
weight_qparams_dict=weight_qparams_dict,
)
for wn in mod._flat_weights_names:
setattr(ref_mod, wn, getattr(mod, wn))
return ref_mod

View File

@ -0,0 +1,162 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, Optional
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .utils import ReferenceQuantizedModule
__all__ = ["Embedding", "EmbeddingBag"]
class Embedding(nn.Embedding, ReferenceQuantizedModule):
"""A reference quantized Embedding module that fits into the
FX Graph Mode Quantization workflow, activation will be floating point Tensor,
we will store floating point weight as well in the module, but in forward we'll
quantize and dequantize the weight before running the floating point functional
embedding operator.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedEmbedding(Reference)"
def forward(self, input: Tensor) -> Tensor:
weight_quant_dequant = self.get_weight()
return F.embedding(
input,
weight_quant_dequant,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
@classmethod
def from_float(cls, mod, weight_qparams):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
mod.weight,
mod.weight.device,
mod.weight.dtype,
weight_qparams,
)
class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
"""A reference quantized EmbeddingBag module that fits into the
FX Graph Mode Quantization workflow, activation will be floating point Tensor,
we will store floating point weight as well in the module, but in forward we'll
quantize and dequantize the weight before running the floating point functional
embedding operator.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
mode: str = "mean",
sparse: bool = False,
_weight: Optional[Tensor] = None,
include_last_offset: bool = False,
padding_idx: Optional[int] = None,
device=None,
dtype=None,
weight_qparams: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
max_norm,
norm_type,
scale_grad_by_freq,
mode,
sparse,
_weight,
include_last_offset,
padding_idx,
device,
dtype,
)
self._init_weight_qparams(weight_qparams, device)
def _get_name(self):
return "QuantizedEmbedding(Reference)"
def forward(
self,
input: Tensor,
offsets: Optional[Tensor] = None,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
weight_quant_dequant = self.get_weight()
return F.embedding_bag(
input,
weight_quant_dequant,
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
@classmethod
def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.mode,
mod.sparse,
mod.weight,
mod.include_last_offset,
mod.padding_idx,
mod.weight.device,
mod.weight.dtype,
weight_qparams,
)

View File

@ -0,0 +1,431 @@
# mypy: allow-untyped-defs
import typing
import torch
__all__ = [
"ReferenceQuantizedModule",
]
class ReferenceQuantizedModule(torch.nn.Module):
def _init_weight_qparams(self, weight_qparams, device):
if weight_qparams is None:
weight_qparams = {
"qscheme": torch.per_tensor_affine,
"dtype": torch.quint8,
"scale": 1.0,
"zero_point": 0,
}
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [
None,
torch.per_tensor_affine,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
], f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}"
if self.weight_dtype in [
torch.quint8,
torch.qint8,
torch.quint4x2,
torch.qint32,
]:
zero_point_dtype = (
weight_qparams["zero_point"].dtype
if isinstance(weight_qparams["zero_point"], torch.Tensor)
else torch.int
)
w_scale = weight_qparams["scale"]
w_scale_tensor = (
w_scale.clone().detach()
if isinstance(w_scale, torch.Tensor)
else torch.tensor(w_scale, dtype=torch.float, device=device)
)
self.register_buffer("weight_scale", w_scale_tensor)
w_zp = weight_qparams["zero_point"]
w_zp_tensor = (
w_zp.clone().detach()
if isinstance(w_zp, torch.Tensor)
else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
)
self.register_buffer("weight_zero_point", w_zp_tensor)
if self.weight_qscheme in [
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
w_axis = weight_qparams["axis"]
w_axis_tensor = (
w_axis.clone().detach()
if isinstance(w_axis, torch.Tensor)
else torch.tensor(w_axis, dtype=torch.int, device=device)
)
self.register_buffer("weight_axis", w_axis_tensor)
else:
# added for TorchScriptability, not used
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
)
else:
# added for TorchScriptability, and for torch.float
self.register_buffer(
"weight_scale", torch.tensor(1.0, dtype=torch.float, device=device)
)
self.register_buffer(
"weight_zero_point", torch.tensor(0, dtype=torch.int, device=device)
)
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
)
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
"quant_min", None
)
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
"quant_max", None
)
def get_weight(self):
"""
Fake quantize (quantize and dequantize) the weight with
the quantization parameters for weight, this is used to
simulate the numerics for the quantized weight in a quantized
model
"""
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
if self.is_decomposed:
return _quantize_and_dequantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis_int,
self.weight_quant_min,
self.weight_quant_max,
)
else:
return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis_int,
)
def get_quantized_weight(self):
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
# assert isinstance(self.weight_axis, torch.Tensor)
if self.is_decomposed:
return _quantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis_int,
self.weight_quant_min,
self.weight_quant_max,
)
else:
return _quantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis_int,
)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
_save_weight_qparams(
destination,
prefix,
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis,
)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
for key in _get_weight_qparam_keys(state_dict, prefix):
setattr(self, key, state_dict[prefix + key])
state_dict.pop(prefix + key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
def _quantize_weight_decomposed(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: int,
weight_quant_min: typing.Optional[int],
weight_quant_max: typing.Optional[int],
) -> torch.Tensor:
_DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int32: (-(2**31), 2**31 - 1),
}
# TODO: add an util function for converting qdtype to dtype
_QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
torch.quint8: torch.uint8,
torch.qint8: torch.int8,
torch.qint32: torch.int32,
}
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
if weight_quant_min is None or weight_quant_max is None:
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[
weight_dtype_
]
weight = torch.ops.quantized_decomposed.quantize_per_tensor(
weight,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
weight_dtype_,
)
return weight
elif weight_qscheme in [
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
# TODO: torch.quint4x2 is not supported
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
if weight_quant_min is None or weight_quant_max is None:
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[
weight_dtype_
]
weight = torch.ops.quantized_decomposed.quantize_per_channel(
weight,
weight_scale,
weight_zero_point,
weight_axis,
weight_quant_min,
weight_quant_max,
weight_dtype_,
) # type: ignore[arg-type]
return weight
raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _dequantize_weight_decomposed(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: int,
weight_quant_min: typing.Optional[int],
weight_quant_max: typing.Optional[int],
) -> torch.Tensor:
# TODO: get the quant_min and quant_max from activation_post_process
_DTYPE_TO_QVALUE_BOUNDS = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int32: (-(2**31), 2**31 - 1),
}
# TODO: add an util function for converting qdtype to dtype
_QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
torch.quint8: torch.uint8,
torch.qint8: torch.int8,
torch.qint32: torch.int32,
}
weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
if weight_quant_min is None or weight_quant_max is None:
weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
weight,
weight_scale,
weight_zero_point,
weight_quant_min,
weight_quant_max,
weight_dtype_,
)
return weight
elif weight_qscheme in [
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
# TODO: torch.quint4x2 is not supported
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.ops.quantized_decomposed.dequantize_per_channel(
weight,
weight_scale,
weight_zero_point,
weight_axis,
weight_quant_min,
weight_quant_max,
weight_dtype_,
) # type: ignore[arg-type]
return weight
raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _quantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis_int: int,
) -> torch.Tensor:
if weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
return weight
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.quantize_per_tensor(
weight, weight_scale, weight_zero_point, weight_dtype
)
return weight
elif weight_qscheme in [
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
weight = torch.quantize_per_channel(
weight, weight_scale, weight_zero_point, weight_axis_int, weight_dtype
) # type: ignore[arg-type]
return weight
raise ValueError(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _quantize_and_dequantize_weight_decomposed(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis_int: int,
weight_quant_min: typing.Optional[int],
weight_quant_max: typing.Optional[int],
) -> torch.Tensor:
"""Quantize and then dequantize the weight based on
the quantization parameters
"""
if weight_qscheme in [
torch.per_tensor_affine,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
weight_quant = _quantize_weight_decomposed(
weight,
weight_qscheme,
weight_dtype,
weight_scale,
weight_zero_point,
weight_axis_int,
weight_quant_min,
weight_quant_max,
)
weight_dequant = _dequantize_weight_decomposed(
weight_quant,
weight_qscheme,
weight_dtype,
weight_scale,
weight_zero_point,
weight_axis_int,
weight_quant_min,
weight_quant_max,
)
else:
weight_dequant = weight
return weight_dequant
def _quantize_and_dequantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis_int: int,
) -> torch.Tensor:
"""Quantize and then dequantize the weight based on
the quantization parameters
"""
if weight_qscheme in [
torch.per_tensor_affine,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
weight_quant = _quantize_weight(
weight,
weight_qscheme,
weight_dtype,
weight_scale,
weight_zero_point,
weight_axis_int,
)
weight_dequant = weight_quant.dequantize()
else:
weight_dequant = weight
return weight_dequant
def _save_weight_qparams(
destination,
prefix,
weight_qscheme,
weight_dtype,
weight_scale,
weight_zero_point,
weight_axis,
):
destination[prefix + "weight_qscheme"] = weight_qscheme
destination[prefix + "weight_dtype"] = weight_dtype
if weight_qscheme is not None:
destination[prefix + "weight_scale"] = weight_scale
destination[prefix + "weight_zero_point"] = weight_zero_point
if weight_qscheme == torch.per_channel_affine:
destination[prefix + "weight_axis"] = weight_axis
def _get_weight_qparam_keys(state_dict: typing.Dict[str, typing.Any], prefix: str):
keys = ["weight_qscheme", "weight_dtype"]
weight_qscheme = state_dict[prefix + "weight_qscheme"]
if weight_qscheme is not None:
keys.append("weight_scale")
keys.append("weight_zero_point")
if weight_qscheme == torch.quantize_per_channel:
keys.append("weight_axis")
return keys