I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,40 @@
# mypy: allow-untyped-defs
from .modules import * # noqa: F403
from .modules.fused import _FusedModule # noqa: F403
# # Subpackages
# from . import qat # noqa: F403
# from . import quantized # noqa: F403
__all__ = [
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/
def __getattr__(name):
if name in __all__:
import importlib
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,41 @@
from .fused import ( # noqa: F401
_FusedModule,
BNReLU2d,
BNReLU3d,
ConvAdd2d,
ConvAddReLU2d,
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
LinearBn1d,
LinearLeakyReLU,
LinearReLU,
LinearTanh,
)
__all__ = [
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1,245 @@
# mypy: allow-untyped-defs
import torch
from torch.nn import (
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
Conv1d,
Conv2d,
Conv3d,
Linear,
ReLU,
)
from torch.nn.utils.parametrize import type_before_parametrizations
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"ConvBn1d",
"ConvBn2d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBn3d",
"ConvBnReLU3d",
"BNReLU2d",
"BNReLU3d",
"LinearBn1d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]
# Used for identifying intrinsic modules used in quantization
class _FusedModule(torch.nn.Sequential):
pass
class ConvReLU1d(_FusedModule):
r"""This is a sequential container which calls the Conv1d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class ConvReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class ConvReLU3d(_FusedModule):
r"""This is a sequential container which calls the Conv3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, relu):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}"
super().__init__(conv, relu)
class LinearReLU(_FusedModule):
r"""This is a sequential container which calls the Linear and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, relu):
assert (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}"
super().__init__(linear, relu)
class ConvBn1d(_FusedModule):
r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBn2d(_FusedModule):
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBnReLU1d(_FusedModule):
r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv1d
and type_before_parametrizations(bn) == BatchNorm1d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class ConvBnReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv2d
and type_before_parametrizations(bn) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class ConvBn3d(_FusedModule):
r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}"
super().__init__(conv, bn)
class ConvBnReLU3d(_FusedModule):
r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, bn, relu):
assert (
type_before_parametrizations(conv) == Conv3d
and type_before_parametrizations(bn) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950
super().__init__(conv, bn, relu)
class BNReLU2d(_FusedModule):
r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
type_before_parametrizations(batch_norm) == BatchNorm2d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
super().__init__(batch_norm, relu)
class BNReLU3d(_FusedModule):
r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, batch_norm, relu):
assert (
type_before_parametrizations(batch_norm) == BatchNorm3d
and type_before_parametrizations(relu) == ReLU
), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}"
super().__init__(batch_norm, relu)
class LinearBn1d(_FusedModule):
r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, bn):
assert (
type_before_parametrizations(linear) == Linear
and type_before_parametrizations(bn) == BatchNorm1d
), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}"
super().__init__(linear, bn)
class LinearLeakyReLU(_FusedModule):
r"""This is a sequential container which calls the Linear and LeakyReLU modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, leaky_relu):
assert (
type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU
), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}"
super().__init__(linear, leaky_relu)
class LinearTanh(_FusedModule):
r"""This is a sequential container which calls the Linear and Tanh modules.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, linear, tanh):
assert (
type(linear) == Linear and type(tanh) == torch.nn.Tanh
), f"Incorrect types for input modules{type(linear)}{type(tanh)}"
super().__init__(linear, tanh)
class ConvAdd2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d modules with extra Add.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, add):
super().__init__(conv)
self.add = add
def forward(self, x1, x2):
return self.add(self[0](x1), x2)
class ConvAddReLU2d(_FusedModule):
r"""This is a sequential container which calls the Conv2d, add, Relu.
During quantization this will be replaced with the corresponding fused module."""
def __init__(self, conv, add, relu):
super().__init__(conv)
self.add = add
self.relu = relu
def forward(self, x1, x2):
return self.relu(self.add(self[0](x1), x2))

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,32 @@
from .conv_fused import (
ConvBn1d,
ConvBn2d,
ConvBn3d,
ConvBnReLU1d,
ConvBnReLU2d,
ConvBnReLU3d,
ConvReLU1d,
ConvReLU2d,
ConvReLU3d,
freeze_bn_stats,
update_bn_stats,
)
from .linear_fused import LinearBn1d
from .linear_relu import LinearReLU
__all__ = [
"LinearReLU",
"LinearBn1d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"ConvBn1d",
"ConvBn2d",
"ConvBn3d",
"ConvBnReLU1d",
"ConvBnReLU2d",
"ConvBnReLU3d",
"update_bn_stats",
"freeze_bn_stats",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,193 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.utils.fusion import fuse_linear_bn_weights
__all__ = [
"LinearBn1d",
]
class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
r"""
A LinearBn1d module is a module fused from Linear and BatchNorm1d, attached
with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of :class:`torch.nn.Linear` and
:class:torch.nn.BatchNorm1d`.
Similar to :class:`torch.nn.Linear`, with FakeQuantize modules initialized
to default.
Attributes:
freeze_bn:
weight_fake_quant: fake quant module for weight
"""
def __init__(
self,
# Linear args
in_features,
out_features,
bias=True,
# BatchNorm1d args
# num_features: out_features
eps=1e-05,
momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
nn.modules.linear.Linear.__init__(self, in_features, out_features, bias)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.freeze_bn = freeze_bn if self.training else True
self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True)
self.weight_fake_quant = self.qconfig.weight()
if bias:
self.bias = Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
self.reset_bn_parameters()
# this needs to be called after reset_bn_parameters,
# as they modify the same state
if self.training:
if freeze_bn:
self.freeze_bn_stats()
else:
self.update_bn_stats()
else:
self.freeze_bn_stats()
def reset_running_stats(self):
self.bn.reset_running_stats()
def reset_bn_parameters(self):
self.bn.reset_running_stats()
init.uniform_(self.bn.weight)
init.zeros_(self.bn.bias)
def reset_parameters(self):
super().reset_parameters()
def update_bn_stats(self):
self.freeze_bn = False
self.bn.training = True
return self
def freeze_bn_stats(self):
self.freeze_bn = True
self.bn.training = False
return self
def forward(self, input):
assert self.bn.running_var is not None
# Scale the linear weights by BN's running statistics to reduce
# weight jitter, see https://arxiv.org/pdf/1806.08342.pdf, page 18
# for motivation.
#
# Instead of
#
# x1 = F.linear(x0, fq(w), b)
# x2 = self.bn(x1)
#
# We have
#
# # scale the weight by previous batch's running statistics
# scale_factor = bn.w / bn.running_std_from_prev_batch
# # do the linear transformation without bias
# x1_scaled = F.linear(x0, fq(w * scale_factor), 0)
# # reverse the scaling and add original bias
# x1_orig = x1_scaled / scale_factor + b
# x2 = self.bn(x1_orig)
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.weight.shape)
bias_shape[1] = -1
scaled_weight = self.weight_fake_quant(
self.weight * scale_factor.reshape(weight_shape)
)
if self.bias is not None:
zero_bias = torch.zeros_like(self.bias)
else:
zero_bias = torch.zeros(self.out_features, device=scaled_weight.device)
linear_out = F.linear(input, scaled_weight, zero_bias)
linear_out_orig = linear_out / scale_factor.reshape(bias_shape)
if self.bias is not None:
linear_out_orig = linear_out_orig + self.bias.reshape(bias_shape)
bn_out = self.bn(linear_out_orig)
return bn_out
def train(self, mode=True):
"""
Batchnorm's training behavior is using the self.training flag. Prevent
changing it if BN is frozen. This makes sure that calling `model.train()`
on a model with a frozen BN will behave properly.
"""
self.training = mode
if not self.freeze_bn:
for module in self.children():
module.train(mode)
return self
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod' a float module, either produced by torch.ao.quantization
utilities or directly from user
"""
assert type(mod) == nni.LinearBn1d, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ nni.LinearBn1d.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid config"
qconfig = mod.qconfig
linear, bn = mod[0], mod[1]
qat_linearbn = cls(
linear.in_features,
linear.out_features,
linear.bias is not None,
bn.eps,
bn.momentum,
False,
qconfig,
)
qat_linearbn.weight = linear.weight
qat_linearbn.bias = linear.bias
qat_linearbn.bn.weight = bn.weight
qat_linearbn.bn.bias = bn.bias
qat_linearbn.bn.running_mean = bn.running_mean
qat_linearbn.bn.running_var = bn.running_var
qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked
return qat_linearbn
def to_float(self):
linear = torch.nn.Linear(self.in_features, self.out_features)
assert self.bn.running_var is not None and self.bn.running_mean is not None
linear.weight, linear.bias = fuse_linear_bn_weights(
self.weight,
self.bias,
self.bn.running_mean,
self.bn.running_var,
self.bn.eps,
self.bn.weight,
self.bn.bias,
)
return linear

View File

@ -0,0 +1,51 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.qat as nnqat
import torch.nn.functional as F
class LinearReLU(nnqat.Linear, nni._FusedModule):
r"""
A LinearReLU module fused from Linear and ReLU modules, attached with
FakeQuantize modules for weight, used in
quantization aware training.
We adopt the same interface as :class:`torch.nn.Linear`.
Similar to `torch.ao.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.qat.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, qconfig=None):
super().__init__(in_features, out_features, bias, qconfig)
def forward(self, input):
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
def to_float(self):
linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias is not None
)
linear.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
linear.bias = torch.nn.Parameter(self.bias.detach())
relu = torch.nn.ReLU()
return torch.ao.nn.intrinsic.LinearReLU(linear, relu)

