I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,22 @@
from .quantizer import (
DerivedQuantizationSpec,
EdgeOrNode,
FixedQParamsQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
QuantizationSpecBase,
Quantizer,
SharedQuantizationSpec,
)
__all__ = [
"EdgeOrNode",
"Quantizer",
"QuantizationSpecBase",
"QuantizationSpec",
"FixedQParamsQuantizationSpec",
"SharedQuantizationSpec",
"DerivedQuantizationSpec",
"QuantizationAnnotation",
]

View File

@ -0,0 +1,79 @@
from __future__ import annotations
from typing import Dict, List, TYPE_CHECKING
from .quantizer import QuantizationAnnotation, Quantizer
if TYPE_CHECKING:
import torch
from torch.fx import Node
__all__ = [
"ComposableQuantizer",
]
class ComposableQuantizer(Quantizer):
"""
ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
maybe supported by one quantizer while linear layers and other ops might be supported by another
quantizer.
ComposableQuantizer is initialized with a list of `Quantizer` instances.
The order of the composition matters since that is the order in which the quantizers will be
applies.
Example:
```
embedding_quantizer = EmbeddingQuantizer()
linear_quantizer = MyLinearQuantizer()
xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
prepared_m = prepare_pt2e(model, composed_quantizer)
```
"""
def __init__(self, quantizers: List[Quantizer]):
super().__init__()
self.quantizers = quantizers
self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
def _record_and_validate_annotations(
self, gm: torch.fx.GraphModule, quantizer: Quantizer
) -> None:
for n in gm.graph.nodes:
if "quantization_annotation" in n.meta:
# check if the annotation has been changed by
# comparing QuantizationAnnotation object id
if n in self._graph_annotations and (
id(self._graph_annotations[n])
!= id(n.meta["quantization_annotation"])
):
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
)
else:
self._graph_annotations[n] = n.meta["quantization_annotation"]
else:
if n in self._graph_annotations:
raise RuntimeError(
f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
for quantizer in self.quantizers:
quantizer.annotate(model)
self._record_and_validate_annotations(model, quantizer)
return model
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for quantizer in self.quantizers:
model = quantizer.transform_for_annotation(model)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass

View File

@ -0,0 +1,98 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import copy
from typing import List, Set
import torch
import torch.nn.functional as F
from torch.ao.quantization.observer import PerChannelMinMaxObserver
from torch.ao.quantization.quantizer.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OperatorConfig,
OperatorPatternType,
QuantizationConfig,
)
__all__ = [
"get_embedding_operators_config",
"EmbeddingQuantizer",
]
def get_embedding_operators_config() -> OperatorConfig:
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12),
)
quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None)
ops: List[OperatorPatternType] = [[torch.nn.Embedding]]
ops.append([F.embedding])
supported_config_and_operators = OperatorConfig(
config=quantization_config, operators=ops
)
return copy.deepcopy(supported_config_and_operators)
class EmbeddingQuantizer(Quantizer):
def __init__(self) -> None:
super().__init__()
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = {
spec for spec, _ in cls.get_supported_operators()
}
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: QuantizationConfig
) -> List[OperatorPatternType]:
for config, ops in cls.get_supported_operators():
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
self._annotate_embedding_ops(model.graph)
return model
def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
embedding_config: OperatorConfig = get_embedding_operators_config()
for node in graph.nodes:
# Keep node parsing based annotations instead of module partitioners
# just as an example of alternate ways of annotating
if (
node.op == "call_function"
and node.target == torch.ops.aten.embedding.default
):
if embedding_config.config.weight is None:
raise ValueError(
"Embedding config must have a valid weight quantization spec."
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
node.args[0]: embedding_config.config.weight,
}
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return [get_embedding_operators_config()]

View File

