I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from .convert import convert
from .fuse import fuse
from .prepare import prepare

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,951 @@
# mypy: allow-untyped-defs
import operator
import warnings
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.ao.nn.intrinsic as nni
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
from torch.ao.quantization.observer import (
_with_args,
ObserverBase,
PerChannelMinMaxObserver,
)
from torch.ao.quantization.utils import _parent_name, check_min_max_valid
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import (
get_new_attr_name_with_prefix,
maybe_get_next_module,
node_arg_is_weight,
)
CUSTOM_MODULE_SUPP_LIST: List[Any] = []
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
"""Reshapes the scale so that we can multiply it to the input by the given axis."""
new_shape = [1] * input.ndim
new_shape[axis] = input.size(axis)
return scale.view(new_shape)
qsheme_mapping_per_tensor_to_per_channel = {
torch.per_tensor_affine: torch.per_channel_affine,
torch.per_tensor_symmetric: torch.per_channel_symmetric,
}
class _InputEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of input columns, and
computing the quantization parameters for the overall min/max input values.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
This observer is intended to be used along with a WeightEqualizationObserver
to calculate the equalization scale.
"""
def __init__(
self,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
super().__init__()
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
raise TypeError("Input qscheme must be per-tensor")
self.dtype = dtype
self.qscheme = qscheme
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
self.input_obs = PerChannelMinMaxObserver(
ch_axis=1,
dtype=dtype,
qscheme=per_channel_qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.equalization_scale = torch.tensor(1)
self.equalization_shape: List[int] = []
def forward(self, x_orig):
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
# Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
self.equalization_shape = [1] * x_orig.ndim
self.equalization_shape[1] = x_orig.size(1)
return self.input_obs(x_orig)
def get_input_minmax(self):
return (self.input_obs.min_val, self.input_obs.max_val)
def set_equalization_scale(self, equalization_scale):
# Reshape the equalization scale along axis=1 so that it can be
# multiplied with the input along axis=1
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
return
self.equalization_scale = torch.reshape(
equalization_scale, self.equalization_shape
)
def calculate_scaled_minmax(self):
r"""Returns the scaled min/max inputs"""
if (
self.equalization_scale.nelement() == 1
and self.equalization_scale == torch.tensor(1)
):
warnings.warn(
"Must call calculate_equalization_scale before calling calculate_scaled_minmax. "
+ "Will not scale the next quantization observer."
)
return None, None
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
# index
(min_inputs, max_inputs) = self.get_input_minmax()
equalization_scale_reshaped = reshape_scale(
self.equalization_scale, 0, min_inputs
)
min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
return min_input_scaled, max_input_scaled
with_args = classmethod(_with_args)
class _WeightEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of weight columns and
rows, and computing the quantization parameters for the weight rows.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
to record the running minimum and maximum of columns of incoming weight
tensors. This observer is intended to be used along with an
InputEqualizationObserver to calculate the equalization scale.
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
"""
def __init__(
self,
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
super().__init__()
self.dtype = dtype
self.qscheme = qscheme
self.ch_axis = 1
per_channel_qscheme = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
self.weight_col_obs = PerChannelMinMaxObserver(
ch_axis=1,
dtype=dtype,
qscheme=per_channel_qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.equalization_scale = torch.tensor(1)
def forward(self, w_orig):
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
return self.weight_col_obs(w_orig)
def get_weight_col_minmax(self):
return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
with_args = classmethod(_with_args)
def calculate_equalization_scale(
input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver
) -> torch.Tensor:
r"""Calculates the equalization scale and sets the equalization_scale value
in the observers.
Args:
input_obs: Observer that tracks the ranges for the input columns
weight_obs: Observer that tracks the ranges for the weight columns
"""
(min_inputs, max_inputs) = input_obs.get_input_minmax()
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (
check_min_max_valid(min_inputs, max_inputs)
and check_min_max_valid(min_weights, max_weights)
):
warnings.warn(
"Must run observer before calling calculate_equalization_scale. "
+ "Returning default equalization scale torch.tensor(1)."
)
return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. "
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
)
equalization_scale = torch.sqrt(
(max_weights - min_weights) / (max_inputs - min_inputs)
)
# Replace all 'inf', 'nan', 0's with 1s to prevent errors
equalization_scale[equalization_scale == 0.0] = 1
equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
return equalization_scale
class EqualizationQConfig(
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
):
"""
Describes how to quantize a layer or a part of the network specifically for
input-weight equalization by providing settings (observer classes) for
inputs, outputs, and weights.
Note that EqualizationQConfig needs to contain observer **classes** (like
MinMaxObserver) or a callable that returns instances on invocation, not the
concrete observer instances themselves.
Quantization function will instantiate observers multiple times for each of
the layers.
Observer classes have usually reasonable default arguments, but they can be
overwritten with `with_args` method (that behaves like functools.partial):
my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
"""
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError(
"EqualizationQConfig received observer instance, please pass observer class instead. "
+ "Use MyObserver.with_args(x=1) to override arguments to constructor if needed"
)
self = super().__new__(cls, input_activation, weight)
return self
input_equalization_observer = _InputEqualizationObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
)
weight_equalization_observer = _WeightEqualizationObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)
default_equalization_qconfig = EqualizationQConfig(
input_activation=input_equalization_observer, weight=weight_equalization_observer
)
def fused_module_supports_equalization(module) -> bool:
"""Checks if the fused node supports equalization."""
return type(module) in [
nni.LinearReLU,
nni.ConvReLU1d,
nni.ConvReLU2d,
nni.ConvReLU3d,
]
def nn_module_supports_equalization(module) -> bool:
"""Checks if the torch.nn node supports equalization."""
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
def custom_module_supports_equalization(module) -> bool:
"""Checks if the custom node supports equalization."""
return type(module) in CUSTOM_MODULE_SUPP_LIST
def node_supports_equalization(node: Node, modules) -> bool:
"""Checks if the current node supports equalization
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
"""
if node.op == "call_module":
return (
nn_module_supports_equalization(modules[str(node.target)])
or fused_module_supports_equalization(modules[str(node.target)])
or custom_module_supports_equalization(modules[str(node.target)])
)
elif node.op == "call_function":
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
return False
def is_equalization_observer(observer: nn.Module) -> bool:
return isinstance(
observer, (_InputEqualizationObserver, _WeightEqualizationObserver)
)
###############################################################################
# Functions for equalization during convert #
###############################################################################
def get_op_node_and_weight_eq_obs(
input_eq_obs_node: Node, model: GraphModule, modules: Dict[str, nn.Module]
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
"""Gets the following weight equalization observer. There should always
exist a weight equalization observer after an input equalization observer.
Returns the operation node that follows the input equalization observer node
and the weight equalization observer
"""
# Find the op node that comes directly after the input equalization observer
op_node = None
for user in input_eq_obs_node.users.keys():
if node_supports_equalization(user, modules):
op_node = user
break
assert op_node is not None
if op_node.op == "call_module":
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
model, "equalization_node_name_to_qconfig"
)
assert maybe_equalization_node_name_to_config is not None
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
weight_eq_obs = equalization_node_name_to_qconfig.get(
op_node.name, None
).weight()
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
elif op_node.op == "call_function":
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)]
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
return None, None
def maybe_get_weight_eq_obs_node(
op_node: Node, modules: Dict[str, nn.Module]
) -> Optional[Node]:
"""Gets the weight equalization observer node if it exists."""
assert op_node.op == "call_function"
for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg):
assert (
isinstance(node_arg, Node)
and node_arg.op == "call_module"
and isinstance(
modules[str(node_arg.target)], _WeightEqualizationObserver
)
)
return node_arg
return None
def maybe_get_next_input_eq_obs(
node: Node, modules: Dict[str, nn.Module]
) -> Optional[_InputEqualizationObserver]:
"""Gets the following input equalization observer if it exists.
For example, in the case of connecting linear layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
If the node being passed in is the linear1 node, then we want to return eq_obs2,
the following equalization observer for linear2.
However, if there are no connecting layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
Then we want to return None.
In the case of an unfused linear-relu layer with a connecting linear layer:
linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
Since it is unfused, we want to skip over the relu layer and return eq_obs2,
the following equalization observer for linear2.
"""
assert node_supports_equalization(node, modules)
# Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
if maybe_relu_node is None:
maybe_relu_node = maybe_get_next_module(
node, modules, target_functional_type=F.relu
)
# Locate the following output observer if it exists.
# We will skip the relu node if it exists.
maybe_obs_node = (
maybe_get_next_module(node, modules, ObserverBase)
if maybe_relu_node is None
else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
)
if maybe_obs_node is None:
return None
maybe_eq_obs_node = maybe_get_next_module(
maybe_obs_node, modules, _InputEqualizationObserver
)
if maybe_eq_obs_node is None:
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
return maybe_eq_obs
def maybe_get_next_equalization_scale(
node: Node, modules: Dict[str, nn.Module]
) -> Optional[torch.Tensor]:
"""If the next next node is an InputEqualizationObserver then we want to
return its equalization scale, else we return 1
This is used in the case where there are two connecting linear layers:
linear1 -> LinearOutObs -> InputEqObs -> linear2
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
if next_inp_eq_obs:
if (
next_inp_eq_obs.equalization_scale.nelement() == 1
and next_inp_eq_obs.equalization_scale == torch.tensor(1)
):
return None
return next_inp_eq_obs.equalization_scale
return None
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
"""Scales the following input quantization observer's min/max values by
updating the values with the scaled min/max values calculated by the input
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
input_quant_obs_node = node.args[0]
assert isinstance(input_quant_obs_node, Node)
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
return
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
if min_input_scaled is None and max_input_scaled is None:
return
input_quant_obs.min_val = min_input_scaled
input_quant_obs.max_val = max_input_scaled
def scale_weight_node(
node: Node,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
"""Scale the weights for input-weight equalization by multiplying the
weight by 1/equalization_scale and next_equalization_scale
Args:
node: Current node whose weights we want to scale
equalization_scale: Current node's calculated equalization scale
next_equalization_scale: Next node's calculated equalization scale if
the following node needs to be equalized, 1 otherwise
"""
if equalization_scale is None:
return
if fused_module_supports_equalization(modules[str(node.target)]):
op_module = modules[str(node.target)][0] # type: ignore[index]
else:
op_module = modules[str(node.target)]
assert nn_module_supports_equalization(
op_module
) or custom_module_supports_equalization(op_module)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight
assert isinstance(weight, torch.Tensor)
# Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
if next_equalization_scale is None:
op_module.weight = nn.Parameter(scaled_weight)
return
# Multiply the weights row wise by the next equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=0
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
op_module.weight = nn.Parameter(scaled_weight)
# Multiply the bias element wise by the next equalization scale
bias = op_module.bias
if bias is None:
return
assert isinstance(bias, torch.Tensor)
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
op_module.bias = nn.Parameter(scaled_bias)
def scale_weight_functional(
op_node: Node,
model: GraphModule,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
"""Scales the weight value for functional layers"""
if equalization_scale is None:
return
# From the given op_node, the path looks like:
# get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
# So we want to trace back from the op_node to get the equalization observer
# node, then the quantization observer node, and then finally the weight
# node which contains the weight values.
# Get the equalization observer node
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_eq_obs_node is None:
return
# Get the quantization observer node
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node) and isinstance(
modules[str(weight_quant_obs_node.target)], ObserverBase
)
# Get the get_attr(weight) node
weight_node = weight_quant_obs_node.args[0]
if weight_node is None:
return
assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
if next_equalization_scale is None:
setattr(modules[weight_parent_name], weight_name, scaled_weight)
return
# Multiply the weights row wise by the next equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
next_equalization_scale_reshaped = reshape_scale(
next_equalization_scale, 0, scaled_weight
)
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
# Multiply the bias element wise by the next equalization scale
bias_node = None
for node in op_node.args:
# Find the node containing the weight values
if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name:
bias_node = node
break
if bias_node is None:
return
bias_parent_name, bias_name = _parent_name(bias_node.target)
bias = getattr(modules[bias_parent_name], bias_name)
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
setattr(modules[bias_parent_name], bias_name, scaled_bias)
def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
"""Given the operation node, we want find the corresponding quantization
observer and reset its min/max values
"""
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_eq_obs_node is None:
return
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node)
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
def remove_node(model: GraphModule, node: Node, prev_node: Node):
"""Removes the given node from the model by replacing all of its users with
the given previous node
"""
# For all of the current node's users, replace the current node with
# the input quantization observer node
orig_users = list(node.users.keys())
for user_node in orig_users:
user_node.replace_input_with(node, prev_node)
# Erase the InputEqualizationObserver node
model.graph.erase_node(node)
def update_obs_for_equalization(
model: GraphModule, modules: Dict[str, nn.Module]
) -> Dict[str, _WeightEqualizationObserver]:
"""Update all of the observer's equalization scale. For each
InputEqualizationObserver, we will find the location of the next
WeightEqualizationObserver, create it, and calculate the equalization scale
based on the two observers.
We will then return a dictionary mapping operation node names to
the corresponding WeightEqualizationObservers for that operation.
"""
weight_eq_obs_dict = {}
for node in model.graph.nodes:
if node.op == "call_module" and isinstance(
modules[node.target], _InputEqualizationObserver
):
input_eq_obs = modules[node.target]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
continue
if op_node.op == "call_module":
# Calibrate the weight equalization observer since it has just
# been created
if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index]
assert nn_module_supports_equalization(module)
weight_eq_obs(module.weight)
else:
weight_eq_obs(modules[str(op_node.target)].weight)
# Calculate and set the equalization scale values
equalization_scale = calculate_equalization_scale(
input_eq_obs, weight_eq_obs
)
input_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs_dict[op_node.name] = weight_eq_obs
return weight_eq_obs_dict
def convert_eq_obs(
model: GraphModule,
modules: Dict[str, nn.Module],
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
) -> None:
"""Converts the equalization operations and updates the other nodes in the
following way:
- Removes the input equalization observers and inserts a mul operator
along with an equalization scale node wherever applicable (we do not
want to insert a mul operator between connecting linear layers).
- Updates the input quantization observers with the scaled input min/max
values.
- Scales the weights by the current and next equalization scales.
- Removes the weight equalization observer node if it exists.
Before (after prepare):
weight values
|
WeightQuantObs
|
WeightEqObs
|
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
After this function:
scaled weight values
|
equalization scale WeightQuantObs
| |
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
After convert:
equalization scale scaled weight values
| |
x -> mul -> quantize_per_tensor -> quantized::linear
Note that although the equalization observer appeared after the quantization
observer after prepare_fx, the mul node appears before the quantization node
after convert_fx. This is because placing the equalization observer after
the quantization observer in prepare_fx would allow us to keep the invariant
that the graph before the current node inserts its observers is not
modified.
Having the equalization observer before the quantization observer would also
cause some inconsistences between the ordering of the quantization and
equalization observers.
For example, a single linear layer would look like:
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
But between two connected linear layers, it would look like:
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
"""
for node in model.graph.nodes:
if node.op == "call_module" and isinstance(
modules[node.target], _InputEqualizationObserver
):
inp_quant_obs_node = node.args[0]
prev_node = inp_quant_obs_node.args[0]
# If the previous node is a layer that needs to be equalized, then
# we will remove the current node because we do not need to add any
# equalization nodes between two layers that need to be equalized
# Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
# After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
if (
node_supports_equalization(prev_node, modules)
or "relu" in prev_node.name
):
remove_node(model, node, inp_quant_obs_node)
continue
# Update the following input quantization observer's min/max values
scale_input_observer(node, modules)
# Remove the InputEqualization node and add a mul operator before
# the quantization observer node that appears before the equalization node
# Before: x -> input_quant_obs -> input_eq_obs -> linear
# After: x -> mul -> input_quant_obs -> linear
# Create a node containing the equalization scale
with model.graph.inserting_before(inp_quant_obs_node):
get_new_eq_scale_name = get_new_attr_name_with_prefix(
prev_node.name + "_equalization_scale"
)
name = get_new_eq_scale_name(modules)
setattr(model, name, modules[node.target].equalization_scale)
eq_scale_node = model.graph.create_node("get_attr", name)
# Create a node multiplying the input with the equalization scale
with model.graph.inserting_after(eq_scale_node):
inputs = (prev_node, eq_scale_node)
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
# Set the mul nod to be the input_quant_obs_node's input instead of
# the previous node
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
remove_node(model, node, inp_quant_obs_node)
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
equalization_scale = weight_eq_obs.equalization_scale
if (
equalization_scale.nelement() == 1
and equalization_scale == torch.tensor(1)
):
equalization_scale = None # type: ignore[assignment]
maybe_next_equalization_scale = maybe_get_next_equalization_scale(
node, modules
)
# Scale the weight nodes
if node.op == "call_module":
scale_weight_node(
node, modules, equalization_scale, maybe_next_equalization_scale
)
elif node.op == "call_function":
scale_weight_functional(
node,
model,
modules,
equalization_scale,
maybe_next_equalization_scale,
)
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None:
return
assert isinstance(
modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
)
# Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values
clear_weight_quant_obs_node(node, modules)
# Erase the weight equalization observer node
prev_node = weight_eq_obs_node.args[0]
remove_node(model, weight_eq_obs_node, prev_node)
else:
raise ValueError(
"Expected operation node to be 'call_module' or 'call_function"
+ f"Instead got node {node.name} as '{node.op}'."
)
def _convert_equalization_ref(model: GraphModule):
"""Reference function which applies changes needed for equalization, but
does not quantize the nodes
"""
modules = dict(model.named_modules(remove_duplicate=False))
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
return GraphModule(model, model.graph)
###############################################################################
# Functions for running the equalized model on the Numeric Suite #
###############################################################################
def get_layer_sqnr_dict(
model_a: nn.Module, model_b: nn.Module, x: torch.Tensor
) -> Dict[str, float]:
"""Runs the Numeric Suite on model_a and model_b and returns a dictionary
containing the SQNR between layers in model_a and model_b.
Note: In order to support equalized models, this function has a hacky fix in
which we do not match any torch.mul operators. This is because equalized
models contain extra mul operators to scale the input by the equalization
scale, but this edge case has not been resolved yet within the numeric suite code.
Args:
model_a: A float model
model_b: A quantized model
x: Inputs to use during calibration
"""
import torch.ao.ns._numeric_suite_fx as ns
from torch.ao.ns.fx.mappings import get_unmatchable_types_map
unmatchable_types_map = get_unmatchable_types_map()
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
model_a_ns, model_b_ns = ns.add_loggers(
"fp32",
model_a,
"int8",
model_b,
ns.OutputLogger,
unmatchable_types_map=unmatchable_types_map,
)
model_a_ns(x)
model_b_ns(x)
activation_comparison_dict = ns.extract_logger_info(
model_a_ns, model_b_ns, ns.OutputLogger, "int8"
)
ns.extend_logger_results_with_comparison(
activation_comparison_dict,
"fp32",
"int8",
torch.ao.ns.fx.utils.compute_sqnr,
"sqnr",
)
# Construct a dictionary mapping layer names to the SQNR values
layer_sqnr_dict = {}
for key in activation_comparison_dict:
layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"]
sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0]
layer_sqnr_dict[layer] = sqnr
return layer_sqnr_dict
def get_equalization_qconfig_dict(
layer_sqnr_dict: Dict[str, float], num_layers_to_equalize: int
) -> Any:
"""Given the layer to SQNR dictionary, find the layers with the highest
quantization errors, and return an equalization_qconfig_dict
specifying to only equalize those top layers.
Args:
layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
when comparing an equalized model against a float model)
num_layers_to_equalize: Number of layers with the highest quantization
errors to equalize
"""
# Sort the layer_sqnr_dictionary values and get the layers with the lowest
# SQNR values (aka highest quantization errors)
layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1))
layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
# Constructs an equalization_qconfig_dict that specifies to only equalize
# the layers with the highest quantization errors
module_to_qconfig_list = [
(item[0], default_equalization_qconfig) for item in layers_to_equalize
]
equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
return equalization_qconfig_dict

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,655 @@
# mypy: allow-untyped-defs
from collections import OrderedDict
from typing import Any, Callable, Dict, Set, Tuple
import torch
from torch.ao.quantization.fx._equalize import EqualizationQConfig
from torch.ao.quantization.fx._model_report.detector import (
DETECTOR_IS_POST_OBS_KEY,
DETECTOR_OBS_ARGS_KEY,
DETECTOR_OBS_TO_INSERT_KEY,
DETECTOR_TARGET_NODE_KEY,
DetectorBase,
DetectorQConfigInfo,
)
from torch.ao.quantization.fx._model_report.model_report_visualizer import (
ModelReportVisualizer,
)
from torch.ao.quantization.fx.graph_module import GraphModule
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping
class ModelReport:
r"""
The ModelReport class aims to provide users an easy way to diagnose issues that they run into
with their models. The class works with all traceable GraphModules to help diagnose issues,
though the requirements on the type of model more-so depends on the specific report the user
is trying to generate. With respect to the reports, the ModelReport class is initialized with
a set of Detector classes, each of which generate reports on quantization configuration
issues a use might have.
Currently supports generating reports on:
- Suggestions for per-channel vs. per-tensor quantization (nn.Module)
- Suggestions for dynamic vs static quantization for linear layers (Graph Modules)
- Suggestions for input-weight equalization for linear and conv layers (Graph Modules)
- Suggestions for outlier detection for all layers (Graph Modules)
The ModelReport class has the primary functionality of inserting observers (primarily the ModelReportObserver)
where needed for each detector to gather the information it needs, and then after callibration, the ModelReport
class compiles the report generated by each Detector class into a single report to return to the user. It also
has the capability to remove all the observers it inserted as well.
* :attr:`_model` The model we wish to generate the report for. Must be a traceable GraphModule
* :attr:`_desired_report_detectors` The set of Detectors representing desired reports from the ModelReport class
Make sure that these are all unique types of detectors [do not have more than 1 of the same class]
* :attr:`_desired_detector_names` The set of detector names of the _desired_report_detectors.
This set is generated by calling the get_detector_name() of each detector
* :attr:`_detector_name_to_observer_fqns` The mapping from each detector to fqns of observers of interest
The purpose of this is to keep track of what observers were inserted for each detector, so that they
can be removed at the end if desired
* :attr:`_prepared_flag` A boolean flag that keeps track of whether we have prepared the model or not
This is to ensure we only insert observers once with the ModelReport instance
* :attr:`_removed_observers` A boolean to track if we have removed observers already
The purpose is to ensure we don't attempt to remove observers twice with the same ModelReport
instance. This also allows the functionality where we can generate the report multiple times
as long as we haven't removed the observers yet.
Note:
This class was initially designed to work with the Fx Graph Mode workflow in mind. However,
full functionality is available as long as there is a traceable GraphModule that is being used.
One method to get a traceable GraphModule without going through the Fx workflow is to use
the QuantizationTracer class.
General Flow for Fx workflow:
1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects and model
2.) Prepare your model with prepare_fx
3.) Call model_report.prepare_detailed_calibration to add relevant observers
4.) Callibrate your model with data
5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
Optional
6.) Call model_report.generate_visualizer to get a ModelReportVisualizer instance
7.) To help in parsing report information and debugging, view report info as a:
- Table
- Histogram
- Line plot
8.) Call model_report.generate_qconfigs to generate the qconfigs based on the report suggestions
Example (with QuantizationTracer):
>>> # xdoctest: +SKIP
>>> # get the necessary qconfig
>>> config = PrepareCustomConfig()
>>> skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(config, False)
>>> # initialize our model and get GraphModule
>>> model = SomeModel()
>>> tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
>>> graph_module = GraphModule(model, tracer.trace(model))
>>> # get our set of detectors and ModelReport instance
>>> detector_set = set([DynamicStaticDetector(tolerance=0.5), InputWeightEqualizationDetector(ratio_threshold=0.7)])
>>> tracer_reporter = ModelReport(graph_module, tracer_detector_set)
>>> # now we insert the observers and callibrate the model
>>> tracer_model_with_observers = tracer_reporter.prepare_detailed_calibration()
>>> for i in range(num_callibration_batches):
>>> example_input = get_callibration_input()
>>> tracer_model_with_observers(example_input)
>>> # finally we generate the reports and optionally remove the observers we inserted
>>> reports = tracer_reporter.generate_model_report(remove_inserted_observers=True)
>>> # Optional: we can generate the qconfig mapping based on the suggestions
>>> qconfigs = model_report.generate_qconfig_mapping()
>>> # Optional: we can generate the equalization mapping based on the suggestions
>>> qconfigs = model_report.generate_equalization_mapping()
>>> # Optional: we get a ModelReportVisualizer instance to do any visualizations desired
>>> model_report_visualizer = tracer_reporter.generate_visualizer()
"""
def __init__(self, model: GraphModule, desired_report_detectors: Set[DetectorBase]):
if len(desired_report_detectors) == 0:
raise ValueError("Should include at least 1 desired report")
# keep track of the model we wish to generate report for
self._model: GraphModule = model
# keep the reports private so they can't be modified
self._desired_report_detectors = desired_report_detectors
self._desired_detector_names = {
detector.get_detector_name() for detector in desired_report_detectors
}
# keep a mapping of desired reports to observers of interest
# this is to get the readings, and to remove them, can create a large set
# this set can then be used to traverse the graph and remove added observers
self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
# initialize each report to have empty set of observers of interest
for desired_report in self._desired_detector_names:
self._detector_name_to_observer_fqns[desired_report] = set()
# flags to ensure that we can only prepare and remove observers once
self._prepared_flag = False
self._removed_observers = False
# store the reports that we generated for visualization purposes
# initially empty since no reports generated
self._generated_reports: Dict[str, Dict] = {}
def get_desired_reports_names(self) -> Set[str]:
"""Returns a copy of the desired reports for viewing"""
return self._desired_detector_names.copy()
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
"""Returns a copy of the observers of interest for viewing"""
return self._detector_name_to_observer_fqns.copy()
def prepare_detailed_calibration(self) -> GraphModule:
r"""
Takes in a graph model and inserts the following observers:
- ModelReportObserver
Each observer is inserted based on the desired_reports into the relevant locations
Right now, each report in self._desired_detector_names has independent insertions
However, if a module already has a Observer of the same type, the insertion will not occur
This is because all of the same type of Observer collect same information, so redundant
Returns the same GraphModule with the observers inserted
"""
# if already prepared once, cannot prepare again
if self._prepared_flag:
raise ValueError(
"Already ran preparing detailed callibration. Run the report generation next after callibration."
)
# loop through each detector, find where placements should be, and keep track
insert_observers_fqns: Dict[str, Any] = {}
for detector in self._desired_report_detectors:
# determine observer points for each detector
obs_fqn_to_info = detector.determine_observer_insert_points(self._model)
# map each insert point to the observer to use
insert_observers_fqns.update(obs_fqn_to_info)
# update the set of observers this report cares about
self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(
obs_fqn_to_info.keys()
)
# now insert all the observers at their desired locations
for observer_fqn in insert_observers_fqns:
target_node = insert_observers_fqns[observer_fqn][DETECTOR_TARGET_NODE_KEY]
insert_obs = insert_observers_fqns[observer_fqn][DETECTOR_OBS_TO_INSERT_KEY]
insert_post = insert_observers_fqns[observer_fqn][DETECTOR_IS_POST_OBS_KEY]
observer_args = insert_observers_fqns[observer_fqn][DETECTOR_OBS_ARGS_KEY]
self._insert_observer_around_module(
observer_fqn, target_node, insert_obs, observer_args, insert_post
)
self._prepared_flag = True
return self._model
def _insert_observer_around_module(
self,
obs_fqn: str,
target_node: torch.fx.node.Node,
obs_to_insert: ObserverBase,
observer_args: Tuple,
insert_post: bool,
):
r"""
Helper function that inserts the observer into both the graph structure and the module of the model
Args
node_fqn (str): The fully qualified name of the observer we want to insert
target_node (torch.fx.node.Node): The node in model we are inserting observers around
obs_to_insert (ObserverBase): The observer we are inserting around target_node
observer_args (Tuple): The arguments we want to pass into the observer
insert_post (bool): whether this is meant to be a post observer for this node
"""
# if we are inserting post, then our target node is the next node
if insert_post:
target_node = target_node.next
with self._model.graph.inserting_before(target_node):
self._model.add_submodule(obs_fqn, obs_to_insert)
self._model.graph.create_node(
op="call_module", target=obs_fqn, args=observer_args
)
# recompile model after inserts are made
self._model.recompile()
def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node:
r"""
Takes in a node fqn and returns the node based on the fqn
Args
node_fqn (str): The fully qualified name of the node we want to find in model
Returns the Node object of the given node_fqn otherwise returns None
"""
node_to_return = None
for node in self._model.graph.nodes:
# if the target matches the fqn, it's the node we are looking for
if node.target == node_fqn:
node_to_return = node
break
if node_to_return is None:
raise ValueError("The node_fqn is was not found within the module.")
# assert for MyPy
assert isinstance(node_to_return, torch.fx.node.Node)
return node_to_return
def generate_model_report(
self, remove_inserted_observers: bool
) -> Dict[str, Tuple[str, Dict]]:
r"""
Generates all the requested reports.
Note:
You should have callibrated the model with relevant data before calling this
The reports generated are specified by the desired_reports specified in desired_reports
Can optionally remove all the observers inserted by the ModelReport instance
Args:
remove_inserted_observers (bool): True to remove the observers inserted by this ModelReport instance
Returns a mapping of each desired report name to a tuple with:
The textual summary of that report information
A dictionary containing relevant statistics or information for that report
Note:
Throws exception if we try to generate report on model we already removed observers from
Throws exception if we try to generate report without preparing for callibration
"""
# if we haven't prepped model for callibration, then we shouldn't generate report yet
if not self._prepared_flag:
raise Exception( # noqa: TRY002
"Cannot generate report without preparing model for callibration"
)
# if we already removed the observers, we cannot generate report
if self._removed_observers:
raise Exception( # noqa: TRY002
"Cannot generate report on model you already removed observers from"
)
# keep track of all the reports of interest and their outputs
reports_of_interest = {}
for detector in self._desired_report_detectors:
# generate the individual report for the detector
report_output = detector.generate_detector_report(self._model)
reports_of_interest[detector.get_detector_name()] = report_output
# if user wishes to remove inserted observers, go ahead and remove
if remove_inserted_observers:
self._removed_observers = True
# get the set of all Observers inserted by this instance of ModelReport
all_observers_of_interest: Set[str] = set()
for desired_report in self._detector_name_to_observer_fqns:
observers_of_interest = self._detector_name_to_observer_fqns[
desired_report
]
all_observers_of_interest.update(observers_of_interest)
# go through all_observers_of_interest and remove them from the graph and model
for observer_fqn in all_observers_of_interest:
# remove the observer from the model
self._model.delete_submodule(observer_fqn)
# remove the observer from the graph structure
node_obj = self._get_node_from_fqn(observer_fqn)
if node_obj:
self._model.graph.erase_node(node_obj)
else:
raise ValueError("Node no longer exists in GraphModule structure")
# remember to recompile the model
self._model.recompile()
# save the generated reports for visualization purposes
saved_reports: Dict[str, Dict] = {
report_name: report_tuple[1]
for report_name, report_tuple in reports_of_interest.items()
}
self._generated_reports = saved_reports
# return the reports of interest
return reports_of_interest
def _is_same_info_for_same_key(self, info_dict_a: Dict, info_dict_b: Dict) -> bool:
r"""
Takes in two dictionaries and ensures that any common keys between the two have the same
values.
Args:
info_dict_a (Dict): First dictionary we wish to compare
info_dict_b (Dict): Second dictionary we wish to compare
Returns True if all shared keys have same values, false otherwise
"""
# get the set of keys for both
dict_a_keys: Set = set(info_dict_a.keys())
dict_b_keys: Set = set(info_dict_b.keys())
# get the insersection keys and check if same value for both dicts
intersecting_keys: Set = dict_a_keys.intersection(dict_b_keys)
for key in intersecting_keys:
dict_a_val = info_dict_a[key]
dict_b_val = info_dict_b[key]
# if it's a tensor we have to handle separately
if type(dict_a_val) == torch.Tensor:
# if dict_b_val not tensor, automatically false
if (
type(dict_b_val) != torch.Tensor
or sum(dict_a_val != dict_b_val) != 0
):
return False
else:
# for non-tensor vals
if dict_a_val != dict_b_val:
return False
# if no non matching shared keys found, return true
return True
def _reformat_reports_for_visualizer(self) -> OrderedDict:
r"""
Takes the generated reports and reformats them into the format that is desired by the
ModelReportVisualizer
Returns an OrderedDict mapping module_fqns to their features
"""
# we want to reorder and reformat the information so it is ordered in terms of order
# found in the model
# first create new dict with all modules as keys and features under respective module
module_fqns_to_features: Dict[str, Dict] = {}
for report_name in self._generated_reports:
# get mod -> feature dict and go through
module_info = self._generated_reports[report_name]
for module_fqn in module_info:
# check if already in our accumulation dict
if module_fqn in module_fqns_to_features:
# we merge all the features together
new_info: Dict = module_info[module_fqn]
present_info: Dict = module_fqns_to_features[module_fqn]
# merge them together into the new unioned dict
# same features keys -> same info, so okay if override
# do safety check to make sure shared keys have same info
if self._is_same_info_for_same_key(new_info, present_info):
module_fqns_to_features[module_fqn] = {
**new_info,
**present_info,
}
else:
error_str = "You have the same key with different values across detectors. "
error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors."
raise ValueError(error_str)
else:
# we just set it
module_fqns_to_features[module_fqn] = module_info[module_fqn]
# our ordered dict so that modules can be ordered in order of how they appear in model
features_by_module: OrderedDict[str, Dict] = OrderedDict()
# we loop through modules in graph in order
for fqn, module in self._model.named_modules():
# find that fqn in fqns_to_features
if fqn in module_fqns_to_features:
# add it to our ordered dict
features_by_module[fqn] = module_fqns_to_features[fqn]
# return the ordered dict of info we created
return features_by_module
def generate_visualizer(self) -> ModelReportVisualizer:
r"""
Generates a ModelReportVisualizer instance using the reports generated
by the generate_model_report() method.
Returns the generated ModelReportVisualizer instance initialized
Note:
Throws exception if attempt to get visualizers without generating report
"""
# check if user has generated reports at least once
if len(self._generated_reports) == 0:
raise Exception( # noqa: TRY002
"Unable to generate visualizers without first generating reports"
)
# get the ordered dict mapping modules to their full set of collected features / stats
module_fqns_to_features: OrderedDict = self._reformat_reports_for_visualizer()
# create and return ModelReportVisualizer instance
visualizer: ModelReportVisualizer = ModelReportVisualizer(
module_fqns_to_features
)
return visualizer
def _generate_qconfig_mapping_helper(
self,
detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo],
generation_function: Callable,
) -> QConfigMapping:
r"""
This helper takes in the compiled detector qconfig info that
has been compiled together and merges it into a QConfigMapping
"""
# keep track of the qconfigmapping
qconfig_mapping = QConfigMapping()
# loop through each module / fqn and attempt to create QConfigMapping
for fqn, module in self._model.named_modules():
# if we have a qconfig info for this module
if fqn in detector_qconfig_info_combined:
qconfig_info_compiled = detector_qconfig_info_combined[fqn]
# now generate the qconfig and add it to the mapping
generated_qconfig = generation_function(qconfig_info_compiled, module)
# add to our config
qconfig_mapping.set_module_name(fqn, generated_qconfig)
# return compiled mapping
return qconfig_mapping
def _update_detector_quantizaiton_qconfig_info(
self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
):
r"""
Takes in the old and new information and updates the combined information.
Args:
combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
into it
"""
combined_info.is_activation_dynamic = (
combined_info.is_activation_dynamic or new_info.is_activation_dynamic
)
combined_info.is_weight_per_channel = (
combined_info.is_weight_per_channel or new_info.is_weight_per_channel
)
def _update_detector_equalization_qconfig_info(
self, combined_info: DetectorQConfigInfo, new_info: DetectorQConfigInfo
):
r"""
Takes in the old and new information and updates the combined information.
Args:
combined_info (DetectorQConfigInfo): The DetectorQConfigInfo we are compiling all of the information in
new_info (DetectorQConfigInfo): The DetectorQConfigInfo with the information we are trying to merge the new info
into it
"""
is_equalization_recommended = (
combined_info.is_equalization_recommended
or new_info.is_equalization_recommended
)
combined_info.is_equalization_recommended = is_equalization_recommended
def _generate_module_fqn_to_detector_info_mapping(
self, update_qconfig_info_function: Callable
) -> Dict[str, DetectorQConfigInfo]:
r"""
Generates a QConfigMapping based on the suggestions of the
ModelReport API. The generated mapping encompasses all the
different types of feedback from the different detectors
all into one place.
These configs are based on the suggestions provided by the ModelReport API
and can only be generated once the reports have been generated.
Args:
update_qconfig_info_function (Callable) takes in a function that takes in two DetectorQConfigInfo
and updates the one that is being compiled
Returns a Dict mapping module_fqns to DetectorQConfigInfo objects
Note:
Throws exception if we try to generate mapping on model we already removed observers from
Throws exception if we try to generate mapping without preparing for callibration
"""
# if we haven't prepped model for callibration, then we shouldn't generate mapping yet
if not self._prepared_flag:
raise Exception( # noqa: TRY002
"Cannot generate report without preparing model for callibration"
)
# if we already removed the observers, we cannot mapping
if self._removed_observers:
raise Exception( # noqa: TRY002
"Cannot generate report on model you already removed observers from"
)
# keep track of qconfig info for each module across detectors
detector_qconfig_info_combined: Dict[str, DetectorQConfigInfo] = {}
for detector in self._desired_report_detectors:
# get the info from the detector
detector_info: Dict[str, DetectorQConfigInfo] = detector.get_qconfig_info(
self._model
)
# we go through the modules
for module_fqn in detector_info:
# see if we already have info on it
if module_fqn in detector_qconfig_info_combined:
# we combine the current options with what is there
current_options = detector_qconfig_info_combined[module_fqn]
detector_options = detector_info[module_fqn]
update_qconfig_info_function(current_options, detector_options)
else:
# we just use this for now
detector_qconfig_info_combined[module_fqn] = detector_info[
module_fqn
]
return detector_qconfig_info_combined
def generate_qconfig_mapping(self) -> QConfigMapping:
r"""
Generates a QConfigMapping based on the suggestions of the
ModelReport API. The generated mapping encompasses all the
different types of feedback from the different detectors
all into one place.
These configs are based on the suggestions provided by the ModelReport API
and can only be generated once the reports have been generated.
Returns a QConfigMapping for the quantization configuration
Note:
Throws exception if we try to generate mapping on model we already removed observers from
Throws exception if we try to generate mapping without preparing for callibration
"""
# get the mapping info
detector_qconfig_info_combined = (
self._generate_module_fqn_to_detector_info_mapping(
self._update_detector_quantizaiton_qconfig_info
)
)
# we will do a bit of processing and remove fqns that don't have input weight recommended
# now we generate the QConfig for each of the options
mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
detector_qconfig_info_combined, self._quantization_config_generator
)
# return the generated mapping
return mapping
def _quantization_config_generator(
self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
) -> QConfig:
r"""
Returns the quantization configuration generated by the DetectorQConfigInfo object
"""
return detector_qconfig_info.generate_quantization_qconfig(module)
def _equalization_config_generator(
self, detector_qconfig_info: DetectorQConfigInfo, module: torch.nn.Module
) -> EqualizationQConfig:
r"""
We ignore the module argument here, and only focus on thedetector_qconfig_info
Returns the equalization configuration generated by the DetectorQConfigInfo object
"""
return detector_qconfig_info.generate_equalization_qconfig()
def generate_equalization_mapping(self) -> QConfigMapping:
r"""
Generates a QConfigMapping based on the suggestions of the
ModelReport API for equalization. The generated mapping encompasses all the
different types of feedback from the input-weight equalization detector.
These configs are based on the suggestions provided by the ModelReport API
and can only be generated once the reports have been generated.
Returns a QConfigMapping for the equalization configuration
"""
# get the mapping info
detector_qconfig_info_combined = (
self._generate_module_fqn_to_detector_info_mapping(
self._update_detector_equalization_qconfig_info
)
)
# now we generate the QConfig for each of the options
mapping: QConfigMapping = self._generate_qconfig_mapping_helper(
detector_qconfig_info_combined, self._equalization_config_generator
)
# return the generated mapping
return mapping