View File

@ -0,0 +1,15 @@
from .modules import * # noqa: F403
__all__ = [
"BNReLU2d",
"BNReLU3d",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"LinearReLU",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,6 @@
from .linear_relu import LinearReLU
__all__ = [
"LinearReLU",
]

View File

@ -0,0 +1,60 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized.dynamic as nnqd
__all__ = ["LinearReLU"]
class LinearReLU(nnqd.Linear):
r"""
A LinearReLU module fused from Linear and ReLU modules that can be used
for dynamic quantization.
Supports both, FP16 and INT8 quantization.
We adopt the same interface as :class:`torch.ao.nn.quantized.dynamic.Linear`.
Attributes:
Same as torch.ao.nn.quantized.dynamic.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.quantized.dynamic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._packed_params.dtype == torch.qint8:
# TODO check if we should set reduce_rage = True by default here
Y = torch.ops.quantized.linear_relu_dynamic(
x, self._packed_params._packed_params, reduce_range=True
)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_relu_dynamic_fp16(
x, self._packed_params._packed_params
)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear relu!")
return Y.to(x.dtype)
def _get_name(self):
return "DynamicQuantizedLinearReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qlinear_relu):
return super().from_reference(ref_qlinear_relu[0])

View File

@ -0,0 +1,18 @@
from .bn_relu import BNReLU2d, BNReLU3d
from .conv_add import ConvAdd2d, ConvAddReLU2d
from .conv_relu import ConvReLU1d, ConvReLU2d, ConvReLU3d
from .linear_relu import LinearLeakyReLU, LinearReLU, LinearTanh
__all__ = [
"LinearReLU",
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
"BNReLU2d",
"BNReLU3d",
"LinearLeakyReLU",
"LinearTanh",
"ConvAdd2d",
"ConvAddReLU2d",
]

View File

@ -0,0 +1,105 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
__all__ = ["BNReLU2d", "BNReLU3d"]
class BNReLU2d(nnq.BatchNorm2d):
r"""
A BNReLU2d module is a fused module of BatchNorm2d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
Attributes:
Same as torch.ao.nn.quantized.BatchNorm2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super().__init__(
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return torch.ops.quantized.batch_norm2d_relu(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedBNReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU2d
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)
class BNReLU3d(nnq.BatchNorm3d):
r"""
A BNReLU3d module is a fused module of BatchNorm3d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
Attributes:
Same as torch.ao.nn.quantized.BatchNorm3d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super().__init__(
num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
return torch.ops.quantized.batch_norm3d_relu(
input,
self.weight,
self.bias,
self.running_mean,
self.running_var,
self.eps,
self.scale,
self.zero_point,
)
def _get_name(self):
return "QuantizedBNReLU3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU3d
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)

View File

@ -0,0 +1,145 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
import torch.nn.functional as F
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
class ConvAdd2d(nnq.Conv2d):
r"""
A ConvAdd2d module is a fused module of Conv2d and Add
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAdd2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input, extra_input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_add(
input, extra_input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvAdd2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvAddReLU2d(nnq.Conv2d):
r"""
A ConvAddReLU2d module is a fused module of Conv2d, Add and Relu
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvAddReLU2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input, extra_input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_add_relu(
input, extra_input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvAddReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)

View File

@ -0,0 +1,263 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic
import torch.ao.nn.intrinsic.qat
import torch.ao.nn.quantized as nnq
import torch.nn.functional as F
from torch.nn.utils import fuse_conv_bn_weights
__all__ = [
"ConvReLU1d",
"ConvReLU2d",
"ConvReLU3d",
]
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding
# TODO: factor out the common parts to ConvNd
class ConvReLU1d(nnq.Conv1d):
r"""
A ConvReLU1d module is a fused module of Conv1d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`.
Attributes:
Same as torch.ao.nn.quantized.Conv1d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU1d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != "zeros":
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv1d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU1d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU1d
), "BatchNorm1d should be fused into Conv1d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvReLU2d(nnq.Conv2d):
r"""
A ConvReLU2d module is a fused module of Conv2d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`.
Attributes:
Same as torch.ao.nn.quantized.Conv2d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU2d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv2d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU2d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU2d
), "BatchNorm2d should be fused into Conv2d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)
class ConvReLU3d(nnq.Conv3d):
r"""
A ConvReLU3d module is a fused module of Conv3d and ReLU
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`.
Attributes: Same as torch.ao.nn.quantized.Conv3d
"""
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvReLU3d # type: ignore[assignment]
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode="zeros",
device=None,
dtype=None,
):
assert padding_mode != "reflect", "Conv3d does not support reflection padding"
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != "zeros":
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(
input, _reversed_padding_repeated_twice, mode=self.padding_mode
)
return torch.ops.quantized.conv3d_relu(
input, self._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedConvReLU3d"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight,
mod.bias,
mod.bn.running_mean,
mod.bn.running_var,
mod.bn.eps,
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(
mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
assert (
type(ref_qconv) != torch.ao.nn.intrinsic.ConvBnReLU3d
), "BatchNorm3d should be fused into Conv3d before converting to reference module"
return super().from_reference(ref_qconv[0], output_scale, output_zero_point)

View File

@ -0,0 +1,187 @@
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = [
"LinearReLU",
"LinearLeakyReLU",
"LinearTanh",
]
class LinearReLU(nnq.Linear):
r"""
A LinearReLU module fused from Linear and ReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearReLU # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_relu(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLinearReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
return super().from_reference(
ref_linear_relu[0], output_scale, output_zero_point
)
class LinearLeakyReLU(nnq.Linear):
r"""
For onednn backend only
A LinearLeakyReLU module fused from Linear and LeakyReLU modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
+ negative_slope
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearLeakyReLU(20, 30, 0.01)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearLeakyReLU # type: ignore[assignment]
def __init__(
self, in_features, out_features, negative_slope, bias=True, dtype=torch.qint8
):
super().__init__(in_features, out_features, bias, dtype)
self.negative_slope = negative_slope
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_leaky_relu(
x,
self._packed_params._packed_params,
self.scale,
self.zero_point,
self.negative_slope,
)
def _get_name(self):
return "QuantizedLinearLeakyReLU"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert (
type(mod) == nni.LinearLeakyReLU
), "Input float module should be LinearLeakyReLU"
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
leaky_relu = mod[1]
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear_leaky_relu = cls(
mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype
)
qlinear_leaky_relu.set_weight_bias(qweight, mod.bias)
qlinear_leaky_relu.scale = float(act_scale)
qlinear_leaky_relu.zero_point = int(act_zp)
return qlinear_leaky_relu
@classmethod
def from_reference(cls, ref_mod, output_scale, output_zero_point):
linear = ref_mod[0]
leaky_relu = ref_mod[1]
qlinear_leaky_relu = cls(
linear.in_features, linear.out_features, leaky_relu.negative_slope
)
qweight = linear.get_quantized_weight()
qlinear_leaky_relu.set_weight_bias(qweight, linear.bias)
qlinear_leaky_relu.scale = float(output_scale)
qlinear_leaky_relu.zero_point = int(output_zero_point)
return qlinear_leaky_relu
class LinearTanh(nnq.Linear):
r"""
A LinearTanh module fused from Linear and Tanh modules
We adopt the same interface as :class:`torch.ao.nn.quantized.Linear`.
Attributes:
Same as torch.ao.nn.quantized.Linear
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.intrinsic.LinearTanh(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_FLOAT_MODULE = nni.LinearTanh # type: ignore[assignment]
def __init__(self, in_features, out_features, bias=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear_tanh(
x, self._packed_params._packed_params, self.scale, self.zero_point
)
def _get_name(self):
return "QuantizedLinearTanh"
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearTanh, "Input float module should be LinearTanh"
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
activation_post_process = mod.activation_post_process
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator]
assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8"
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear_tanh.set_weight_bias(qweight, mod.bias)
qlinear_tanh.scale = float(act_scale)
qlinear_tanh.zero_point = int(act_zp)
return qlinear_tanh
@classmethod
def from_reference(cls, ref_mod, output_scale, output_zero_point):
linear = ref_mod[0]
qlinear_tanh = cls(linear.in_features, linear.out_features)
qweight = linear.get_quantized_weight()
qlinear_tanh.set_weight_bias(qweight, linear.bias)
qlinear_tanh.scale = float(output_scale)
qlinear_tanh.zero_point = int(output_zero_point)
return qlinear_tanh