@ -0,0 +1,161 @@
# mypy: allow-untyped-defs
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.ao.quantization import ObserverOrFakeQuantize
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
__all__ = [
"Quantizer",
"QuantizationSpecBase",
"QuantizationSpec",
"FixedQParamsQuantizationSpec",
"EdgeOrNode",
"SharedQuantizationSpec",
"DerivedQuantizationSpec",
"QuantizationAnnotation",
]
class QuantizationSpecBase(ABC): # noqa: B024
"""Base class for different types of quantization specs that allows users to
specify how to quantize a Tensor (input/output of a Node) in the model
"""
@dataclass(eq=True, frozen=True)
class QuantizationSpec(QuantizationSpecBase):
"""Quantization spec for common operators that allows user to specify how to
quantize a Tensor, this includes dtype, quant_min, quant_max etc.
"""
dtype: torch.dtype
# observer or fake_quantize constructor such as
# MinMaxObserver, PerChannelHistogramObserver etc.
# or we can attach some custom args to them
# e.g. MinMaxObserver.with_args(eps=eps)
observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
ch_axis: Optional[int] = None
is_dynamic: bool = False
def __post_init__(self):
# TODO: add init for quant_min/quant_max
# quant_min must be less than quant_max
if (
self.quant_min is not None
and self.quant_max is not None
and self.quant_min > self.quant_max
):
raise ValueError(
f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
)
# ch_axis must be less than the number of channels
# but no way to check here. Just check that it is not < 0.
if self.ch_axis is not None and self.ch_axis < 0:
raise ValueError("Ch_axis is < 0.")
@dataclass(eq=True, frozen=True)
class FixedQParamsQuantizationSpec(QuantizationSpecBase):
dtype: torch.dtype
scale: float
zero_point: int
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
is_dynamic: bool = False
"""
The way we refer to other points of quantization in the graph will be either
an input edge or an output value
input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
output value is an fx Node
"""
EdgeOrNode = Union[Tuple[Node, Node], Node]
EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
@dataclass(eq=True, frozen=True)
class SharedQuantizationSpec(QuantizationSpecBase):
"""
Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
"""
# the edge or node to share observer or fake quant instances with
edge_or_node: EdgeOrNode
@dataclass(eq=True, frozen=True)
class DerivedQuantizationSpec(QuantizationSpecBase):
"""Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
derived_from: List[EdgeOrNode]
derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
dtype: torch.dtype
quant_min: Optional[int] = None
quant_max: Optional[int] = None
qscheme: Optional[torch.qscheme] = None
ch_axis: Optional[int] = None
is_dynamic: bool = False
@dataclass
class QuantizationAnnotation:
"""How are input arguemnt or output should be quantized,
expressed as QuantizationSpec, this corresponds to how a Tensor in the
operator Graph is observed (PTQ) or fake quantized (QAT)
"""
# a map from torch.fx.Node to a type of QuantizationSpecBase
input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
default_factory=dict
)
# How the output of this node is quantized, expressed as QuantizationSpec
# TODO: change the value to QuantizationSpec in a separate PR
output_qspec: Optional[QuantizationSpecBase] = None
# For a Node: node1 and edge: (node1, node2), since they are observing the same
# Tensor, we may want to implicitly share observers, this flag allows people to
# turn off this behavior for the output of the node
allow_implicit_sharing: bool = True
# whether the node is annotated or not
_annotated: bool = False
class Quantizer(ABC):
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Allows for user defined transforms to run before annotating the graph.
This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
For example quantizer can
a) decompose a compound operator like scaled dot product attention,
into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
or b) transform scalars to tensor to allow quantizing scalares.
Note: this is an optional method
"""
return model
# annotate nodes in the graph with observer or fake quant constructors
# to convey the desired way of quantization
@abstractmethod
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
pass
# validate the annotated graph is supported by the backend
@abstractmethod
def validate(self, model: torch.fx.GraphModule) -> None:
pass

View File