View File

@ -0,0 +1,285 @@
# mypy: allow-untyped-defs
import torch
from torch.ao.quantization.observer import ObserverBase
class ModelReportObserver(ObserverBase):
r"""This observer is used to record additional information regarding keeping track
of S = average_batch_activation_range/epoch_activation_range.
The purpose of this information is to prepare a report to present to users on whether
Dynamic or Static Quantization is more appropriate for their model given the general
distributions of their data.
Args:
ch_axis (int, optional): The channel axis for which the range and outlier stats are computed
Default: 1
comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers
Should be between 0 and 1 exclusive
Default: 0.9
* :attr:`num_batches_tracked` specifies number of batches passed through the observer
* :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through
* :attr:`epoch_activation_min` defines the minimum value passed through the observer
* :attr:`epoch_activation_max` defines the maximum value passed through the observer
* :attr:`ch_axis` defines the channel being used to compute per channel min max stats
* :attr:`min_val` defines the per channel minimum values passed through
* :attr:`max_val` defines the per channel maximum values passed through
* :attr:`comp_percentile` defines comparison percentile to find outliers
* :attr:`average_percentile_ratio` defines the per channel average percentile ratios
* :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel
* :attr:`constant_channels` defines the number of batches that aren't constant channels per channel
Note: this tool is meant for FX Graph Mode Quantization
"""
epoch_activation_min: torch.Tensor
epoch_activation_max: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
comp_percentile: torch.Tensor
average_percentile_ratio: torch.Tensor
percentile_batches_tracked: torch.Tensor
constant_channels: torch.Tensor
def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9):
super().__init__(torch.qint8)
self.num_batches_tracked = 0
# keep track of the min and mix of the range for average batch and epoch as a whole
self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0))
self.register_buffer("epoch_activation_min", torch.tensor(float("inf")))
self.register_buffer("epoch_activation_max", torch.tensor(float("-inf")))
# keep track of per channel min max information using the given channel
self.ch_axis: int = ch_axis
self.register_buffer("min_val", torch.tensor([]))
self.register_buffer("max_val", torch.tensor([]))
# keep track of percentile ratio information per channel
self.register_buffer("comp_percentile", torch.tensor([comp_percentile]))
self.register_buffer("average_percentile_ratio", torch.tensor([]))
self.register_buffer("percentile_batches_tracked", torch.tensor([]))
self.register_buffer("constant_channels", torch.tensor([]))
def forward(self, x):
x_copy = x.detach() # avoid keeping autograd tape
x_copy = x_copy.to(self.epoch_activation_min.dtype)
x_copy = self._calculate_range_stats(x_copy)
x_copy = self._calculate_min_max_stats(x_copy)
x_copy = self._calculate_percentile_stats(x_copy)
# return the passed in the value
return x
def _calculate_range_stats(self, x_copy):
r"""Calculates and stores range stats with forward values.
Args
x_copy: A copy of the forward data
Returns the passed in x_copy
"""
# get the min, max values of the data
min_val_cur, max_val_cur = torch.aminmax(x_copy)
# calculate new epoch range values
epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)
self.epoch_activation_min.copy_(epoch_min_val)
self.epoch_activation_max.copy_(epoch_max_val)
# calculate the average batch activation range
current_batch_range = max_val_cur - min_val_cur
new_range = (
self.average_batch_activation_range * self.num_batches_tracked
+ current_batch_range
) / (self.num_batches_tracked + 1)
self.average_batch_activation_range = new_range
self.num_batches_tracked += 1 # new batch was processed
return x_copy
def _calculate_min_max_stats(self, x_copy):
r"""Calculates and stores the per_channel min, max stats with forward values.
Does calculation based on channel axis: self.ch_axis
Args
x_copy: A copy of the forward data
Returns the passed in x_copy
"""
# get the current min and max vals
min_val = self.min_val
max_val = self.max_val
x_dim = x_copy.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x_copy.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
if min_val.numel() == 0 or max_val.numel() == 0:
min_val, max_val = torch.aminmax(y, dim=1)
else:
min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
min_val = torch.min(min_val_cur, min_val)
max_val = torch.max(max_val_cur, max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_copy
def _calculate_percentile_stats(self, x_copy):
r"""Calculates and stores the per_channel percentile stats with forward values.
Does calculation based on channel axis: self.ch_axis
Args
x_copy: A copy of the forward data
Returns the passed in x_copy
"""
# get the dimension of the copy
x_dim = x_copy.size()
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
new_axis_list[self.ch_axis] = 0
new_axis_list[0] = self.ch_axis
y = x_copy.permute(new_axis_list)
# Need to match dtype of min/max because the updates to buffers
# are done in place and types need to match for comparisons
y = y.to(self.min_val.dtype)
y = torch.flatten(y, start_dim=1)
y = y.to(dtype=self.min_val.dtype, device="cpu")
# find the percentile values along the axis
# we want both 100th percentile and comp_percentile
# we also want to find 0th quartile to see if we have constant channel
quantiles_list = [0, self.comp_percentile, 1.00]
quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype)
# find the quantiles
desired_quantiles = torch.quantile(
y, quantiles_to_find, dim=self.ch_axis, interpolation="lower"
)
zero_quantile = desired_quantiles[0]
comp_quantile = desired_quantiles[1]
hundreth_quartile = desired_quantiles[2]
# if any of the channels have 0s, we ignore that channel for this calculation
any_non_zero_quantile_value: torch.Tensor = (
comp_quantile != torch.tensor([0])
) | (hundreth_quartile != torch.tensor([0]))
any_non_zero_quantile_value = (
any_non_zero_quantile_value.int()
) # transform boolean values to int values
# we also check if we have a constant channel
any_constant_channels: torch.Tensor = (
hundreth_quartile - zero_quantile
) == torch.tensor([0])
any_constant_channels = (
any_constant_channels.int()
) # transform boolean values to int values
# possibilities to get nan as an answer
# will ignore any of these three cases with 0s and just not deal with them for now
# case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative
# case (2) 0 in denominator: is possible unless case 3, we just ignore
# case (3) 0 in both: not outlier, channel just kinda useless, ignore
# get the ratio and get rid of nan values
quantile_ratios = hundreth_quartile / comp_quantile
quantile_ratios = torch.nan_to_num(quantile_ratios)
# update averages, remembering to only update if didn't have zeros
ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios
# if num_batches and average_ratio are not initialized, we want to initialize them
if (
self.percentile_batches_tracked.shape[0] == 0
or self.average_percentile_ratio.shape[0] == 0
):
self.percentile_batches_tracked = torch.zeros_like(
any_non_zero_quantile_value
)
self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero)
# also initialize the constant channel var if that is not initialized separately
if self.constant_channels.shape[0] == 0:
self.constant_channels = torch.zeros_like(any_constant_channels)
# get current num batches and average ratio
num_batches = self.percentile_batches_tracked
average_ratio = self.average_percentile_ratio
# calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches
new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value
new_ratios: torch.Tensor = (
(average_ratio * num_batches) + ratio_if_not_zero
) / new_number_of_batches
new_ratios = torch.nan_to_num(new_ratios)
# update the number of non-constant channels
new_constant_count: torch.Tensor = (
self.constant_channels + any_constant_channels
)
# update the values locally
self.percentile_batches_tracked.copy_(new_number_of_batches)
self.average_percentile_ratio.copy_(new_ratios)
self.constant_channels.copy_(new_constant_count)
return x_copy
@torch.jit.export
def get_batch_to_epoch_ratio(self):
epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min
if epoch_activation_range == torch.tensor(float(0)):
raise ValueError("Range for Epoch is 0")
elif epoch_activation_range == torch.tensor(float("inf")):
raise ValueError(
"No data has been run through observer or infinity value present"
)
else:
return self.average_batch_activation_range / epoch_activation_range
@torch.jit.export
def reset_batch_and_epoch_values(self):
# set all the values back to their original defaults for a new epoch
# keep device
device = self.max_val.device
self.num_batches_tracked = 0
self.average_batch_activation_range = torch.tensor(float(0), device=device)
self.epoch_activation_min = torch.tensor(float("inf"), device=device)
self.epoch_activation_max = torch.tensor(float("-inf"), device=device)
self.min_val = torch.tensor([], device=device)
self.max_val = torch.tensor([], device=device)
self.average_percentile_ratio = torch.tensor([], device=device)
self.percentile_batches_tracked = torch.tensor([], device=device)
self.constant_channels = torch.tensor([], device=device)
@torch.jit.export
def calculate_qparams(self):
raise Exception( # noqa: TRY002
"calculate_qparams should not be called for ModelReportObserver"
)

