I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,4 @@
from .linear import Linear
__all__ = ["Linear"]

View File

@ -0,0 +1,35 @@
# mypy: allow-untyped-defs
import torch
__all__ = ["Linear"]
class Linear(torch.ao.nn.qat.Linear):
r"""
A linear module attached with FakeQuantize modules for weight,
used for dynamic quantization aware training.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
default.
"""
def __init__(
self,
in_features,
out_features,
bias=True,
qconfig=None,
device=None,
dtype=None,
) -> None:
super().__init__(in_features, out_features, bias, qconfig, device, dtype)
if not torch.ao.quantization.qconfig._activation_is_memoryless(qconfig):
raise ValueError(
"Dynamic QAT requires a memoryless observer."
+ "This means a MovingAverage observer with averaging constant equal to 1"
)

View File

@ -0,0 +1,13 @@
from .conv import Conv1d, Conv2d, Conv3d
from .embedding_ops import Embedding, EmbeddingBag
from .linear import Linear
__all__ = [
"Linear",
"Conv1d",
"Conv2d",
"Conv3d",
"Embedding",
"EmbeddingBag",
]

View File

@ -0,0 +1,310 @@
# mypy: allow-untyped-defs
from typing import Tuple, TypeVar, Union
import torch
import torch.nn as nn
from torch.ao.nn.intrinsic import _FusedModule
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _pair, _single, _triple
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd)
class _ConvNd(nn.modules.conv._ConvNd):
_FLOAT_MODULE = MOD
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Tuple[int, ...],
dilation: Tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
nn.modules.conv._ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
bias,
padding_mode,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args:
`mod`: a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if issubclass(type(mod), _FusedModule):
mod = mod[0] # type: ignore[index]
qconfig = mod.qconfig
qat_conv = cls(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
stride=mod.stride,
padding=mod.padding,
dilation=mod.dilation,
groups=mod.groups,
bias=mod.bias is not None,
padding_mode=mod.padding_mode,
qconfig=qconfig,
)
qat_conv.weight = mod.weight
qat_conv.bias = mod.bias
return qat_conv
def to_float(self):
"""This works for both single qat conv, and the qat conv - relu modules
to convert the qat module to a floating point module
"""
cls = type(self)
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined, operator]
self.in_channels,
self.out_channels,
self.kernel_size, # type: ignore[arg-type]
self.stride, # type: ignore[arg-type]
self.padding, # type: ignore[arg-type]
self.dilation, # type: ignore[arg-type]
self.groups,
self.bias is not None,
self.padding_mode,
)
conv.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
conv.bias = torch.nn.Parameter(self.bias.detach())
# conv relu
if issubclass(cls, _FusedModule):
modules = [conv]
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
modules.append(relu)
fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
fused.train(self.training)
return fused
else:
return conv
class Conv1d(_ConvNd, nn.Conv1d):
r"""
A Conv1d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as :class:`~torch.nn.Conv1d`
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv1d
_FLOAT_CONV_MODULE = nn.Conv1d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _single(kernel_size)
stride_ = _single(stride)
padding_ = padding if isinstance(padding, str) else _single(padding)
dilation_ = _single(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_single(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class Conv2d(_ConvNd, nn.Conv2d):
r"""
A Conv2d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv2d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
for documentation.
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv2d
_FLOAT_CONV_MODULE = nn.Conv2d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_pair(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)
class Conv3d(_ConvNd, nn.Conv3d):
r"""
A Conv3d module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Conv3d`, please see
https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
for documentation.
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
default.
Attributes:
weight_fake_quant: fake quant module for weight
"""
_FLOAT_MODULE = nn.Conv3d
_FLOAT_CONV_MODULE = nn.Conv3d
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
padding: Union[str, _size_3_t] = 0,
dilation: _size_3_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
qconfig=None,
device=None,
dtype=None,
) -> None:
kernel_size_ = _triple(kernel_size)
stride_ = _triple(stride)
padding_ = padding if isinstance(padding, str) else _triple(padding)
dilation_ = _triple(dilation)
super().__init__(
in_channels,
out_channels,
kernel_size_,
stride=stride_,
padding=padding_,
dilation=dilation_,
transposed=False,
output_padding=_triple(0),
groups=groups,
bias=bias,
padding_mode=padding_mode,
qconfig=qconfig,
device=device,
dtype=dtype,
)
def forward(self, input):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
)

View File

@ -0,0 +1,248 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
__all__ = ["Embedding", "EmbeddingBag"]
class Embedding(nn.Embedding):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Embedding`, please see
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
for documentation.
Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Embedding
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
qconfig=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, (
"Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(qconfig.weight().qscheme)
)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input) -> Tensor:
return F.embedding(
input,
self.weight_fake_quant(self.weight),
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, (
"Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(weight_qscheme)
)
qconfig = mod.qconfig
qat_embedding_bag = cls(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
mod.weight,
qconfig=qconfig,
)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
None,
)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag
class EmbeddingBag(nn.EmbeddingBag):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
for documentation.
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.EmbeddingBag
def __init__(
self,
num_embeddings,
embedding_dim,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
mode="mean",
sparse=False,
_weight=None,
include_last_offset=False,
padding_idx=None,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(
num_embeddings,
embedding_dim,
max_norm,
norm_type,
scale_grad_by_freq,
mode,
sparse,
_weight,
include_last_offset,
padding_idx,
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, (
"Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(qconfig.weight().qscheme)
)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
return F.embedding_bag(
input,
self.weight_fake_quant(self.weight),
offsets,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
per_sample_weights,
self.include_last_offset,
self.padding_idx,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
assert weight_qscheme == torch.per_channel_affine_float_qparams, (
"Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got "
+ str(weight_qscheme)
)
qconfig = mod.qconfig
qat_embedding_bag = cls(
mod.num_embeddings,
mod.embedding_dim,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.mode,
mod.sparse,
mod.weight,
mod.include_last_offset,
mod.padding_idx,
qconfig=qconfig,
)
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.EmbeddingBag(
self.num_embeddings,
self.embedding_dim,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.mode,
self.sparse,
None,
self.include_last_offset,
self.padding_idx,
)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag

View File

@ -0,0 +1,96 @@
# mypy: allow-untyped-defs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.nn.intrinsic import LinearReLU
from torch.nn.utils.parametrize import (
is_parametrized,
transfer_parametrizations_and_params,
type_before_parametrizations,
)
__all__ = ["Linear"]
class Linear(nn.Linear):
r"""
A linear module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Linear
def __init__(
self,
in_features,
out_features,
bias=True,
qconfig=None,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(in_features, out_features, bias, **factory_kwargs)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, (
" qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
if type_before_parametrizations(mod) == LinearReLU:
mod = mod[0]
qconfig = mod.qconfig
qat_linear = cls(
mod.in_features,
mod.out_features,
bias=mod.bias is not None,
qconfig=qconfig,
)
if is_parametrized(mod, "weight"):
transfer_parametrizations_and_params(mod, qat_linear, "weight")
else:
qat_linear.weight = mod.weight
if is_parametrized(mod, "bias"):
transfer_parametrizations_and_params(mod, qat_linear, "bias")
else:
qat_linear.bias = mod.bias
return qat_linear
def to_float(self):
linear = torch.nn.Linear(
self.in_features, self.out_features, self.bias is not None
)
linear.weight = torch.nn.Parameter(self.weight.detach())
if self.bias is not None:
linear.bias = torch.nn.Parameter(self.bias.detach())
linear.train(self.training)
return linear