@ -0,0 +1,83 @@
# mypy: allow-untyped-defs
from typing import List
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
from torch.fx import Node
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
quantization_annotation = node.meta.get(
"quantization_annotation", QuantizationAnnotation()
)
if quantization_annotation.input_qspec_map is None:
quantization_annotation.input_qspec_map = {}
quantization_annotation.input_qspec_map[input_node] = qspec
node.meta["quantization_annotation"] = quantization_annotation
def _annotate_output_qspec(node: Node, qspec):
quantization_annotation = node.meta.get(
"quantization_annotation", QuantizationAnnotation()
)
quantization_annotation.output_qspec = qspec
node.meta["quantization_annotation"] = quantization_annotation
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
"""
This utility is used to handle cases when dynami_shape=True tracing leads
to symint nodes in the pattern of linear module. In those cases, we need to
distinguish between the nodes that are in input for just extracting value of
some dimentions (and symint nodes) vs. the one that is activation.
For example:
graph(x, y, weight):
size_0 = torch.ops.aten.sym_size([x], [0])
size_1 = torch.ops.aten.sym_size([y], [1])
view_size = size_0 * size_1
size_3 = torch.ops.aten.sym_size([x], [2])
vie_out = torch.ops.aten.view(x, [view_size, size_3])
return mm(view_out, weight)
In the example above y node is not actual input. It exist only to extract size_1
"""
if _is_sym_size_node(node):
return True
return all(
((user not in partition_nodes) or _is_sym_size_node(user))
for user in node.users
)
def _get_module_name_filter(module_name: str):
"""Get the module_name_filter function for a given module name, the filter accepts
a node and checks if the node comes from a module that has certain module name
For example:
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
>> module_name_filter = _get_module_name_filter("blocks.sub")
>> print(module_name_filter(node))
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
"""
def module_name_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
# get_attr nodes doesn't have nn_module_stack?
nn_module_stack = n.meta.get("nn_module_stack", {})
def _normalize_path(n):
prefix = 0
# TODO This is non standard behavior and should be removed when we migrate off capture_pre_autograd_graph.
if n.startswith("L['self']."):
prefix = len("L['self'].")
return n[prefix:]
names = [_normalize_path(n) for n, _ in nn_module_stack.values()]
return module_name in names
return module_name_filter

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,436 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import copy
import functools
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING
import torch
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
HistogramObserver,
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
PlaceholderObserver,
)
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
_convert_scalars_to_attrs,
OP_TO_ANNOTATOR,
OperatorConfig,
OperatorPatternType,
propagate_annotation,
QuantizationConfig,
)
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
from torch.fx import Node
__all__ = [
"XNNPACKQuantizer",
"get_symmetric_quantization_config",
]
def _get_dynamo_graph(function: Callable, inputs) -> torch.fx.Graph:
gm, _ = torchdynamo.export(function, aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
return gm.graph
def _get_linear_patterns(input_size: List[int]):
in_channels = input_size[-1]
out_channels = 8 # hard coding but this should not matter
weight = torch.ones((out_channels, in_channels))
bias = torch.ones((out_channels,))
act = torch.ones(input_size)
def linear_op(act, weight, bias=None):
return F.linear(act, weight, bias)
pattern_w_bias = _get_dynamo_graph(linear_op, (act, weight, bias))
pattern_wo_bias = _get_dynamo_graph(linear_op, (act, weight))
return [pattern_w_bias, pattern_wo_bias]
def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
supported_operators: Dict[str, List[OperatorPatternType]] = {
# Both conv and linear should be able to handle relu + hardtanh fusion since
# those are clamp ops
"conv2d": [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, F.relu],
[F.conv2d, torch.nn.ReLU],
[F.conv2d, F.relu],
],
"linear": [[torch.nn.Linear], [F.linear]],
"add": [[torch.add]],
"adaptive_avg_pool2d": [
[torch.nn.AdaptiveAvgPool2d],
[F.adaptive_avg_pool2d],
],
}
return copy.deepcopy(supported_operators)
def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
supported_config_and_operators: List[OperatorConfig] = []
for quantization_config in [
get_symmetric_quantization_config(),
get_symmetric_quantization_config(is_qat=True),
get_symmetric_quantization_config(is_per_channel=True),
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
]:
ops = _supported_symmetric_quantized_operators()
for pattern_list in ops.values():
supported_config_and_operators.append(
OperatorConfig(quantization_config, pattern_list)
)
return copy.deepcopy(supported_config_and_operators)
@functools.lru_cache
def get_symmetric_quantization_config(
is_per_channel: bool = False,
is_qat: bool = False,
is_dynamic: bool = False,
act_qmin: int = -128,
act_qmax: int = 127,
weight_qmin: int = -127,
weight_qmax: int = 127,
):
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if is_dynamic:
act_observer_or_fake_quant_ctr = FakeQuantize
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
averaging_constant=1
)
extra_args["observer"] = dynamic_quant_observer
else:
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
else:
if is_dynamic:
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
else:
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=act_qmin,
quant_max=act_qmax,
qscheme=torch.per_tensor_affine,
is_dynamic=is_dynamic,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
**extra_args,
),
)
weight_qscheme = (
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
MinMaxObserver
)
if is_qat:
# TODO: qat + per channel?
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
elif is_per_channel:
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
extra_args: Dict[str, Any] = {"eps": 2**-12}
if is_qat:
if weight_qscheme == torch.per_tensor_symmetric:
extra_args["observer"] = MovingAverageMinMaxObserver
else:
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=weight_qmin,
quant_max=weight_qmax,
qscheme=weight_qscheme,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)
bias_quantization_spec = None
if is_dynamic:
quantization_config = QuantizationConfig(
act_quantization_spec,
None,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
else:
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
is_qat,
)
return quantization_config
def _get_supported_config_and_operators() -> List[OperatorConfig]:
return _get_supported_symmetric_config_and_operators()
def _get_module_type_filter(tp: Callable):
"""Get the module_type_filter function for a given module type, the filter accepts
a node and checks if the node comes from a module that has certain module type
For example:
node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
>> print(module_type_filter(node))
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
"""
tp_str = tp.__module__ + "." + tp.__qualname__
def module_type_filter(n: Node) -> bool:
# example: {
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
# }
nn_module_stack = n.meta.get("nn_module_stack", {})
types = []
for _, t in nn_module_stack.values():
# export() returns str, but older APIs (e.g. capture_pre_autograd_graph)
# return type. Handle both cases.
if isinstance(t, type):
t = t.__module__ + "." + t.__qualname__
types.append(t)
return tp_str in types
return module_type_filter
def _get_not_module_type_or_name_filter(
tp_list: List[Callable], module_name_list: List[str]
) -> Callable[[Node], bool]:
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
def not_module_type_or_name_filter(n: Node) -> bool:
return not any(f(n) for f in module_type_filters + module_name_list_filters)
return not_module_type_or_name_filter
class XNNPACKQuantizer(Quantizer):
supported_config_and_operators = _get_supported_config_and_operators()
STATIC_QAT_ONLY_OPS = [
"conv_bn_relu",
"conv_bn",
"conv_transpose_bn_relu",
"conv_transpose_bn",
]
# static quantization ops (both PTQ and QAT)
# Preserve the order that fusions come before singular ops
STATIC_OPS = [
"linear_relu",
"linear",
"conv_relu",
"conv",
"conv_transpose_relu",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",
"add_relu",
"add",
"mul_relu",
"mul",
"cat",
]
DYNAMIC_OPS = [
"linear",
]
def __init__(self) -> None:
super().__init__()
self.global_config: Optional[QuantizationConfig] = None
self.operator_type_config: Dict[
torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
] = {}
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
@classmethod
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
op_configs: Set[QuantizationConfig] = {
spec for spec, _ in cls.supported_config_and_operators
}
return list(op_configs)
@classmethod
def get_supported_operator_for_quantization_config(
cls, quantization_config: Optional[QuantizationConfig]
) -> List[OperatorPatternType]:
if quantization_config is None:
all_ops = []
for _, ops in cls.supported_config_and_operators:
all_ops.extend(ops)
return all_ops
for config, ops in cls.supported_config_and_operators:
# note: this assumes each entry in cls.supported_spec_and_operators
# corresponds to one spec, e.g. we don't have
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
# where the first and second entry have the same spec but did not
# merge the op list
if config == quantization_config:
return ops
return []
def set_global(self, quantization_config: QuantizationConfig) -> XNNPACKQuantizer:
self.global_config = quantization_config
return self
def set_operator_type(
self,
operator_type: torch._ops.OpOverloadPacket,
quantization_config: QuantizationConfig,
) -> XNNPACKQuantizer:
self.operator_type_config[operator_type] = quantization_config
return self
def set_module_type(
self, module_type: Callable, quantization_config: QuantizationConfig
):
"""Set quantization_config for a submodule with type: `module_type`, for example:
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
patterns in the submodule with this module type with the given `quantization_config`
"""
self.module_type_config[module_type] = quantization_config
return self
def set_module_name(
self, module_name: str, quantization_config: Optional[QuantizationConfig]
):
"""Set quantization_config for a submodule with name: `module_name`, for example:
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
patterns in the submodule with this module name with the given `quantization_config`
"""
assert (
quantization_config is not None
), " quantization_config == None is not supported yet"
self.module_name_config[module_name] = quantization_config
return self
def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""Transforms scalar values to tensor attributes"""
return _convert_scalars_to_attrs(model)
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""just handling global spec for now"""
# hacked for handling dynamic linear quant. will fix later.
if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]
model = self._annotate_for_dynamic_quantization_config(model)
else:
model = self._annotate_for_static_quantization_config(model)
propagate_annotation(model)
return model
def _annotate_all_static_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
if quantization_config.is_qat:
for op in self.STATIC_QAT_ONLY_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
for op in self.STATIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_all_dynamic_patterns(
self,
model: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> torch.fx.GraphModule:
# TODO: implement the support for None to be canceling out previous annotations
if quantization_config is None:
return model
for op in self.DYNAMIC_OPS:
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
return model
def _annotate_for_static_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_name_filter(module_name)
)
tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_static_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_static_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model
def _annotate_for_dynamic_quantization_config(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
module_name_list = list(self.module_name_config.keys())
for module_name, config in self.module_name_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_name_filter(module_name)
)
tp_list = list(self.module_type_config.keys())
for module_type, config in self.module_type_config.items():
self._annotate_all_dynamic_patterns(
model, config, _get_module_type_filter(module_type)
)
self._annotate_all_dynamic_patterns(
model,
self.global_config,
_get_not_module_type_or_name_filter(tp_list, module_name_list),
)
return model
def validate(self, model: torch.fx.GraphModule) -> None:
pass
@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
return cls.supported_config_and_operators

File diff suppressed because it is too large Load Diff