View File

@ -0,0 +1,713 @@
# mypy: allow-untyped-defs
from collections import OrderedDict as OrdDict
from typing import Any, Dict, List, OrderedDict, Set, Tuple
import torch
# try to import tablate
got_tabulate = True
try:
from tabulate import tabulate
except ImportError:
got_tabulate = False
# var to see if we could import matplotlib
got_matplotlib = True
try:
import matplotlib.pyplot as plt
except ImportError:
got_matplotlib = False
class ModelReportVisualizer:
r"""
The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
that were generated by the ModelReport API. However, at a higher level, the class aims to provide
some level of visualization of statistics to PyTorch in order to make it easier to parse data and
diagnose any potential issues with data or a specific model. With respect to the visualizations,
the ModelReportVisualizer class currently supports several methods of visualizing data.
Supported Visualization Methods Include:
- Table format
- Plot format (line graph)
- Histogram format
For all of the existing visualization methods, there is the option to filter data based on:
- A module fqn prefix
- Feature [required for the plot and histogram]
* :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
Ensure sure that features that are the same across different report contain the same name
Ensure that objects representing the same features are the same type / dimension (where applicable)
Note:
Currently, the ModelReportVisualizer class supports visualization of data generated by the
ModelReport class. However, this structure is extensible and should allow the visualization of
other information as long as the information is structured in the following general format:
Report Structure
-- module_fqn [module with attached detectors]
|
-- feature keys [not every detector extracts same information]
[same collected info has same keys, unless can be specific to detector]
The goal behind the class is that the generated visualizations can be used in conjunction with the generated
report for people to get a better understanding of issues and what the fix might be. It is also just to provide
a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
that grows in size.
General Use Flow Expected
1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
2.) Prepare your model with prepare_fx
3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
4.) Callibrate your model with data
5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
7.) Use instance to view different views of data as desired, applying filters as needed
8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
"""
# keys for table dict
TABLE_TENSOR_KEY = "tensor_level_info"
TABLE_CHANNEL_KEY = "channel_level_info"
# Constants for header vals
NUM_NON_FEATURE_TENSOR_HEADERS = 2
NUM_NON_FEATURE_CHANNEL_HEADERS = 3
# Constants for row index in header
CHANNEL_NUM_INDEX = 2
def __init__(self, generated_reports: OrderedDict[str, Any]):
r"""
Initializes the ModelReportVisualizer instance with the necessary reports.
Args:
generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
can also be a dictionary generated in another manner, as long as format is same
"""
self.generated_reports = generated_reports
def get_all_unique_module_fqns(self) -> Set[str]:
r"""
The purpose of this method is to provide a user the set of all module_fqns so that if
they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
they don't need to manually parse the generated_reports dictionary to get this information.
Returns all the unique module fqns present in the reports the ModelReportVisualizer
instance was initialized with.
"""
# returns the keys of the ordered dict
return set(self.generated_reports.keys())
def get_all_unique_feature_names(
self, plottable_features_only: bool = True
) -> Set[str]:
r"""
The purpose of this method is to provide a user the set of all feature names so that if
they wish to use the filtering capabilities of the generate_table_view(), or use either of
the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
the generated_reports dictionary to get this information.
Args:
plottable_features_only (bool): True if the user is only looking for plottable features,
False otherwise
plottable features are those that are tensor values
Default: True (only return those feature names that are plottable)
Returns all the unique module fqns present in the reports the ModelReportVisualizer
instance was initialized with.
"""
unique_feature_names = set()
for module_fqn in self.generated_reports:
# get dict of the features
feature_dict: Dict[str, Any] = self.generated_reports[module_fqn]
# loop through features
for feature_name in feature_dict:
# if we need plottable, ensure type of val is tensor
if (
not plottable_features_only
or type(feature_dict[feature_name]) == torch.Tensor
):
unique_feature_names.add(feature_name)
# return our compiled set of unique feature names
return unique_feature_names
def _get_filtered_data(
self, feature_filter: str, module_fqn_filter: str
) -> OrderedDict[str, Any]:
r"""
Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
Args:
feature_filter (str): The feature filter, if we want to filter the set of data to only include
a certain set of features that include feature_filter
If feature = "", then we do not filter based on any features
module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
this prefix will be included
If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
First, the data is filtered based on module_fqn, and then filtered based on feature
Returns an OrderedDict (sorted in order of model) mapping:
module_fqns -> feature_names -> values
"""
# create return dict
filtered_dict: OrderedDict[str, Any] = OrdDict()
for module_fqn in self.generated_reports:
# first filter based on module
if module_fqn_filter == "" or module_fqn_filter in module_fqn:
# create entry for module and loop through features
filtered_dict[module_fqn] = {}
module_reports = self.generated_reports[module_fqn]
for feature_name in module_reports:
# check if filtering on features and do so if desired
if feature_filter == "" or feature_filter in feature_name:
filtered_dict[module_fqn][feature_name] = module_reports[
feature_name
]
# we have populated the filtered dict, and must return it
return filtered_dict
def _generate_tensor_table(
self,
filtered_data: OrderedDict[str, Dict[str, Any]],
tensor_features: List[str],
) -> Tuple[List, List]:
r"""
Takes in the filtered data and features list and generates the tensor headers and table
Currently meant to generate the headers and table for both the tensor information.
Args:
filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
module_fqns -> feature_names -> values
tensor_features (List[str]): A list of the tensor level features
Returns a tuple with:
A list of the headers of the tensor table
A list of lists containing the table information row by row
The 0th index row will contain the headers of the columns
The rest of the rows will contain data
"""
# now we compose the tensor information table
tensor_table: List[List[Any]] = []
tensor_headers: List[str] = []
# append the table row to the table only if we have features
if len(tensor_features) > 0:
# now we add all the data
for index, module_fqn in enumerate(filtered_data):
# we make a new row for the tensor table
tensor_table_row = [index, module_fqn]
for feature in tensor_features:
# we iterate in same order of added features
if feature in filtered_data[module_fqn]:
# add value if applicable to module
feature_val = filtered_data[module_fqn][feature]
else:
# add that it is not applicable
feature_val = "Not Applicable"
# if it's a tensor we want to extract val
if isinstance(feature_val, torch.Tensor):
feature_val = feature_val.item()
# we add to our list of values
tensor_table_row.append(feature_val)
tensor_table.append(tensor_table_row)
# add row of headers of we actually have something, otherwise just empty
if len(tensor_table) != 0:
tensor_headers = ["idx", "layer_fqn"] + tensor_features
return (tensor_headers, tensor_table)
def _generate_channels_table(
self,
filtered_data: OrderedDict[str, Any],
channel_features: List[str],
num_channels: int,
) -> Tuple[List, List]:
r"""
Takes in the filtered data and features list and generates the channels headers and table
Currently meant to generate the headers and table for both the channels information.
Args:
filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
module_fqns -> feature_names -> values
channel_features (List[str]): A list of the channel level features
num_channels (int): Number of channels in the channel data
Returns a tuple with:
A list of the headers of the channel table
A list of lists containing the table information row by row
The 0th index row will contain the headers of the columns
The rest of the rows will contain data
"""
# now we compose the table for the channel information table
channel_table: List[List[Any]] = []
channel_headers: List[str] = []
# counter to keep track of number of entries in
channel_table_entry_counter: int = 0
if len(channel_features) > 0:
# now we add all channel data
for module_fqn in filtered_data:
# we iterate over all channels
for channel in range(num_channels):
# we make a new row for the channel
new_channel_row = [channel_table_entry_counter, module_fqn, channel]
for feature in channel_features:
if feature in filtered_data[module_fqn]:
# add value if applicable to module
feature_val = filtered_data[module_fqn][feature][channel]
else:
# add that it is not applicable
feature_val = "Not Applicable"
# if it's a tensor we want to extract val
if type(feature_val) is torch.Tensor:
feature_val = feature_val.item()
# add value to channel specific row
new_channel_row.append(feature_val)
# add to table and increment row index counter
channel_table.append(new_channel_row)
channel_table_entry_counter += 1
# add row of headers of we actually have something, otherwise just empty
if len(channel_table) != 0:
channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
return (channel_headers, channel_table)
def generate_filtered_tables(
self, feature_filter: str = "", module_fqn_filter: str = ""
) -> Dict[str, Tuple[List, List]]:
r"""
Takes in optional filter values and generates two tables with desired information.
The generated tables are presented in both a list-of-lists format
The reason for the two tables are that they handle different things:
1.) the first table handles all tensor level information
2.) the second table handles and displays all channel based information
The reasoning for this is that having all the info in one table can make it ambiguous which collected
statistics are global, and which are actually per-channel, so it's better to split it up into two
tables. This also makes the information much easier to digest given the plethora of statistics collected
Tensor table columns:
idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
---- --------- --------- --------- --------- ---------
Per-Channel table columns:
idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
---- --------- ------- --------- --------- --------- ---------
Args:
feature_filter (str, optional): Filters the features presented to only those that
contain this filter substring
Default = "", results in all the features being printed
module_fqn_filter (str, optional): Only includes modules that contains this string
Default = "", results in all the modules in the reports to be visible in the table
Returns a dictionary with two keys:
(Dict[str, Tuple[List, List]]) A dict containing two keys:
"tensor_level_info", "channel_level_info"
Each key maps to a tuple with:
A list of the headers of each table
A list of lists containing the table information row by row
The 0th index row will contain the headers of the columns
The rest of the rows will contain data
Example Use:
>>> # xdoctest: +SKIP("undefined variables")
>>> mod_report_visualizer.generate_filtered_tables(
... feature_filter = "per_channel_min",
... module_fqn_filter = "block1"
... ) # generates table with per_channel_min info for all modules in block 1 of the model
"""
# first get the filtered data
filtered_data: OrderedDict[str, Any] = self._get_filtered_data(
feature_filter, module_fqn_filter
)
# now we split into tensor and per-channel data
tensor_features: Set[str] = set()
channel_features: Set[str] = set()
# keep track of the number of channels we have
num_channels: int = 0
for module_fqn in filtered_data:
for feature_name in filtered_data[module_fqn]:
# get the data for that specific feature
feature_data = filtered_data[module_fqn][feature_name]
# check if not zero dim tensor
is_tensor: bool = isinstance(feature_data, torch.Tensor)
is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
if is_not_zero_dim or isinstance(feature_data, list):
# works means per channel
channel_features.add(feature_name)
num_channels = len(feature_data)
else:
# means is per-tensor
tensor_features.add(feature_name)
# we make them lists for iteration purposes
tensor_features_list: List[str] = sorted(tensor_features)
channel_features_list: List[str] = sorted(channel_features)
# get the tensor info
tensor_headers, tensor_table = self._generate_tensor_table(
filtered_data, tensor_features_list
)
# get the channel info
channel_headers, channel_table = self._generate_channels_table(
filtered_data, channel_features_list, num_channels
)
# let's now create the dictionary to return
table_dict = {
self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table),
self.TABLE_CHANNEL_KEY: (channel_headers, channel_table),
}
# return the two tables
return table_dict
def generate_table_visualization(
self, feature_filter: str = "", module_fqn_filter: str = ""
):
r"""
Takes in optional filter values and prints out formatted tables of the information.
The reason for the two tables printed out instead of one large one are that they handle different things:
1.) the first table handles all tensor level information
2.) the second table handles and displays all channel based information
The reasoning for this is that having all the info in one table can make it ambiguous which collected
statistics are global, and which are actually per-channel, so it's better to split it up into two
tables. This also makes the information much easier to digest given the plethora of statistics collected
Tensor table columns:
idx layer_fqn feature_1 feature_2 feature_3 .... feature_n
---- --------- --------- --------- --------- ---------
Per-Channel table columns:
idx layer_fqn channel feature_1 feature_2 feature_3 .... feature_n
---- --------- ------- --------- --------- --------- ---------
Args:
feature_filter (str, optional): Filters the features presented to only those that
contain this filter substring
Default = "", results in all the features being printed
module_fqn_filter (str, optional): Only includes modules that contains this string
Default = "", results in all the modules in the reports to be visible in the table
Example Use:
>>> # xdoctest: +SKIP("undefined variables")
>>> mod_report_visualizer.generate_table_visualization(
... feature_filter = "per_channel_min",
... module_fqn_filter = "block1"
... )
>>> # prints out neatly formatted table with per_channel_min info
>>> # for all modules in block 1 of the model
"""
# see if we got tabulate
if not got_tabulate:
print("Make sure to install tabulate and try again.")
return None
# get the table dict and the specific tables of interest
table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
# get the table string and print it out
# now we have populated the tables for each one
# let's create the strings to be returned
table_str = ""
# the tables will have some headers columns that are non-feature
# ex. table index, module name, channel index, etc.
# we want to look at header columns for features, that come after those headers
if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
# if we have at least one tensor level feature to be added we add tensor table
table_str += "Tensor Level Information \n"
table_str += tabulate(tensor_table, headers=tensor_headers)
if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
# if we have at least one channel level feature to be added we add tensor table
table_str += "\n\n Channel Level Information \n"
table_str += tabulate(channel_table, headers=channel_headers)
# if no features at all, let user know
if table_str == "":
table_str = "No data points to generate table with."
print(table_str)
def _get_plottable_data(
self, feature_filter: str, module_fqn_filter: str
) -> Tuple[List, List[List], bool]:
r"""
Takes in the feature filters and module filters and outputs the x and y data for plotting
Args:
feature_filter (str): Filters the features presented to only those that
contain this filter substring
module_fqn_filter (str): Only includes modules that contains this string
Returns a tuple of three elements
The first is a list containing relevant x-axis data
The second is a list containing the corresponding y-axis data
If the data is per channel
"""
# get the table dict and the specific tables of interest
table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
# make sure it is only 1 feature that is being plotted
# get the number of features in each of these
tensor_info_features_count = (
len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
)
channel_info_features_count = (
len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
)
# see if valid tensor or channel plot
is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
is_valid_per_channel_plot: bool = channel_info_features_count == 1
# offset should either be one of tensor or channel table or neither
feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
table = tensor_table
# if a per_channel plot, we have different offset and table
if is_valid_per_channel_plot:
feature_column_offset = (
ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
)
table = channel_table
x_data: List = []
y_data: List[List] = []
# the feature will either be a tensor feature or channel feature
if is_valid_per_tensor_plot:
for table_row_num, row in enumerate(table):
# get x_value to append
x_val_to_append = table_row_num
# the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index]
if not type(row_value) == str:
x_data.append(x_val_to_append)
y_data.append(row_value)
elif is_valid_per_channel_plot:
# gather the x_data and multiple y_data
# calculate the number of channels
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
for channel in range(num_channels):
y_data.append([]) # separate data list per channel
for table_row_num, row in enumerate(table):
# get x_value to append
x_val_to_append = table_row_num
current_channel = row[
self.CHANNEL_NUM_INDEX
] # initially chose current channel
new_module_index: int = table_row_num // num_channels
x_val_to_append = new_module_index
# the index of the feature will the 0 + num non feature columns
tensor_feature_index = feature_column_offset
row_value = row[tensor_feature_index]
if not type(row_value) == str:
# only append if new index we are appending
if len(x_data) == 0 or x_data[-1] != x_val_to_append:
x_data.append(x_val_to_append)
# append value for that channel
y_data[current_channel].append(row_value)
else:
# more than one feature was chosen
error_str = "Make sure to pick only a single feature with your filter to plot a graph."
error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
error_str += " Pick one of those features to plot."
raise ValueError(error_str)
# return x, y values, and if data is per-channel
return (x_data, y_data, is_valid_per_channel_plot)
def generate_plot_visualization(
self, feature_filter: str, module_fqn_filter: str = ""
):
r"""
Takes in a feature and optional module_filter and plots of the desired data.
For per channel features, it averages the value across the channels and plots a point
per module. The reason for this is that for models with hundreds of channels, it can
be hard to differentiate one channel line from another, and so the point of generating
a single average point per module is to give a sense of general trends that encourage
further deep dives.
Note:
Only features in the report that have tensor value data are plottable by this class
When the tensor information is plotted, it will plot:
idx as the x val, feature value as the y_val
When the channel information is plotted, it will plot:
the first idx of each module as the x val, feature value as the y_val [for each channel]
The reason for this is that we want to be able to compare values across the
channels for same layer, and it will be hard if values are staggered by idx
This means each module is represented by only 1 x value
Args:
feature_filter (str): Filters the features presented to only those that
contain this filter substring
module_fqn_filter (str, optional): Only includes modules that contains this string
Default = "", results in all the modules in the reports to be visible in the table
Example Use:
>>> # xdoctest: +SKIP("undefined variables")
>>> mod_report_visualizer.generate_plot_visualization(
... feature_filter = "per_channel_min",
... module_fqn_filter = "block1"
... )
>>> # outputs line plot of per_channel_min information for all
>>> # modules in block1 of model each channel gets it's own line,
>>> # and it's plotted across the in-order modules on the x-axis
"""
# checks if we have matplotlib and let's user know to install it if don't
if not got_matplotlib:
print("make sure to install matplotlib and try again.")
return None
# get the x and y data and if per channel
x_data, y_data, data_per_channel = self._get_plottable_data(
feature_filter, module_fqn_filter
)
# plot based on whether data is per channel or not
ax = plt.subplot()
ax.set_ylabel(feature_filter)
ax.set_title(feature_filter + " Plot")
plt.xticks(x_data) # only show ticks for actual points
if data_per_channel:
ax.set_xlabel("First idx of module")
# set the legend as well
# plot a single line that is average of the channel values
num_modules = len(
y_data[0]
) # all y_data have same length, so get num modules
num_channels = len(
y_data
) # we want num channels to be able to calculate average later
avg_vals = [
sum(y_data[:][index]) / num_channels for index in range(num_modules)
]
# plot the three things we measured
ax.plot(
x_data, avg_vals, label=f"Average Value Across {num_channels} Channels"
)
ax.legend(loc="upper right")
else:
ax.set_xlabel("idx")
ax.plot(x_data, y_data)
# actually show the plot
plt.show()
def generate_histogram_visualization(
self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10
):
r"""
Takes in a feature and optional module_filter and plots the histogram of desired data.
Note:
Only features in the report that have tensor value data can be viewed as a histogram
If you want to plot a histogram from all the channel values of a specific feature for
a specific model, make sure to specify both the model and the feature properly
in the filters and you should be able to see a distribution of the channel data
Args:
feature_filter (str, optional): Filters the features presented to only those that
contain this filter substring
Default = "", results in all the features being printed
module_fqn_filter (str, optional): Only includes modules that contains this string
Default = "", results in all the modules in the reports to be visible in the table
num_bins (int, optional): The number of bins to create the histogram with
Default = 10, the values will be split into 10 equal sized bins
Example Use:
>>> # xdoctest: +SKIP
>>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
... feature_filter = "per_channel_min",
... module_fqn_filter = "block1"
... )
# outputs histogram of per_channel_min information for all modules in block1 of model
information is gathered across all channels for all modules in block 1 for the
per_channel_min and is displayed in a histogram of equally sized bins
"""
# checks if we have matplotlib and let's user know to install it if don't
if not got_matplotlib:
print("make sure to install matplotlib and try again.")
return None
# get the x and y data and if per channel
x_data, y_data, data_per_channel = self._get_plottable_data(
feature_filter, module_fqn_filter
)
# for histogram, we just care about plotting the y data
# plot based on whether data is per channel or not
ax = plt.subplot()
ax.set_xlabel(feature_filter)
ax.set_ylabel("Frequency")
ax.set_title(feature_filter + " Histogram")
if data_per_channel:
# set the legend as well
# combine all the data
all_data = []
for channel_info in y_data:
all_data.extend(channel_info)
val, bins, _ = plt.hist(
all_data,
bins=num_bins,
stacked=True,
rwidth=0.8,
)
plt.xticks(bins)
else:
val, bins, _ = plt.hist(
y_data,
bins=num_bins,
stacked=False,
rwidth=0.8,
)
plt.xticks(bins)
plt.show()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,519 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.quant_type import (
_get_quant_type_to_str,
_quant_type_from_str,
QuantType,
)
__all__ = [
"ConvertCustomConfig",
"FuseCustomConfig",
"PrepareCustomConfig",
"StandaloneModuleConfigEntry",
]
# TODO: replace all usages with these constants
STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"
@dataclass
class StandaloneModuleConfigEntry:
# qconfig_mapping for the prepare function called in the submodule,
# None means use qconfig from parent qconfig_mapping
qconfig_mapping: Optional[QConfigMapping]
example_inputs: Tuple[Any, ...]
prepare_custom_config: Optional[PrepareCustomConfig]
backend_config: Optional[BackendConfig]
class PrepareCustomConfig:
"""
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
:func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
Example usage::
prepare_custom_config = PrepareCustomConfig() \
.set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
child_prepare_custom_config, backend_config) \
.set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
child_prepare_custom_config, backend_config) \
.set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
.set_non_traceable_module_names(["module2", "module3"]) \
.set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
.set_input_quantized_indexes([0]) \
.set_output_quantized_indexes([0]) \
.set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self) -> None:
self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
self.non_traceable_module_names: List[str] = []
self.non_traceable_module_classes: List[Type] = []
self.input_quantized_indexes: List[int] = []
self.output_quantized_indexes: List[int] = []
self.preserved_attributes: List[str] = []
def __repr__(self):
dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
return f"PrepareCustomConfig({dict_nonempty})"
def set_standalone_module_name(
self,
module_name: str,
qconfig_mapping: Optional[QConfigMapping],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Optional[PrepareCustomConfig],
backend_config: Optional[BackendConfig],
) -> PrepareCustomConfig:
"""
Set the configuration for running a standalone module identified by ``module_name``.
If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
"""
self.standalone_module_names[module_name] = StandaloneModuleConfigEntry(
qconfig_mapping, example_inputs, prepare_custom_config, backend_config
)
return self
def set_standalone_module_class(
self,
module_class: Type,
qconfig_mapping: Optional[QConfigMapping],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Optional[PrepareCustomConfig],
backend_config: Optional[BackendConfig],
) -> PrepareCustomConfig:
"""
Set the configuration for running a standalone module identified by ``module_class``.
If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
"""
self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry(
qconfig_mapping, example_inputs, prepare_custom_config, backend_config
)
return self
def set_float_to_observed_mapping(
self,
float_class: Type,
observed_class: Type,
quant_type: QuantType = QuantType.STATIC,
) -> PrepareCustomConfig:
"""
Set the mapping from a custom float module class to a custom observed module class.
The observed module class must have a ``from_float`` class method that converts the float module class
to the observed module class. This is currently only supported for static quantization.
"""
if quant_type != QuantType.STATIC:
raise ValueError(
"set_float_to_observed_mapping is currently only supported for static quantization"
)
if quant_type not in self.float_to_observed_mapping:
self.float_to_observed_mapping[quant_type] = {}
self.float_to_observed_mapping[quant_type][float_class] = observed_class
return self
def set_non_traceable_module_names(
self, module_names: List[str]
) -> PrepareCustomConfig:
"""
Set the modules that are not symbolically traceable, identified by name.
"""
self.non_traceable_module_names = module_names
return self
def set_non_traceable_module_classes(
self, module_classes: List[Type]
) -> PrepareCustomConfig:
"""
Set the modules that are not symbolically traceable, identified by class.
"""
self.non_traceable_module_classes = module_classes
return self
def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
"""
Set the indexes of the inputs of the graph that should be quantized.
Inputs are otherwise assumed to be in fp32 by default instead.
"""
self.input_quantized_indexes = indexes
return self
def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
"""
Set the indexes of the outputs of the graph that should be quantized.
Outputs are otherwise assumed to be in fp32 by default instead.
"""
self.output_quantized_indexes = indexes
return self
def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig:
"""
Set the names of the attributes that will persist in the graph module even if they are not used in
the model's ``forward`` method.
"""
self.preserved_attributes = attributes
return self
# TODO: remove this
@classmethod
def from_dict(
cls, prepare_custom_config_dict: Dict[str, Any]
) -> PrepareCustomConfig:
"""
Create a ``PrepareCustomConfig`` from a dictionary with the following items:
"standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
child_prepare_custom_config, backend_config) tuples
"standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
child_prepare_custom_config, backend_config) tuples
"float_to_observed_custom_module_class": a nested dictionary mapping from quantization
mode to an inner mapping from float module classes to observed module classes, e.g.
{"static": {FloatCustomModule: ObservedCustomModule}}
"non_traceable_module_name": a list of modules names that are not symbolically traceable
"non_traceable_module_class": a list of module classes that are not symbolically traceable
"input_quantized_idxs": a list of indexes of graph inputs that should be quantized
"output_quantized_idxs": a list of indexes of graph outputs that should be quantized
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
This function is primarily for backward compatibility and may be removed in the future.
"""
def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]:
"""
Convert the given object into a QConfigMapping if possible, else throw an exception.
"""
if isinstance(obj, QConfigMapping) or obj is None:
return obj
if isinstance(obj, Dict):
return QConfigMapping.from_dict(obj)
raise ValueError(
f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
)
def _get_prepare_custom_config(
obj: Any, dict_key: str
) -> Optional[PrepareCustomConfig]:
"""
Convert the given object into a PrepareCustomConfig if possible, else throw an exception.
"""
if isinstance(obj, PrepareCustomConfig) or obj is None:
return obj
if isinstance(obj, Dict):
return PrepareCustomConfig.from_dict(obj)
raise ValueError(
f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
)
def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]:
"""
Convert the given object into a BackendConfig if possible, else throw an exception.
"""
if isinstance(obj, BackendConfig) or obj is None:
return obj
if isinstance(obj, Dict):
return BackendConfig.from_dict(obj)
raise ValueError(
f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
)
conf = cls()
for (
module_name,
qconfig_dict,
example_inputs,
_prepare_custom_config_dict,
backend_config_dict,
) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []):
qconfig_mapping = _get_qconfig_mapping(
qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY
)
prepare_custom_config = _get_prepare_custom_config(
_prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
)
backend_config = _get_backend_config(
backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
)
conf.set_standalone_module_name(
module_name,
qconfig_mapping,
example_inputs,
prepare_custom_config,
backend_config,
)
for (
module_class,
qconfig_dict,
example_inputs,
_prepare_custom_config_dict,
backend_config_dict,
) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []):
qconfig_mapping = _get_qconfig_mapping(
qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY
)
prepare_custom_config = _get_prepare_custom_config(
_prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
)
backend_config = _get_backend_config(
backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
)
conf.set_standalone_module_class(
module_class,
qconfig_mapping,
example_inputs,
prepare_custom_config,
backend_config,
)
for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(
FLOAT_TO_OBSERVED_DICT_KEY, {}
).items():
quant_type = _quant_type_from_str(quant_type_name)
for float_class, observed_class in custom_module_mapping.items():
conf.set_float_to_observed_mapping(
float_class, observed_class, quant_type
)
conf.set_non_traceable_module_names(
prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, [])
)
conf.set_non_traceable_module_classes(
prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, [])
)
conf.set_input_quantized_indexes(
prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, [])
)
conf.set_output_quantized_indexes(
prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, [])
)
conf.set_preserved_attributes(
prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
)
return conf
def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
"""
def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None
prepare_custom_config_dict = (
e.prepare_custom_config.to_dict() if e.prepare_custom_config else None
)
return (
key,
qconfig_dict,
e.example_inputs,
prepare_custom_config_dict,
e.backend_config,
)
d: Dict[str, Any] = {}
for module_name, sm_config_entry in self.standalone_module_names.items():
if STANDALONE_MODULE_NAME_DICT_KEY not in d:
d[STANDALONE_MODULE_NAME_DICT_KEY] = []
d[STANDALONE_MODULE_NAME_DICT_KEY].append(
_make_tuple(module_name, sm_config_entry)
)
for module_class, sm_config_entry in self.standalone_module_classes.items():
if STANDALONE_MODULE_CLASS_DICT_KEY not in d:
d[STANDALONE_MODULE_CLASS_DICT_KEY] = []
d[STANDALONE_MODULE_CLASS_DICT_KEY].append(
_make_tuple(module_class, sm_config_entry)
)
for (
quant_type,
float_to_observed_mapping,
) in self.float_to_observed_mapping.items():
if FLOAT_TO_OBSERVED_DICT_KEY not in d:
d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
d[FLOAT_TO_OBSERVED_DICT_KEY][
_get_quant_type_to_str(quant_type)
] = float_to_observed_mapping
if len(self.non_traceable_module_names) > 0:
d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
if len(self.non_traceable_module_classes) > 0:
d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes
if len(self.input_quantized_indexes) > 0:
d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes
if len(self.output_quantized_indexes) > 0:
d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes
if len(self.preserved_attributes) > 0:
d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
return d
class ConvertCustomConfig:
"""
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
Example usage::
convert_custom_config = ConvertCustomConfig() \
.set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
.set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self) -> None:
self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
self.preserved_attributes: List[str] = []
def __repr__(self):
dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
return f"ConvertCustomConfig({dict_nonempty})"
def set_observed_to_quantized_mapping(
self,
observed_class: Type,
quantized_class: Type,
quant_type: QuantType = QuantType.STATIC,
) -> ConvertCustomConfig:
"""
Set the mapping from a custom observed module class to a custom quantized module class.
The quantized module class must have a ``from_observed`` class method that converts the observed module class
to the quantized module class.
"""
if quant_type not in self.observed_to_quantized_mapping:
self.observed_to_quantized_mapping[quant_type] = {}
self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class
return self
def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig:
"""
Set the names of the attributes that will persist in the graph module even if they are not used in
the model's ``forward`` method.
"""
self.preserved_attributes = attributes
return self
# TODO: remove this
@classmethod
def from_dict(
cls, convert_custom_config_dict: Dict[str, Any]
) -> ConvertCustomConfig:
"""
Create a ``ConvertCustomConfig`` from a dictionary with the following items:
"observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
mode to an inner mapping from observed module classes to quantized module classes, e.g.::
{
"static": {FloatCustomModule: ObservedCustomModule},
"dynamic": {FloatCustomModule: ObservedCustomModule},
"weight_only": {FloatCustomModule: ObservedCustomModule}
}
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
This function is primarily for backward compatibility and may be removed in the future.
"""
conf = cls()
for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(
OBSERVED_TO_QUANTIZED_DICT_KEY, {}
).items():
quant_type = _quant_type_from_str(quant_type_name)
for observed_class, quantized_class in custom_module_mapping.items():
conf.set_observed_to_quantized_mapping(
observed_class, quantized_class, quant_type
)
conf.set_preserved_attributes(
convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
)
return conf
def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
"""
d: Dict[str, Any] = {}
for (
quant_type,
observed_to_quantized_mapping,
) in self.observed_to_quantized_mapping.items():
if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
d[OBSERVED_TO_QUANTIZED_DICT_KEY][
_get_quant_type_to_str(quant_type)
] = observed_to_quantized_mapping
if len(self.preserved_attributes) > 0:
d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
return d
class FuseCustomConfig:
"""
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
Example usage::
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
"""
def __init__(self) -> None:
self.preserved_attributes: List[str] = []
def __repr__(self):
dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
return f"FuseCustomConfig({dict_nonempty})"
def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig:
"""
Set the names of the attributes that will persist in the graph module even if they are not used in
the model's ``forward`` method.
"""
self.preserved_attributes = attributes
return self
# TODO: remove this
@classmethod
def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig:
"""
Create a ``ConvertCustomConfig`` from a dictionary with the following items:
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
This function is primarily for backward compatibility and may be removed in the future.
"""
conf = cls()
conf.set_preserved_attributes(
fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
)
return conf
def to_dict(self) -> Dict[str, Any]:
"""
Convert this ``FuseCustomConfig`` to a dictionary with the items described in
:func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
"""
d: Dict[str, Any] = {}
if len(self.preserved_attributes) > 0:
d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
return d

View File

@ -0,0 +1,191 @@
# mypy: allow-untyped-defs
import warnings
from typing import Any, Callable, Dict, List, Tuple, Union
from torch.ao.quantization.backend_config import (
BackendConfig,
get_native_backend_config,
)
from torch.ao.quantization.backend_config.utils import (
get_fuser_method_mapping,
get_fusion_pattern_to_extra_inputs_getter,
get_fusion_pattern_to_root_node_getter,
)
from torch.ao.quantization.utils import NodePattern, Pattern
from torch.fx import GraphModule, map_arg, Node
from torch.fx.graph import Graph
from .custom_config import FuseCustomConfig
from .fuse_handler import _get_fusion_pattern_to_fuse_handler_cls, FuseHandler
from .match_utils import _is_match, MatchAllNode
from .pattern_utils import _sorted_patterns_dict
__all__ = [
"fuse",
# TODO: We should make this private in the future
# This is currently needed for test_public_bindings for some reason
"FuseHandler",
]
def fuse(
model: GraphModule,
is_qat: bool,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
if fuse_custom_config is None:
fuse_custom_config = FuseCustomConfig()
if isinstance(fuse_custom_config, dict):
warnings.warn(
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
"in a future version. Please pass in a FuseCustomConfig instead.",
FutureWarning,
stacklevel=2,
)
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
if isinstance(backend_config, dict):
warnings.warn(
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
"in a future version. Please pass in a BackendConfig instead.",
FutureWarning,
stacklevel=2,
)
backend_config = BackendConfig.from_dict(backend_config)
named_modules = dict(model.named_modules())
if backend_config is None:
backend_config = get_native_backend_config()
fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(
_get_fusion_pattern_to_fuse_handler_cls(backend_config)
)
fuser_method_mapping = get_fuser_method_mapping(backend_config)
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(
backend_config
)
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(
backend_config
)
# find fusion
fusion_pairs = _find_matches(model, model.graph, fusion_pattern_to_fuse_handler_cls)
# TODO: change this to inplace changes to graph, since we no longer construct
# new GraphModule anymore
fused_graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
def default_root_node_getter(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
for node in model.graph.nodes:
(
maybe_last_node,
pattern,
matched_node_pattern,
obj,
node_to_subpattern,
) = fusion_pairs.get(node.name, (None, None, None, None, None))
# get the corresponding subpattern for the current node
if node_to_subpattern is not None:
node_subpattern = node_to_subpattern.get(node, None)
else:
node_subpattern = None
if maybe_last_node is node:
assert obj is not None
root_node_getter = fusion_pattern_to_root_node_getter.get(
pattern, default_root_node_getter
)
root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(
pattern, None
)
extra_inputs = []
if extra_inputs_getter is not None:
extra_inputs = extra_inputs_getter(matched_node_pattern)
# TODO: add validation that root_node is a module and has the same type
# as the root_module in the configuration
env[node.name] = obj.fuse(
load_arg,
named_modules,
fused_graph,
root_node,
extra_inputs,
matched_node_pattern, # type: ignore[arg-type]
fuse_custom_config,
fuser_method_mapping,
is_qat,
)
elif maybe_last_node is None or node_subpattern is MatchAllNode:
env[node.name] = fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here
model = GraphModule(model, fused_graph)
return model
def _find_matches(
root: GraphModule,
graph: Graph,
pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
modules = dict(root.named_modules())
# node name -> (root_node, match_value)
match_map: Dict[
str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]
] = {}
# a map from node to the matched subpattern
node_to_subpattern: Dict[Node, Any] = {}
# TODO: dedup with quantization matching function in match_utils.py
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
if isinstance(pattern, tuple):
s, *args = pattern
current_node_pattern: List[Node] = []
apply_match(s, node, match, current_node_pattern, node_to_subpattern)
for subpattern, arg in zip(args, node.args):
apply_match(
subpattern, arg, match, current_node_pattern, node_to_subpattern
)
matched_node_pattern.append(tuple(current_node_pattern))
else:
# the first pattern matches will take precedence
if node.name not in match_map:
matched_node_pattern.append(node)
# MatchAllNode here is actually MatchAllInputNode which should not
# be added to match_map
if pattern is not MatchAllNode:
node_to_subpattern[node] = pattern
root_node, pattern, handler = match
match_map[node.name] = (
root_node,
pattern,
matched_node_pattern,
handler,
node_to_subpattern,
)
for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
matched_node_pattern: List[Node] = []
if _is_match(modules, node, pattern):
apply_match(
pattern,
node,
(node, pattern, fuse_handler_cls(node)),
matched_node_pattern,
node_to_subpattern,
)
break
return match_map

View File

@ -0,0 +1,132 @@
# mypy: allow-untyped-defs
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Union
import torch
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.fuser_method_mappings import get_fuser_method_new
from torch.ao.quantization.utils import _parent_name, NodePattern, Pattern
from torch.fx.graph import Graph, Node
from torch.nn.utils.parametrize import type_before_parametrizations
from .custom_config import FuseCustomConfig
from .match_utils import MatchAllNode
__all__ = [
"DefaultFuseHandler",
"FuseHandler",
]
# ----------------------------
# Fusion Pattern Registrations
# ----------------------------
# Base Pattern Handler
class FuseHandler(ABC):
"""Base handler class for the fusion patterns"""
@abstractmethod
def __init__(self, node: Node):
pass
@abstractmethod
def fuse(
self,
load_arg: Callable,
named_modules: Dict[str, torch.nn.Module],
fused_graph: Graph,
root_node: Node,
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool,
) -> Node:
pass
class DefaultFuseHandler(FuseHandler):
def __init__(self, node: Node):
super().__init__(node) # type:ignore[safe-super]
def fuse(
self,
load_arg: Callable,
named_modules: Dict[str, torch.nn.Module],
fused_graph: Graph,
root_node: Node,
extra_inputs: List[Any],
matched_node_pattern: NodePattern,
fuse_custom_config: FuseCustomConfig,
fuser_method_mapping: Dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool,
) -> Node:
assert (
root_node.op == "call_module"
), "Expecting module node to be a call_module Node"
root_module = named_modules[str(root_node.target)]
def get_modules(pattern):
"""Given a node pattern, extract the corresponding modules
e.g. input: (relu_node, (bn_node, conv_node))
output: (relu_module, (bn_module, conv_module))
"""
if isinstance(pattern, (tuple, list)):
n, *args = pattern
modules: List[torch.nn.Module] = []
modules.append(get_modules(n))
for a in args:
modules.append(get_modules(a))
return tuple(modules)
else:
n = pattern
if n.op == "call_module":
return named_modules[n.target]
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
relu = torch.nn.ReLU()
relu.training = root_module.training
return relu
elif n.op == "call_function" or n.op == "call_method":
return n.target
else:
return MatchAllNode
# since relu can be used multiple times, we'll need to create a relu module for each match
matched_modules = get_modules(matched_node_pattern)
def get_matched_types(m):
if isinstance(m, tuple):
return tuple(map(get_matched_types, m))
if isinstance(m, torch.nn.Module):
return type_before_parametrizations(m)
return m
matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(is_qat, *matched_modules)
setattr(named_modules[module_parent_name], module_name, fused_module)
extra_args = []
for input in extra_inputs:
extra_args.append(load_arg(input))
node = fused_graph.node_copy(root_node, load_arg)
args = list(node.args)
args.extend(extra_args)
node.args = tuple(args)
return node
def _get_fusion_pattern_to_fuse_handler_cls(
backend_config: BackendConfig,
) -> Dict[Pattern, Callable]:
fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
for pattern, config in backend_config._pattern_complex_format_to_config.items():
if config.fuser_method is not None:
# TODO: is this logic right?
fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
return fusion_pattern_to_fuse_handlers

View File

@ -0,0 +1,203 @@
# mypy: allow-untyped-defs
import copy
from typing import Any, Dict, Set, Union
import torch
from torch.fx import GraphModule
from torch.fx.graph import Graph
__all__ = [
"FusedGraphModule",
"ObservedGraphModule",
"ObservedStandaloneGraphModule",
"QuantizedGraphModule",
]
class FusedGraphModule(GraphModule):
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
preserved_attr_names: Set[str],
):
self.preserved_attr_names = preserved_attr_names
preserved_attrs = {
attr: getattr(root, attr)
for attr in self.preserved_attr_names
if hasattr(root, attr)
}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return FusedGraphModule(
fake_mod,
copy.deepcopy(self.graph),
copy.deepcopy(self.preserved_attr_names),
)
class ObservedGraphModule(GraphModule):
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
preserved_attr_names: Set[str],
):
self.preserved_attr_names = {
"_activation_post_process_map",
"_activation_post_process_indexes",
"_patterns",
"_node_name_to_qconfig",
"_prepare_custom_config",
"_equalization_node_name_to_qconfig",
"_node_name_to_scope",
"_qconfig_mapping",
"_is_qat",
"_observed_node_names",
}.union(preserved_attr_names)
preserved_attrs = {
attr: getattr(root, attr)
for attr in self.preserved_attr_names
if hasattr(root, attr)
}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(
fake_mod,
copy.deepcopy(self.graph),
copy.deepcopy(self.preserved_attr_names),
)
def _is_observed_module(module: Any) -> bool:
return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta
def _get_observed_graph_module_attr(
model: Union[torch.nn.Module, GraphModule], attr_name: str
) -> Any:
if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index]
return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index]
return None
class ObservedStandaloneGraphModule(ObservedGraphModule):
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
preserved_attr_names: Set[str],
):
preserved_attr_names = preserved_attr_names.union(
{
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs",
}
)
super().__init__(root, graph, preserved_attr_names)
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(
fake_mod,
copy.deepcopy(self.graph),
copy.deepcopy(self.preserved_attr_names),
)
def _is_observed_standalone_module(module: Any) -> bool:
return (
_is_observed_module(module)
and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
)
def _save_packed_weight(self, destination, prefix, keep_vars):
for attr_name in dir(self):
if "_packed_weight" in attr_name and isinstance(
getattr(self, attr_name), torch._C.ScriptObject
): # type: ignore[attr-defined]
packed_weight = getattr(self, attr_name)
destination[prefix + attr_name] = packed_weight
class QuantizedGraphModule(GraphModule):
"""This class is created to make sure PackedParams
(e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
so that we can serialize and deserialize quantized graph module with
torch.save(m.state_dict()) and m.load_state_dict(state_dict)
"""
def __init__(
self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: Graph,
preserved_attr_names: Set[str],
):
self.preserved_attr_names = preserved_attr_names
preserved_attrs = {
attr: getattr(root, attr)
for attr in self.preserved_attr_names
if hasattr(root, attr)
}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
self._register_state_dict_hook(_save_packed_weight)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
attrs_to_pop = []
for attr_name in state_dict:
if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950
setattr(self, attr_name, state_dict[attr_name])
attrs_to_pop.append(attr_name)
# pop the packed param attributesn
for attr_name in attrs_to_pop:
state_dict.pop(attr_name)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return QuantizedGraphModule(
fake_mod,
copy.deepcopy(self.graph),
copy.deepcopy(self.preserved_attr_names),
)

View File

@ -0,0 +1,20 @@
from typing import Dict, Tuple
from torch.ao.quantization.qconfig import QConfigAny
from torch.fx import GraphModule
from ._lower_to_native_backend import _lower_to_native_backend
__all__ = ["lower_to_fbgemm"]
def lower_to_fbgemm(
model: GraphModule,
qconfig_map: Dict[str, QConfigAny],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> GraphModule:
"""Lower a quantized reference model (with reference quantized operator patterns)
to fbgemm
"""
return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)

View File

@ -0,0 +1,20 @@
from typing import Dict, Tuple
from torch.ao.quantization.qconfig import QConfigAny
from torch.fx import GraphModule
from ._lower_to_native_backend import _lower_to_native_backend
__all__ = ["lower_to_qnnpack"]
def lower_to_qnnpack(
model: GraphModule,
qconfig_map: Dict[str, QConfigAny],
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> GraphModule:
"""Lower a quantized reference model (with reference quantized operator patterns)
to qnnpack
"""
return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)

View File

@ -0,0 +1,202 @@
import copy
import operator
from typing import Any, Callable, Optional, Tuple
import torch
from torch.ao.quantization import (
default_weight_fake_quant,
default_weight_observer,
FakeQuantizeBase,
QConfig,
QConfigMapping,
)
from torch.ao.quantization.backend_config import BackendConfig
from torch.ao.quantization.observer import _PartialWrapper
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
# TODO: move all LSTM util functions from fx/utils.py to this file
def _get_lstm_with_individually_observed_parts(
float_lstm: torch.nn.LSTM,
example_inputs: Tuple[Any, ...],
backend_config: Optional[BackendConfig] = None,
linear_output_obs_ctr: Optional[_PartialWrapper] = None,
sigmoid_obs_ctr: Optional[_PartialWrapper] = None,
tanh_obs_ctr: Optional[_PartialWrapper] = None,
cell_state_obs_ctr: Optional[_PartialWrapper] = None,
hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
) -> torch.ao.nn.quantizable.LSTM:
"""
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
with specific observers or fake quantizes assigned to the inner ops or submodules.
In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
used as an observed custom module, which is responsible for inserting its own
observers. By default, all inner ops inherit the parent custom module's QConfig.
Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
and use this helper function to customize the observer insertion logic.
This is meant to be used to convert a float module to an observed module in the
custom module flow.
Args:
`float_lstm`: The float LSTM module
`example_inputs`: example inputs for the forward function of the LSTM module
`backend_config`: BackendConfig to use to observe the LSTM module
`linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
where W is the weight matrix, b is the bias, and x is either the inputs
or the hidden state from the previous layer (if any)
`sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
`tanh_obs_ctr`: observer or fake quantize for tanh activations
`cell_state_obs_ctr`: observer or fake quantize for the cell state
`hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
the output
Return:
A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
assigned to the inner ops.
"""
def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig:
"""
Make a QConfig with fixed qparams observers or fake quantizes.
"""
if isinstance(obs_ctr(), FakeQuantizeBase):
weight = default_weight_fake_quant
else:
weight = default_weight_observer
return QConfig(activation=obs_ctr, weight=weight)
quantizable_lstm = torch.ao.nn.quantizable.LSTM(
float_lstm.input_size,
float_lstm.hidden_size,
float_lstm.num_layers,
float_lstm.bias,
float_lstm.batch_first,
float_lstm.dropout,
float_lstm.bidirectional,
)
quantizable_lstm.qconfig = float_lstm.qconfig
for idx in range(float_lstm.num_layers):
quantizable_lstm.layers[
idx
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
float_lstm, idx, float_lstm.qconfig, batch_first=False
)
# Build QConfigMapping for the LSTM cell
# Note: FloatFunctional qconfigs will be configured separately below
cell_qm = QConfigMapping().set_global(float_lstm.qconfig) # type: ignore[arg-type]
if sigmoid_obs_ctr is not None:
cell_qm.set_module_name("input_gate", make_qconfig(sigmoid_obs_ctr))
cell_qm.set_module_name("forget_gate", make_qconfig(sigmoid_obs_ctr))
cell_qm.set_module_name("output_gate", make_qconfig(sigmoid_obs_ctr))
if tanh_obs_ctr is not None:
cell_qm.set_module_name("cell_gate", make_qconfig(tanh_obs_ctr))
# Insert observers into each LSTM cell
# TODO: maybe make this work for layer_bw as well
for layer in quantizable_lstm.layers:
cell = layer.layer_fw.cell
cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
# HACK: Manually replace the activation_post_process following these ops.
# This is needed for FloatFunctional ops because there is currently no way
# to configure these ops in FX graph mode quantization today. This is because
# the FloatFunctional modules simply disappear from the graph after tracing.
# In the future, we should rewrite quantizable LSTM without FloatFunctionals.
op_index_to_activation_post_process_ctr = {
(torch.add, 0): linear_output_obs_ctr, # gates.add
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
}
add_count = 0
mul_count = 0
for node in cell.graph.nodes:
op_index: Optional[Tuple[Callable, int]] = None # e.g. (torch.add, 1)
if node.target == torch.add:
op_index = (torch.add, add_count)
add_count += 1
elif node.target == torch.mul:
op_index = (torch.mul, mul_count)
mul_count += 1
else:
# Neither torch.add nor torch.mul
continue
if op_index not in op_index_to_activation_post_process_ctr:
continue
assert len(node.users) == 1
activation_post_process_name = next(iter(node.users.keys())).name
activation_post_process_ctr = op_index_to_activation_post_process_ctr[
op_index
]
if activation_post_process_ctr is not None:
setattr(
cell, activation_post_process_name, activation_post_process_ctr()
)
layer.layer_fw.cell = cell
return quantizable_lstm
def _get_reference_quantized_lstm_module(
observed_lstm: torch.ao.nn.quantizable.LSTM,
backend_config: Optional[BackendConfig] = None,
) -> torch.ao.nn.quantized.LSTM:
"""
Return a `torch.ao.nn.quantized.LSTM` created from a `torch.ao.nn.quantizable.LSTM`
with observers or fake quantizes inserted through `prepare_fx`, e.g. from
`_get_lstm_with_individually_observed_parts`.
This is meant to be used to convert an observed module to a quantized module in the
custom module flow.
Args:
`observed_lstm`: a `torch.ao.nn.quantizable.LSTM` observed through `prepare_fx`
`backend_config`: BackendConfig to use to produce the reference quantized model
Return:
A reference `torch.ao.nn.quantized.LSTM` module.
"""
quantized_lstm = torch.ao.nn.quantized.LSTM(
observed_lstm.input_size,
observed_lstm.hidden_size,
observed_lstm.num_layers,
observed_lstm.bias,
observed_lstm.batch_first,
observed_lstm.dropout,
observed_lstm.bidirectional,
)
for i, layer in enumerate(quantized_lstm.layers):
cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr]
cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type]
assert isinstance(cell, torch.fx.GraphModule)
# HACK: Manually remove input quantize nodes and output dequantize nodes,
# since custom modules expect quint8 inputs and outputs for now. Note that
# this functionality is supposedly handled through PrepareCustomConfig's
# `set_input_quantized_indexes` and `set_output_quantized_indexes`, but that
# API doesn't currently handle tuple inputs and outputs, so we have to do
# this manually for now. In the future we should (1) relax the restriction
# on custom module input/output dtypes, and (2) expand support for complex
# input/output structures.
for node in cell.graph.nodes:
if node.target == torch.quantize_per_tensor:
arg = node.args[0]
# Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1])
if arg.target == "x" or (
arg.target == operator.getitem and arg.args[0].target == "hidden"
):
with cell.graph.inserting_before(node):
node.replace_all_uses_with(arg)
cell.graph.erase_node(node)
if node.target == "output":
# Remove all dequantize nodes in the output tuple
for arg in node.args[0]:
with cell.graph.inserting_before(node):
node.replace_input_with(arg, arg.args[0])
cell.graph.eliminate_dead_code()
cell.recompile()
layer.layer_fw.cell = cell
return quantized_lstm

View File

@ -0,0 +1,227 @@
# mypy: allow-untyped-defs
import sys
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
import torch
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.utils import MatchAllNode, Pattern
from torch.fx.graph import Graph, Node
from torch.nn.utils.parametrize import type_before_parametrizations
from .graph_module import _is_observed_standalone_module
from .quantize_handler import QuantizeHandler
__all__: List[str] = []
# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
_MatchResultWithQConfig = Tuple[
Node, List[Node], Optional[Pattern], QuantizeHandler, QConfigAny
]
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def _is_match(modules, node, pattern, max_uses=sys.maxsize):
"""Matches a node in fx against a pattern"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, "Expecting getattr pattern to have two elements"
arg_matches = []
else:
self_match = pattern
arg_matches = []
if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
return True
if node == pattern:
return True
if not isinstance(node, Node) or len(node.users) > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != "call_module":
return False
if not type_before_parametrizations(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != "call_function" or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif isinstance(self_match, str):
if node.op != "call_method" or node.target != self_match:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(
_is_match(modules, node, arg_match, max_uses=1)
for node, arg_match in zip(node.args, arg_matches)
)
def _find_matches(
graph: Graph,
modules: Dict[str, torch.nn.Module],
patterns: Dict[Pattern, QuantizeHandler],
root_node_getter_mapping: Dict[Pattern, Callable],
standalone_module_names: Optional[List[str]] = None,
standalone_module_classes: Optional[List[Type]] = None,
custom_module_classes: Optional[List[Any]] = None,
) -> Dict[str, _MatchResult]:
"""
Matches the nodes in the input graph to quantization patterns, and
outputs the information needed to quantize them in future steps.
Inputs:
- graph: an fx.Graph object
- modules: a mapping of fully qualified module name to instance,
for example, {'foo': ModuleFoo, ...}
- patterns: a mapping from a tuple of nodes in reverse order to
uninitialized QuantizeHandler subclass.
Outputs a map of
node_name ->
(node, matched_values, matched_pattern, QuantizeHandler instance,
qconfig)
For example, {
'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
<CopyNodeQuantizeHandler instance>, QConfig(...)),
...
}
"""
if custom_module_classes is None:
custom_module_classes = []
if standalone_module_classes is None:
standalone_module_classes = []
if standalone_module_names is None:
standalone_module_names = []
match_map: Dict[str, _MatchResult] = {}
all_matched: Set[str] = set()
def _recursive_record_node_in_match_map(
last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value
):
if isinstance(node_pattern, Node):
match_map[node_pattern.name] = (
last_node,
matched_node_pattern,
pattern,
match_value,
)
elif not isinstance(node_pattern, Iterable):
return
else:
for n in node_pattern:
_recursive_record_node_in_match_map(
last_node, match_map, n, matched_node_pattern, pattern, match_value
)
# TODO: 1. merge with fuse matcher 2. document the code
def record_match(pattern, node, last_node, matched_node_pattern, match_map):
if isinstance(pattern, tuple):
s, *args = pattern
is_single_arg = len(args) == 1
current_node_pattern: List[Node] = []
record_match(s, node, last_node, matched_node_pattern, match_map)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, node, current_node_pattern, match_map)
if len(current_node_pattern) > 1:
# current_node_pattern is the node pattern we get from matching
# the subpattern with arguments of the node
# we use is_single_arg to recover the original structure of the pattern
# if the original pattern has a single argument, we will have
# (original_op, (original_arg, ...))
# otherwise, we'll have a list of arguments
# (original_op, arg0, arg1, arg2, ...)
if is_single_arg:
matched_node_pattern.append(tuple(current_node_pattern))
else:
matched_node_pattern.extend(list(current_node_pattern))
else:
matched_node_pattern.append(current_node_pattern[0])
else:
matched_node_pattern.append(node)
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, quantize_handler_cls in patterns.items():
root_node_getter = root_node_getter_mapping.get(pattern, None)
if _is_match(modules, node, pattern) and node.name not in match_map:
matched_node_pattern: List[Node] = []
record_match(pattern, node, node, matched_node_pattern, match_map)
quantize_handler = quantize_handler_cls( # type: ignore[operator]
matched_node_pattern, modules, root_node_getter
)
last_node = node
# record the match for all nodes in the pattern
_recursive_record_node_in_match_map(
last_node,
match_map,
# we need to record all nodes in the matched pattern in the match_map
matched_node_pattern,
# this is a part of the value corresponding to the node
matched_node_pattern,
pattern,
quantize_handler,
)
break
# add custom module instances to the match result
assert modules is not None
for node in graph.nodes:
if (
node.op == "call_module"
and type(modules[node.target]) in custom_module_classes
):
match_map[node.name] = (
node,
node,
None,
QuantizeHandler(node, modules, is_custom_module=True),
)
def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
assert modules is not None
return (
node_target in standalone_module_names
or type(modules[node_target]) # type: ignore[operator]
in standalone_module_classes # type: ignore[operator]
)
# add standalone modules to the match
for node in graph.nodes:
if node.op == "call_module" and (
is_standalone_module(node.target, modules)
or _is_observed_standalone_module(modules[node.target])
):
# add node to matched nodes
match_map[node.name] = (
node,
node,
None,
QuantizeHandler(node, modules, is_standalone_module=True),
)
return match_map

View File

@ -0,0 +1,112 @@
# mypy: allow-untyped-defs
import copy
from collections import OrderedDict
from typing import Any, Dict
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.utils import Pattern
__all__ = [
"get_default_fusion_patterns",
"get_default_quant_patterns",
"get_default_output_activation_post_process_map",
]
# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency)
QuantizeHandler = Any
# pattern for conv bn fusion
_DEFAULT_FUSION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
def _register_fusion_pattern(pattern):
def insert(fn):
_DEFAULT_FUSION_PATTERNS[pattern] = fn
return fn
return insert
def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]:
return copy.copy(_DEFAULT_FUSION_PATTERNS)
_DEFAULT_QUANTIZATION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict()
# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation
# e.g. pattern: torch.sigmoid,
# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant
_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: Dict[Pattern, QuantizeHandler] = {}
_DEFAULT_OUTPUT_OBSERVER_MAP: Dict[Pattern, QuantizeHandler] = {}
# Register pattern for both static quantization and qat
def _register_quant_pattern(pattern, fixed_qparams_observer=None):
def insert(fn):
_DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
if fixed_qparams_observer is not None:
_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[
pattern
] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer)
_DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer
return fn
return insert
# Get patterns for both static quantization and qat
def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]:
return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS)
# a map from pattern to output activation post process constructor
# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant
def get_default_output_activation_post_process_map(
is_training,
) -> Dict[Pattern, ObserverBase]:
if is_training:
return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP)
else:
return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP)
# Example use of register pattern function:
# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvOrLinearBNReLUFusion():
# def __init__(...):
# ...
#
def _sorted_patterns_dict(
patterns_dict: Dict[Pattern, QuantizeHandler]
) -> Dict[Pattern, QuantizeHandler]:
"""
Return a sorted version of the patterns dictionary such that longer patterns are matched first,
e.g. match (F.relu, F.linear) before F.relu.
This works for current use cases, but we may need to have a more clever way to sort
things to address more complex patterns
"""
def get_len(pattern):
"""this will calculate the length of the pattern by counting all the entries
in the pattern.
this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before
(nn.BatchNorm, nn.Conv2d) so that we can match the former first
"""
len = 0
if isinstance(pattern, tuple):
for item in pattern:
len += get_len(item)
else:
len += 1
return len
return OrderedDict(
sorted(
patterns_dict.items(),
key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1,
)
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,398 @@
# mypy: allow-untyped-defs
import re
from collections import defaultdict, OrderedDict
from typing import Any, Callable, Dict, List, Set, Tuple, Union
import torch
from torch.ao.nn.intrinsic import _FusedModule
from torch.ao.quantization import QConfig
from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig
from torch.ao.quantization.backend_config.utils import get_module_to_qat_module
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.qconfig import (
_add_module_to_qconfig_obs_ctr,
qconfig_equals,
QConfigAny,
)
from torch.ao.quantization.qconfig_mapping import (
_MODULE_NAME_DICT_KEY,
_MODULE_NAME_REGEX_DICT_KEY,
_OBJECT_TYPE_DICT_KEY,
QConfigMapping,
)
from torch.ao.quantization.utils import _parent_name, get_qconfig_dtypes
from torch.fx import GraphModule
from torch.fx.graph import Graph
__all__: List[str] = []
def _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping: QConfigMapping,
cur_module_path: str,
cur_object_type: Callable,
cur_object_type_idx: int,
fallback_qconfig: QConfigAny,
) -> QConfigAny:
for (
module_name,
object_type,
index,
), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
if (
(module_name == cur_module_path)
and (object_type == cur_object_type)
and (index == cur_object_type_idx)
):
return qconfig
return fallback_qconfig
def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
"""
Update the QConfigMapping to account for fused modules such as LinearReLU.
This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
"""
object_type_dict = qconfig_mapping.object_type_qconfigs
if len(object_type_dict) == 0:
return qconfig_mapping
modules = dict(model.named_modules())
for node in model.graph.nodes:
if node.op == "call_module" and node.target in modules:
maybe_fused_module = modules[str(node.target)]
if not isinstance(maybe_fused_module, _FusedModule):
continue
ops = list(maybe_fused_module._modules.values())
fused_qconfig = object_type_dict.get(type(ops[0]), None)
# Raise an error if the modules in the fused module have
# different qconfigs specified in the qconfig_dict
# TODO: currently it only works for modules,
# need to make this work for torch.nn.functional.relu
# TODO: currently it only works for object_type configurations,
# ideally it should work for different types of configurations,
# maybe we want to redesign this part
for op in ops[1:]:
if not qconfig_equals(
object_type_dict.get(type(op), None), fused_qconfig
):
raise LookupError(
"During fusion, we need to specify the same "
+ f"qconfigs for all module types in {type(maybe_fused_module)} "
+ f"offending type: {type(op)}"
)
if fused_qconfig is not None:
object_type_dict[type(maybe_fused_module)] = fused_qconfig
def _generate_node_name_to_qconfig(
root: torch.nn.Module,
modules: Dict[str, torch.nn.Module],
input_graph: Graph,
qconfig_mapping: QConfigMapping,
node_name_to_scope: Dict[str, Tuple[str, type]],
) -> Dict[str, QConfigAny]:
global_qconfig = qconfig_mapping.global_qconfig
node_name_to_qconfig = {}
# example:
#
# {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
#
# meaning in submodule 'foo.bar', we have seen 0 F.linear and
# 1 F.conv2d invocations so far.
submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = defaultdict(
lambda: defaultdict(int)
)
for node in input_graph.nodes:
qconfig = None
if node.op == "get_attr":
module_name, _ = _parent_name(node.target)
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[module_name]), module_name, global_qconfig
)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
qconfig, modules.get(node.target, None)
)
elif node.op == "call_function":
# precedence: module_name_qconfig
# > function_qconfig > global_qconfig
# module_name takes precedence over function qconfig
function_qconfig = _get_object_type_qconfig(
qconfig_mapping, node.target, global_qconfig
)
module_path, module_type = node_name_to_scope[node.name]
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, module_type, module_path, function_qconfig
)
cur_object_type_idx = submodule_to_object_type_to_cur_idx[module_path][
node.target
]
submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig
)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
qconfig, modules.get(node.target, None)
)
elif node.op == "call_method":
module_path, module_type = node_name_to_scope[node.name]
# first use node.target (string) to get the qconfig
# this is to support configs like
# "object_type": [("reshape", qconfig)]
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, node.target, module_path, global_qconfig
)
# if there is no special config for the method, we'll fall back to the
# config for the module that contains the call_method node
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, module_type, module_path, qconfig
)
# currently call_method does not support modifying qconfig
# by order, we can add this later if it is needed.
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
qconfig, modules.get(node.target, None)
)
elif node.op == "call_module":
# if the node is an observer, just continue - don't add it to the qconfig_map
if _is_activation_post_process(modules[node.target]):
continue
qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, type(modules[node.target]), node.target, global_qconfig
)
module_path, module_type = node_name_to_scope[node.name]
# Note: for call_module, the module_path is the current module's name.
# to meaningfully count invocations, we need to count them in the parent
# module.
parent_name, _ = _parent_name(module_path)
cur_object_type_idx = submodule_to_object_type_to_cur_idx[parent_name][
module_type
]
submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig
)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(
qconfig, modules.get(node.target, None)
)
# regex is not supported eager mode propagate_qconfig_, we'll
# need to set the qconfig explicitly here in case regex
# is used
modules[node.target].qconfig = qconfig_with_device_check
else:
qconfig_with_device_check = None
node_name_to_qconfig[node.name] = qconfig_with_device_check
return node_name_to_qconfig
def _check_is_valid_config_dict(
config_dict: Any, allowed_keys: Set[str], dict_name: str
) -> None:
r"""Checks if the given config_dict has the correct keys
Args:
`config_dict`: dictionary whose keys we want to check
"""
for k in config_dict.keys():
if k not in allowed_keys:
raise ValueError(
"Expected "
+ dict_name
+ " to have the following keys: "
+ str(allowed_keys)
+ ". But found '"
+ k
+ "' instead."
)
def _compare_prepare_convert_qconfig_mappings(
prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping
):
r"""Compare the qconfig_mapping passed in convert to the one from prepare and check the values
Args:
`prepare_qconfig_mapping`: configuration for prepare quantization step
`convert_qconfig_mapping`: configuration for convert quantization step
"""
assert qconfig_equals(
prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
), "Expected global qconfigs to be the same in the prepare and convert quantization configs"
prepare_dicts: List[OrderedDict] = [
prepare_qconfig_mapping.object_type_qconfigs,
prepare_qconfig_mapping.module_name_qconfigs,
prepare_qconfig_mapping.module_name_regex_qconfigs,
]
convert_dicts: List[OrderedDict] = [
convert_qconfig_mapping.object_type_qconfigs,
convert_qconfig_mapping.module_name_qconfigs,
convert_qconfig_mapping.module_name_regex_qconfigs,
]
dict_names = [
_OBJECT_TYPE_DICT_KEY,
_MODULE_NAME_DICT_KEY,
_MODULE_NAME_REGEX_DICT_KEY,
]
for i in range(len(prepare_dicts)):
for name in prepare_dicts[i].keys():
assert (
name in convert_dicts[i]
), f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
when it was present in prepare"
assert convert_dicts[i][name] is None or qconfig_equals(
prepare_dicts[i][name], convert_dicts[i][name]
), f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
def _is_qconfig_supported_by_dtype_configs(
qconfig: QConfig, dtype_configs: List[DTypeConfig]
):
for dtype_config in dtype_configs:
is_dynamic = dtype_config.is_dynamic
if is_dynamic is None:
is_dynamic = False
input_dtype = dtype_config.input_dtype or torch.float
weight_dtype = dtype_config.weight_dtype or torch.float
bias_dtype = dtype_config.bias_dtype or torch.float
output_dtype = dtype_config.output_dtype or torch.float
(
qconfig_activation_dtype,
qconfig_weight_dtype,
qconfig_input_act_is_dynamic,
) = get_qconfig_dtypes(qconfig)
qconfig_bias_dtype = (
torch.float16
if (
qconfig_activation_dtype == torch.float16
and qconfig_weight_dtype == torch.float16
and not is_dynamic
)
else torch.float
)
if is_dynamic:
is_match = (
qconfig_input_act_is_dynamic
and input_dtype == qconfig_activation_dtype
and output_dtype == torch.float
and weight_dtype == qconfig_weight_dtype
)
else:
is_match = (
input_dtype == qconfig_activation_dtype
and output_dtype == qconfig_activation_dtype
and weight_dtype == qconfig_weight_dtype
and bias_dtype == qconfig_bias_dtype
)
if is_match:
return True
return False
def _get_object_type_qconfig(
qconfig_mapping: QConfigMapping,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny,
) -> QConfigAny:
return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig
def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
if module_name == "":
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_mapping.module_name_qconfigs:
return qconfig_mapping.module_name_qconfigs[module_name]
else:
parent, _ = _parent_name(module_name)
return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
def _maybe_adjust_qconfig_for_module_type_or_name(
qconfig_mapping, module_type, module_name, global_qconfig
):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
module_type_qconfig = _get_object_type_qconfig(
qconfig_mapping, module_type, global_qconfig
)
module_name_regex_qconfig = _get_module_name_regex_qconfig(
qconfig_mapping, module_name, module_type_qconfig
)
module_name_qconfig = _get_module_name_qconfig(
qconfig_mapping, module_name, module_name_regex_qconfig
)
return module_name_qconfig
def _get_flattened_qconfig_dict(
qconfig_mapping: QConfigMapping,
) -> Dict[Union[Callable, str], QConfigAny]:
"""flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
"module_name_regex" is ignored for now since it's not supported
in propagate_qconfig_, but it can be fixed later.
For example:
Input: {
"": qconfig,
"object_type": [
(torch.add, qconfig)
],
"module_name": [
("conv", qconfig)
]
}
Output: {
"": qconfig,
torch.add: qconfig,
"conv": qconfig
}
"""
flattened: Dict[Union[Callable, str], QConfigAny] = {
"": qconfig_mapping.global_qconfig
}
for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
flattened[obj] = qconfig
for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
flattened[obj] = qconfig
return flattened
def _update_qconfig_for_qat(
qconfig_mapping: QConfigMapping, backend_config: BackendConfig
):
"""
Update the qconfig_mapping to account for module swaps during QAT.
During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
"""
module_to_qat_module_class = get_module_to_qat_module(backend_config)
object_type_dict = qconfig_mapping.object_type_qconfigs
new_object_type_dict = object_type_dict.copy()
for k, v in new_object_type_dict.items():
if k in module_to_qat_module_class:
object_type_dict[module_to_qat_module_class[k]] = v

View File

@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
from abc import ABC
from typing import Callable, Dict, List, Optional, Type
import torch
from torch.ao.quantization.backend_config import (
BackendConfig,
DTypeConfig,
ObservationType,
)
from torch.ao.quantization.utils import NodePattern, Pattern, QuantizerCls
from torch.fx.graph import Node
from .utils import all_node_args_have_no_tensors
__all__ = [
"QuantizeHandler",
"BinaryOpQuantizeHandler",
"CatQuantizeHandler",
"ConvReluQuantizeHandler",
"LinearReLUQuantizeHandler",
"BatchNormQuantizeHandler",
"EmbeddingQuantizeHandler",
"RNNDynamicQuantizeHandler",
"DefaultNodeQuantizeHandler",
"FixedQParamsOpQuantizeHandler",
"CopyNodeQuantizeHandler",
"GeneralTensorShapeOpQuantizeHandler",
"CustomModuleQuantizeHandler",
"StandaloneModuleQuantizeHandler",
]
def _default_root_node_getter(node_pattern):
if node_pattern is None:
return node_pattern
while not isinstance(node_pattern, Node):
node_pattern = node_pattern[-1]
return node_pattern
# Base Pattern Handler
class QuantizeHandler(ABC): # noqa: B024
"""Base handler class for the quantizer patterns"""
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Optional[Callable] = None,
is_custom_module=False,
is_standalone_module=False,
):
"""Records pattern information in __init__, which will be used
in convert
"""
self.node_pattern = node_pattern
self.modules = modules
if root_node_getter is None:
root_node_getter = _default_root_node_getter
self.root_node = root_node_getter(node_pattern)
self.is_custom_module_ = is_custom_module
self.is_standalone_module_ = is_standalone_module
self.num_tensor_args = 0
# determine how many of the first two args are Tensors (versus scalars)
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
if isinstance(self.root_node, Node):
cache_for_no_tensor_check: Dict[Node, bool] = {}
for arg_idx in range(len(self.root_node.args)):
arg = self.root_node.args[arg_idx]
if isinstance(arg, Node) and (
not all_node_args_have_no_tensors(
arg, self.modules, cache_for_no_tensor_check
)
):
self.num_tensor_args += 1
def is_general_tensor_value_op(self) -> bool:
"""
Returns True if the operator works for both floating point and
quantized input, and does some computation based on the input Tensor,
or the ops that only re-arranges the Tensor values or query some metadata
about the Tensor
so we need to insert observer/fake_quant for the output of the
operator (same observer instance as input)
since the distribution of values is different for input and output
Tensors (for HistogramObserver) while they share the same quantization
parameters
Example operator: avgpool2d, reshape, transpose, maxpool2d
Example observed operator:
observer_0 - avgpool2d - observer_0 (same observer instance as input)
"""
return False
def is_custom_module(self):
return self.is_custom_module_
def is_standalone_module(self):
return self.is_standalone_module_
def _get_quantize_handler_cls(
observation_type: ObservationType,
dtype_configs: List[DTypeConfig],
num_tensor_args_to_observation_type: Dict[int, ObservationType],
) -> Type[QuantizeHandler]:
"""
Return a configurable QuantizeHandler that matches the given specifications from the backend.
"""
class ConfigurableQuantizeHandler(QuantizeHandler):
def __init__(
self,
node_pattern: NodePattern,
modules: Dict[str, torch.nn.Module],
root_node_getter: Optional[Callable] = None,
):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, (
f"Must provide observation_type config for tensor number {self.num_tensor_args}"
f" in num_tensor_args_to_observation_type for {node_pattern}"
)
self.observation_type = num_tensor_args_to_observation_type[
self.num_tensor_args
]
else:
self.observation_type = observation_type
self.dtype_configs = dtype_configs
def is_general_tensor_value_op(self) -> bool:
return (
self.observation_type
== ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
)
return ConfigurableQuantizeHandler
def _get_pattern_to_quantize_handlers(
backend_config: BackendConfig,
) -> Dict[Pattern, QuantizerCls]:
"""
Note: Quantize handler is just a holder for some check methods like
(should_insert_observer_for_output), maybe this can be a enum as well,
we can refactor this after we convert the path for fbgemm/qnnpack fully to the
new path, this is not exposed to backend developers
"""
pattern_to_quantize_handlers = {}
for pattern, config in backend_config._pattern_complex_format_to_config.items():
observation_type = config.observation_type
dtype_configs = config.dtype_configs
num_tensor_args_to_observation_type = (
config._num_tensor_args_to_observation_type
)
pattern_to_quantize_handlers[pattern] = _get_quantize_handler_cls(
observation_type, dtype_configs, num_tensor_args_to_observation_type
)
return pattern_to_quantize_handlers
# TODO: remove this class, this is still exposed in torch.ao.quantization
# but we should be able to break bc
class BinaryOpQuantizeHandler(QuantizeHandler):
pass
class CatQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class ConvReluQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class LinearReLUQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class BatchNormQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class EmbeddingQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class RNNDynamicQuantizeHandler(QuantizeHandler):
pass
# TODO: remove this class
class DefaultNodeQuantizeHandler(QuantizeHandler):
"""Common quantized op, first input and first output will be quantized"""
# TODO: remove this class
class FixedQParamsOpQuantizeHandler(QuantizeHandler):
pass
# TODO: remove
class CopyNodeQuantizeHandler(QuantizeHandler):
pass
# TODO: remove
class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
pass
# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
class CustomModuleQuantizeHandler(QuantizeHandler):
pass
# TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
class StandaloneModuleQuantizeHandler(QuantizeHandler):
pass

View File

@ -0,0 +1,48 @@
from typing import Callable, List
import torch
from torch.ao.nn.intrinsic import _FusedModule
from torch.fx._symbolic_trace import Tracer
from torch.fx.proxy import Scope
__all__ = [
"QuantizationTracer",
]
class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__(scope, Scope(current_module_path, type(current_module)))
class QuantizationTracer(Tracer):
def __init__(
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes
# NB: initialized the module_type of top level module to None
# we are assuming people won't configure the model with the type of top level
# module here, since people can use "" for global config
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.record_stack_traces = True
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return (
(
(
m.__module__.startswith("torch.nn")
or m.__module__.startswith("torch.ao.nn")
)
and not isinstance(m, torch.nn.Sequential)
)
or module_qualified_name in self.skipped_module_names
or type(m) in self.skipped_module_classes
or isinstance(m, _FusedModule)
)

View File

@ -0,0 +1,959 @@
# mypy: allow-untyped-defs
import copy
import operator
import warnings
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.ao.quantization import QConfigAny, QuantType
from torch.ao.quantization.backend_config import DTypeWithConstraints
from torch.ao.quantization.fake_quantize import (
FakeQuantizeBase,
FixedQParamsFakeQuantize,
)
from torch.ao.quantization.observer import (
_is_activation_post_process,
FixedQParamsObserver,
ObserverBase,
)
from torch.ao.quantization.qconfig import (
float16_dynamic_qconfig,
float16_static_qconfig,
qconfig_equals,
)
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.stubs import DeQuantStub
from torch.ao.quantization.utils import (
_assert_and_get_unique_device,
activation_is_statically_quantized,
)
from torch.fx import GraphModule, map_arg
from torch.fx.graph import Graph, Node
# importing the lib so that the quantized_decomposed ops are registered
from ._decomposed import quantized_decomposed_lib # noqa: F401
from .custom_config import PrepareCustomConfig
# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"all_node_args_except_first",
"all_node_args_have_no_tensors",
"assert_and_get_unique_device",
"collect_producer_nodes",
"create_getattr_from_value",
"create_node_from_old_node_preserve_meta",
"EMPTY_ARG_DICT",
"get_custom_module_class_keys",
"get_linear_prepack_op_for_dtype",
"get_new_attr_name_with_prefix",
"get_non_observable_arg_indexes_and_types",
"get_qconv_prepack_op",
"get_skipped_module_name_and_classes",
"graph_module_from_producer_nodes",
"maybe_get_next_module",
"NodeInfo",
"node_arg_is_bias",
"node_arg_is_weight",
"NON_OBSERVABLE_ARG_DICT",
"NON_QUANTIZABLE_WEIGHT_OPS",
"return_arg_list",
"ObservedGraphModuleAttrs",
]
NON_QUANTIZABLE_WEIGHT_OPS = {
torch.nn.functional.layer_norm,
torch.nn.functional.group_norm,
torch.nn.functional.instance_norm,
}
@dataclass
class ObservedGraphModuleAttrs:
node_name_to_qconfig: Dict[str, QConfigAny]
node_name_to_scope: Dict[str, Tuple[str, type]]
prepare_custom_config: PrepareCustomConfig
equalization_node_name_to_qconfig: Dict[str, Any]
qconfig_mapping: QConfigMapping
is_qat: bool
observed_node_names: Set[str]
is_observed_standalone_module: bool = False
standalone_module_input_quantized_idxs: Optional[List[int]] = None
standalone_module_output_quantized_idxs: Optional[List[int]] = None
def node_arg_is_weight(node: Node, arg: Any) -> bool:
"""Returns if node arg is weight"""
weight_index = None
if "target_dtype_info" in node.meta:
weight_index = node.meta["target_dtype_info"].get("weight_index", None)
if (
weight_index is not None
and weight_index < len(node.args)
and node.args[weight_index] is arg
):
return True
return node.kwargs.get("weight") is arg
def node_arg_is_bias(node: Node, arg: Any) -> bool:
"""Returns if node arg is bias"""
bias_index = None
if "target_dtype_info" in node.meta:
bias_index = node.meta["target_dtype_info"].get("bias_index", None)
if (
bias_index is not None
and bias_index < len(node.args)
and node.args[bias_index] is arg
):
return True
return node.kwargs.get("bias") is arg
def get_custom_module_class_keys(
custom_module_mapping: Dict[QuantType, Dict[Type, Type]]
) -> List[Any]:
r"""Get all the unique custom module keys in the custom config dict
e.g.
Input:
{
QuantType.STATIC: {
CustomModule1: ObservedCustomModule
},
QuantType.DYNAMIC: {
CustomModule2: DynamicObservedCustomModule
},
QuantType.WEIGHT_ONLY: {
CustomModule3: WeightOnlyObservedCustomModule
},
}
Output:
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
[CustomModule1, CustomModule2, CustomModule3]
"""
# using set to dedup
float_custom_module_classes: Set[Any] = set()
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
float_custom_module_classes |= quant_mode_custom_module_classes
return list(float_custom_module_classes)
def get_linear_prepack_op_for_dtype(dtype):
if dtype == torch.float16:
return torch.ops.quantized.linear_prepack_fp16
elif dtype == torch.qint8:
return torch.ops.quantized.linear_prepack
else:
raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002
def get_qconv_prepack_op(conv_op: Callable) -> Callable:
prepack_ops = {
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
}
prepack_op = prepack_ops.get(conv_op, None)
assert prepack_op, f"Didn't find prepack op for {conv_op}"
return prepack_op
# Returns a function that can get a new attribute name for module with given
# prefix, for example,
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
# >> new_name = get_new_observer_name(module)
# new_name will be an unused attribute name on module, e.g. `_observer_1`
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
prefix = prefix.replace(".", "_")
def get_new_attr_name(module: torch.nn.Module):
def get_attr_name(i: int):
return prefix + str(i)
i = 0
attr_name = get_attr_name(i)
while hasattr(module, attr_name):
i += 1
attr_name = get_attr_name(i)
return attr_name
return get_new_attr_name
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
r"""Starting from a target node, trace back until we hit inpu or
getattr node. This is used to extract the chain of operators
starting from getattr to the target node, for example
def forward(self, x):
observed = self.observer(self.weight)
return F.linear(x, observed)
collect_producer_nodes(observed) will either return a list of nodes that
produces the observed node or None if we can't extract a self contained
graph without free variables(inputs of the forward function).
"""
nodes = [node]
frontier = [node]
while frontier:
node = frontier.pop()
all_args = list(node.args) + list(node.kwargs.values())
for arg in all_args:
if not isinstance(arg, Node):
continue
if arg.op == "placeholder":
# hit input, can't fold in this case
return None
nodes.append(arg)
if not (arg.op == "call_function" and arg.target == getattr):
frontier.append(arg)
return nodes
def graph_module_from_producer_nodes(
root: GraphModule, producer_nodes: List[Node]
) -> GraphModule:
r"""Construct a graph module from extracted producer nodes
from `collect_producer_nodes` function
Args:
root: the root module for the original graph
producer_nodes: a list of nodes we use to construct the graph
Return:
A graph module constructed from the producer nodes
"""
assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
# since we traced back from node to getattr
producer_nodes.reverse()
graph = Graph()
env: Dict[Any, Any] = {}
def load_arg(a):
return map_arg(a, lambda node: env[node])
for producer_node in producer_nodes:
env[producer_node] = graph.node_copy(producer_node, load_arg)
graph.output(load_arg(producer_nodes[-1]))
graph_module = GraphModule(root, graph)
return graph_module
# TODO: delete
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
return _assert_and_get_unique_device(module)
def create_getattr_from_value(
module: torch.nn.Module, graph: Graph, prefix: str, value: Any
) -> Node:
"""
Given a value of any type, creates a getattr node corresponding to the value and
registers the value as a buffer to the module.
"""
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
attr_name = get_new_attr_name(module)
device = assert_and_get_unique_device(module)
new_value = (
value.clone().detach()
if isinstance(value, torch.Tensor)
else torch.tensor(value, device=device)
)
module.register_buffer(attr_name, new_value)
# Create get_attr with value
attr_node = graph.create_node("get_attr", attr_name)
return attr_node
def all_node_args_have_no_tensors(
node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]
) -> bool:
"""
If we know for sure that all of this node's args have no
tensors (are primitives), return True. If we either
find a tensor or are not sure, return False. Note: this
function is not exact.
"""
if cache and node in cache:
return cache[node]
result = False # will be overwritten
if not isinstance(node, Node):
result = True
elif node.op == "placeholder":
result = False
elif node.op == "call_module":
assert isinstance(node.target, str)
if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == "call_module":
result = False
elif node.op == "call_function" and node.target is operator.getitem:
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == "get_attr":
result = False
elif node.target is getattr and node.args[1] in ["ndim", "shape"]:
# x1 = x0.ndim
result = True
elif node.op == "call_method" and node.target == "size":
# x1 = x0.size(0)
result = True
else:
found_one_tensor = False
for arg in node.args:
if isinstance(arg, list):
for list_el in arg:
if isinstance(list_el, Node):
this_list_el_args_have_no_tensors = (
all_node_args_have_no_tensors(list_el, modules, cache)
)
found_one_tensor = found_one_tensor or (
not this_list_el_args_have_no_tensors
)
# If found_one_tensor is True, there is no point in
# recursing further as the end result will always
# be True.
# TODO(future PR): remove this entire function and
# change to dtype inference without recursion.
if found_one_tensor:
result = not found_one_tensor
if cache:
cache[node] = result
return result
elif isinstance(arg, int):
pass
else:
if isinstance(arg, Node):
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(
arg, modules, cache
)
found_one_tensor = found_one_tensor or (
not this_arg_args_have_no_tensors
)
# If found_one_tensor is True, there is no point in
# recursing further as the end result will always
# be True.
# TODO(future PR): remove this entire function and
# change to dtype inference without recursion.
if found_one_tensor:
result = not found_one_tensor
if cache:
cache[node] = result
return result
else:
found_one_tensor = True
result = not found_one_tensor
if cache:
cache[node] = result
return result
def all_node_args_except_first(node: Node) -> List[int]:
"""
Returns all node arg indices after first
"""
return list(range(1, len(node.args)))
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
"""
Constructs a function that takes a node as arg and returns the arg_indices
that are valid for node.args
"""
def arg_indices_func(node: Node) -> List[int]:
return [i for i in arg_indices if i < len(node.args)]
return arg_indices_func
NodeInfo = namedtuple("NodeInfo", "op target")
# this dict identifies which indices of a node are non tensors
# so that they can be propagated correctly since inserting observers
# for them would cause errors
NON_OBSERVABLE_ARG_DICT: Dict[
NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]
] = {
NodeInfo("call_method", "masked_fill"): {
torch.bool: return_arg_list([1]),
float: return_arg_list([2]),
},
NodeInfo("call_method", "permute"): {int: all_node_args_except_first},
NodeInfo("call_method", "repeat"): {int: all_node_args_except_first},
NodeInfo("call_method", "reshape"): {int: all_node_args_except_first},
NodeInfo("call_method", "size"): {int: return_arg_list([1])},
NodeInfo("call_method", "transpose"): {int: all_node_args_except_first},
NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first},
NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])},
NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])},
NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])},
NodeInfo("call_method", "view"): {int: all_node_args_except_first},
}
EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
def get_non_observable_arg_indexes_and_types(
node: Node,
) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
"""
Returns a dict with of non float tensor types as keys and values which correspond to a
function to retrieve the list (which takes the node as an argument)
"""
info = NodeInfo(node.op, node.target)
return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
def maybe_get_next_module(
node: Node,
modules: Dict[str, nn.Module],
target_module_type: Optional[Type[nn.Module]] = None,
target_functional_type: Any = None,
) -> Optional[Node]:
"""Gets the next module that matches what is needed in
is_target_module_type if it exists
Args:
node: The node whose users we want to look at
target_module_type: Module type that we want to check
target_functional_type: Functional type that we want to check
"""
for user in node.users.keys():
if (
user.op == "call_module"
and target_module_type is not None
and isinstance(modules[str(user.target)], target_module_type)
):
return user
elif (
user.op == "call_function"
and target_functional_type is not None
and user.target == target_functional_type
):
return user
return None
def create_node_from_old_node_preserve_meta(
quantized_graph: Graph,
create_node_args: Tuple[Any, ...],
old_node: Node,
) -> Node:
"""
Creates `new_node` and copies the necessary metadata to it from `old_node`.
"""
new_node = quantized_graph.create_node(*create_node_args)
new_node.stack_trace = old_node.stack_trace
return new_node
def get_skipped_module_name_and_classes(
prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool
) -> Tuple[List[str], List[Type[Any]]]:
skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
skipped_module_classes = copy.copy(
prepare_custom_config.non_traceable_module_classes
)
if not is_standalone_module:
# standalone module and custom module config are applied in top level module
skipped_module_names += list(
prepare_custom_config.standalone_module_names.keys()
)
skipped_module_classes += list(
prepare_custom_config.standalone_module_classes.keys()
)
skipped_module_classes += get_custom_module_class_keys(
prepare_custom_config.float_to_observed_mapping
)
return skipped_module_names, skipped_module_classes
def _is_custom_module_lstm(
node: Node,
named_modules: Dict[str, torch.nn.Module],
qconfig: QConfigAny = None,
# QuantizeHandler, but we cannot include the type here due to circular imports
qhandler: Optional[Any] = None,
) -> bool:
"""
Return whether this refers to the custom module LSTM flow.
"""
mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None:
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
return (
isinstance(mod, torch.nn.LSTM)
and activation_is_statically_quantized(qconfig)
and qhandler.is_custom_module()
)
else:
return isinstance(mod, torch.ao.nn.quantizable.LSTM)
def _is_custom_module_mha(
node: Node,
named_modules: Dict[str, torch.nn.Module],
qconfig: QConfigAny = None,
# QuantizeHandler, but we cannot include the type here due to circular imports
qhandler: Optional[Any] = None,
) -> bool:
"""
Return whether this refers to the custom module MultiheadAttention flow.
"""
mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None:
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
return (
isinstance(mod, torch.nn.MultiheadAttention)
and activation_is_statically_quantized(qconfig)
and qhandler.is_custom_module()
)
else:
return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
def _get_module(
node: Node, named_modules: Dict[str, torch.nn.Module]
) -> Optional[torch.nn.Module]:
"""
If `node` refers to a call_module node, return the module, else None.
"""
if node.op == "call_module" and str(node.target) in named_modules:
return named_modules[str(node.target)]
else:
return None
def _insert_dequant_stub(
node: Node,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> Node:
"""
Attach a `DeQuantStub` to the model and create a node that calls this
`DeQuantStub` on the output of `node`, similar to how observers are inserted.
"""
prefix = "dequant_stub_"
get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
dequant_stub_name = get_new_dequant_stub_name(model)
dequant_stub = DeQuantStub()
setattr(model, dequant_stub_name, dequant_stub)
named_modules[dequant_stub_name] = dequant_stub
with graph.inserting_after(node):
return graph.call_module(dequant_stub_name, (node,))
def _insert_dequant_stubs_for_custom_module_lstm_output(
node: Node,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
graph: Graph,
) -> Node:
"""
Insert DeQuantStubs after each internal output node of custom module LSTM.
Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
components through `getitem`. This function transforms the graph as follows:
(1) Split the LSTM node into (output, (hidden0, hidden1))
(2) Insert a DeQuantStub after each internal node
(3) Recombine the DeQuantStubs into the same structure as before
(4) Reroute all consumers of the original LSTM node and its sub-nodes
(e.g. lstm[0])
Before:
lstm_output
|
v
original_user(s)
After:
lstm_output
/ \\
/ (getitem) \\
/ \\
v v
output hidden
| / \\
(DeQuantStub) (getitem)
| / \\
v v v
output_dq hidden0 hidden1
| | |
| (DeQuantStub) (DeQuantStub)
| | |
| v v
| hidden0_dq hidden1_dq
| \\ /
| (tuple)
| \\ /
| v v
| hidden_dq
\\ /
\\ (tuple) /
v v
lstm_output_dq
|
v
original_user(s)
For step (4), reroute all users of the original LSTM node(s) as follows:
lstm_output -> lstm_output_dq
lstm_output[0] -> output_dq
lstm_output[1] -> hidden_dq
lstm_output[1][0] -> hidden0_dq
lstm_output[1][1] -> hidden1_dq
Return the node `lstm_output_dq`.
"""
# (1) Split the LSTM node into (output, (hidden0, hidden1))
# (2) Insert a DeQuantStub after each internal node
with graph.inserting_after(node):
output = graph.call_function(operator.getitem, (node, 0))
output_dq = _insert_dequant_stub(output, model, named_modules, graph)
with graph.inserting_after(output_dq):
hidden = graph.call_function(operator.getitem, (node, 1))
with graph.inserting_after(hidden):
hidden0 = graph.call_function(operator.getitem, (hidden, 0))
hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
with graph.inserting_after(hidden0_dq):
hidden1 = graph.call_function(operator.getitem, (hidden, 1))
hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
# (3) Recombine the DeQuantStubs into the same structure as before
with graph.inserting_after(hidden1_dq):
hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
with graph.inserting_after(hidden_dq):
lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
# (4) Reroute all consumers of the original LSTM node and its sub-nodes
for user in list(node.users.keys()):
if user != output and user != hidden:
user.replace_input_with(node, lstm_output_dq)
# The getitem and tuple nodes we added here may interfere with reference quantized
# pattern matching, so we need to redirect the consumers of internal nodes to the
# corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
# in order to preserve reference patterns like "dequantize - consumer - quantize".
_reroute_tuple_getitem_pattern(graph)
return lstm_output_dq
def _maybe_get_custom_module_lstm_from_node_arg(
arg: Node,
named_modules: Dict[str, torch.nn.Module],
) -> Optional[Node]:
"""
Given an argument of a node, if the argument refers to the path through which the node
is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
skip inserting input observers for this node. This is because custom module LSTM produces
quantized outputs, so inserting an input observer for the consumer of custom module LSTM
would unnecessarily quantize the outputs again.
lstm -> consumer
In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
This tuple can be consumed in one of four ways:
lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0]
lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1]
lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1]
lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm
Thus, we must match against the above patterns instead of simply checking the parent node
to determine whether this node is a consumer of a custom module LSTM.
"""
def match_dq(a):
return isinstance(_get_module(a, named_modules), DeQuantStub)
def match_lstm(a):
return _is_custom_module_lstm(a, named_modules)
def match_getitem(a):
return a.op == "call_function" and a.target == operator.getitem
def match_tuple(a):
return a.op == "call_function" and a.target == tuple
def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
"""
Traverse up the graph and match the args one by one.
If there is a match, return the last matched node, or None otherwise.
"""
a = arg
for i, match in enumerate(match_pattern):
if not match(a):
return None
# Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
if i < len(match_pattern) - 1:
if match == match_tuple:
a = a.args[0][0] # type: ignore[assignment,index]
else:
a = a.args[0] # type: ignore[assignment]
return a
all_match_patterns = [
[match_dq, match_getitem, match_lstm],
[match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
[match_dq, match_getitem, match_getitem, match_lstm],
[match_tuple, match_dq, match_getitem, match_lstm],
]
for p in all_match_patterns:
matched_node = _match_pattern(p)
if matched_node is not None:
return matched_node
return None
def _reroute_tuple_getitem_pattern(graph: Graph):
"""
Search for patterns where N consecutive `tuple` call_function nodes are followed by
N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
If we find this pattern, reroute the consumers of the last `getitem` to skip these
N `tuple` and `getitem` nodes.
Before:
a b c
| \\ /
\\ tuple
\\ /
tuple
|
getitem(1)
|
getitem(0)
|
d
After:
b
|
d
"""
def find_patterns(
node: Node,
index_stack: List[int],
current_pattern: List[Node],
matched_patterns: List[List[Node]],
seen: Set[Tuple[Node, Tuple[int, ...]]],
):
"""
Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
starting at the given node.
We use a stack to keep track of the expected `getitem` indices, since these are
reversed from the `tuple` indices. In the above example, the stack after
(b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
and then by getitem(0).
TODO: traverse upwards from the output and handle the case when tuple is not a
separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
"""
if len(index_stack) == 0 and len(current_pattern) > 0:
matched_patterns.append(copy.copy(current_pattern))
current_pattern.clear()
# Avoid duplicating work
state = (node, tuple(index_stack))
if state in seen:
return
seen.add(state)
# Iterate through users of this node to find tuple/getitem nodes to match
for user in node.users:
if user.op == "call_function" and user.target == tuple:
for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
if user_arg == node:
index_stack.append(i)
current_pattern.append(user)
find_patterns(
user, index_stack, current_pattern, matched_patterns, seen
)
elif user.op == "call_function" and user.target == operator.getitem:
if len(index_stack) > 0:
if user.args[1] == index_stack[-1]:
index_stack.pop()
current_pattern.append(user)
find_patterns(
user, index_stack, current_pattern, matched_patterns, seen
)
return matched_patterns
# Collect all matched patterns
matched_patterns: List[List[Node]] = []
seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack)
for node in graph.nodes:
find_patterns(node, [], [], matched_patterns, seen)
# For each pattern, redirect all consumers of the last getitem node to the correct input
# of the first tuple node
for pattern in matched_patterns:
first_tuple = pattern[0]
last_getitem = pattern[-1]
assert first_tuple.op == "call_function" and first_tuple.target == tuple
assert (
last_getitem.op == "call_function"
and last_getitem.target == operator.getitem
)
last_getitem_index = last_getitem.args[1]
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
for user in list(last_getitem.users.keys()):
user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
def _get_observer_from_activation_post_process(
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
) -> ObserverBase:
"""
If `activation_post_process` is an observer, return the observer.
If `activation_post_process` is a fake quantize, return the internal observer.
"""
if isinstance(activation_post_process, ObserverBase):
return activation_post_process
else:
assert isinstance(activation_post_process, FakeQuantizeBase)
return activation_post_process.activation_post_process # type: ignore[return-value]
def _qconfig_satisfies_dtype_config_constraints(
qconfig: QConfigAny,
dtype_with_constraints: DTypeWithConstraints,
is_activation: bool = True,
) -> bool:
"""
Return whether `qconfig` satisfies the following constraints from the backend,
specified through the activation and weight DTypeWithConstraints.
1. QConfig specified a quantization range that falls within the backend's, if any
2. QConfig specified a min scale value that is >= the backend's, if any
3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
scale and zero point that match the backend's, if any
If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
"""
# TODO: log warnings only when the user enabled a debug flag
def _activation_post_process_satisfies_dtype_config_constraints(
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
dtype_with_constraints: DTypeWithConstraints,
debug_string: str,
) -> bool:
observer = _get_observer_from_activation_post_process(activation_post_process)
app_quant_min = getattr(observer, "quant_min", None)
app_quant_max = getattr(observer, "quant_max", None)
# TODO: for now, just use the existing eps value as scale_min. In the future, we should
# resolve the differences between the two, either by renaming eps or some other way
app_scale_min = getattr(observer, "eps", None)
backend_quant_min = dtype_with_constraints.quant_min_lower_bound
backend_quant_max = dtype_with_constraints.quant_max_upper_bound
backend_scale_min = dtype_with_constraints.scale_min_lower_bound
backend_scale_exact_match = dtype_with_constraints.scale_exact_match
backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
# check quantization ranges
if backend_quant_min is not None and backend_quant_max is not None:
if app_quant_min is None or app_quant_max is None:
warnings.warn(
f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}"
)
return False
elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
warnings.warn(
f"QConfig {debug_string} quantization range must fall within the backend's:\n"
f"QConfig range = ({app_quant_min}, {app_quant_max}), "
f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
f"ignoring {qconfig}"
)
return False
# check scale min
if backend_scale_min is not None:
if app_scale_min is None:
warnings.warn(
f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}"
)
return False
if app_scale_min < backend_scale_min:
warnings.warn(
f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
)
return False
# check fixed scale and zero point
if (
backend_scale_exact_match is not None
and backend_zero_point_exact_match is not None
):
# For tests only, accept the following qconfigs for now
# TODO: handle fp16 qconfigs properly
for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
if qconfig_equals(qconfig, accepted_qconfig):
return True
suggestion_str = (
"Please use torch.ao.quantization.get_default_qconfig_mapping or "
"torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n'
" model = prepare_fx(model, qconfig_mapping, example_inputs)"
)
if not isinstance(
activation_post_process, FixedQParamsObserver
) and not isinstance(activation_post_process, FixedQParamsFakeQuantize):
warnings.warn(
f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
)
return False
if (
observer.scale != backend_scale_exact_match
or observer.zero_point != backend_zero_point_exact_match
):
warnings.warn(
f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
f"ignoring {qconfig}.\n{suggestion_str}"
)
return False
return True
if qconfig is None or dtype_with_constraints.dtype is None:
return True
activation_post_process_ctr = (
qconfig.activation if is_activation else qconfig.weight
)
debug_string = "activation" if is_activation else "weight"
satisfies_constraints = True
if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr()
assert _is_activation_post_process(activation_post_process)
# If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype:
return True
satisfies_constraints = (
_activation_post_process_satisfies_dtype_config_constraints(
activation_post_process, dtype_with_constraints, debug_string
)
)
return satisfies_constraints