I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,563 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn as nn
from torch.ao.quantization import prepare
from torch.ao.quantization.quantization_mappings import (
get_default_compare_output_module_list,
)
NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST = {
nnqd.Linear,
nnq.Linear,
nnqd.LSTM,
nn.LSTM,
}
def _find_match(
str_list: Union[Dict[str, Any], List[str]],
key_str: str,
postfix: str,
) -> Optional[str]:
split_str = key_str.split(".")
if split_str[-1] == postfix:
match_string = "".join(key_str.split(".")[0:-1])
for s2 in str_list:
pattern1 = "".join(s2.split(".")[0:-1])
pattern2 = "".join(s2.split(".")[0:-2])
if match_string == pattern1:
return s2
if match_string == pattern2:
return s2
# For matching "fc.weight" and "fc._packed_params._packed_params"
if postfix == "_packed_params":
match_string = "".join(key_str.split(".")[0:-2])
if len(match_string) == 0:
return None
for s2 in str_list:
pattern1 = "".join(s2.split(".")[0:-1])
pattern2 = "".join(s2.split(".")[0:-2])
if match_string == pattern1:
return s2
if match_string == pattern2:
return s2
return None
else:
return None
def compare_weights(
float_dict: Dict[str, Any], quantized_dict: Dict[str, Any]
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Compare the weights of the float module with its corresponding quantized
module. Return a dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights. This dict can be used to compare and compute the quantization
error of the weights of float and quantized models.
Example usage::
wt_compare_dict = compare_weights(
float_model.state_dict(), qmodel.state_dict())
for key in wt_compare_dict:
print(
key,
compute_error(
wt_compare_dict[key]['float'],
wt_compare_dict[key]['quantized'].dequantize()
)
)
Args:
float_dict: state dict of the float model
quantized_dict: state dict of the quantized model
Return:
weight_dict: dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_weights")
weight_dict: Dict[str, Dict] = {}
for key in quantized_dict:
match_key = _find_match(float_dict, key, "weight")
if match_key is not None:
weight_dict[key] = {}
weight_dict[key]["float"] = float_dict[match_key]
weight_dict[key]["quantized"] = quantized_dict[key]
continue
# For matching "fc.weight" and "fc._packed_params._packed_params"
match_key = _find_match(float_dict, key, "_packed_params")
if match_key is not None:
weight_dict[key] = {}
weight_dict[key]["float"] = float_dict[match_key]
weight_dict[key]["quantized"] = quantized_dict[key][0]
# For LSTM
split_str = key.split(".")
if split_str[-1] == "param" and split_str[-3] == "_all_weight_values":
layer = split_str[-2]
module_name = ".".join(split_str[:-3])
float_weight_ih_key = module_name + ".weight_ih_l" + layer
float_weight_hh_key = module_name + ".weight_hh_l" + layer
if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict:
weight_dict[key] = {}
weight_dict[key]["float"] = float_dict[float_weight_ih_key]
weight_dict[key]["quantized"] = (
quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0]
)
weight_dict[key]["float"] = float_dict[float_weight_hh_key]
weight_dict[key]["quantized"] = (
quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0]
)
return weight_dict
def _get_logger_dict_helper(
mod: nn.Module,
target_dict: Dict[str, Any],
prefix: str = "",
) -> None:
r"""This is the helper function for get_logger_dict
Args:
mod: module we want to save all logger stats
prefix: prefix for the current module
target_dict: the dictionary used to save all logger stats
"""
def get_prefix(prefix):
return prefix if prefix == "" else prefix + "."
for name, child in mod.named_children():
if isinstance(child, Logger):
target_dict[get_prefix(prefix) + "stats"] = child.stats
break
for name, child in mod.named_children():
module_prefix = get_prefix(prefix) + name if prefix else name
_get_logger_dict_helper(child, target_dict, module_prefix)
def get_logger_dict(mod: nn.Module, prefix: str = "") -> Dict[str, Dict]:
r"""Traverse the modules and save all logger stats into target dict.
This is mainly used for quantization accuracy debug.
Type of loggers supported:
ShadowLogger: used to log the outputs of the quantized module and its matching float shadow module,
OutputLogger: used to log the outputs of the modules
Args:
mod: module we want to save all logger stats
prefix: prefix for the current module
Return:
target_dict: the dictionary used to save all logger stats
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.get_logger_dict")
target_dict: Dict[str, Dict] = {}
_get_logger_dict_helper(mod, target_dict, prefix)
return target_dict
class Logger(nn.Module):
r"""Base class for stats logging"""
def __init__(self):
super().__init__()
self.stats = {}
# We only insert observer if the op is quantized with static quantization,
# which is identified by activation_observer.dtype == quint8. This is needed
# when attaching Logger as observer for FX mode
self.dtype = torch.quint8
def forward(self, x):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
class ShadowLogger(Logger):
r"""Class used in Shadow module to record the outputs of the original and
shadow modules.
"""
def __init__(self):
super().__init__()
self.stats["float"] = []
self.stats["quantized"] = []
def forward(self, x, y):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
if len(x) > 1:
x = x[0]
if len(y) > 1:
y = y[0]
self.stats["quantized"].append(x.detach())
self.stats["float"].append(y.detach())
class OutputLogger(Logger):
r"""Class used to log the outputs of the module"""
def __init__(self):
super().__init__()
self.stats["tensor_val"] = []
def forward(self, x):
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
self.stats["tensor_val"].append(x)
return x
def _convert_tuple_to_list(t: Any) -> Any:
return [_convert_tuple_to_list(x) for x in t] if type(t) is tuple else t
def _dequantize_tensor_list(t: Any) -> Any:
return (
[_dequantize_tensor_list(x) for x in t]
if type(t) is list
else t.dequantize()
if t.is_quantized
else t
)
class Shadow(nn.Module):
r"""Shadow module attaches the float module to its matching quantized module
as the shadow. Then it uses Logger module to process the outputs of both
modules.
Args:
q_module: module quantized from float_module that we want to shadow
float_module: float module used to shadow q_module
logger_cls: type of logger used to process the outputs of q_module and
float_module. ShadowLogger or custom loggers can be used.
"""
def __init__(self, q_module, float_module, logger_cls):
super().__init__()
self.orig_module = q_module
self.shadow_module = float_module
self.dequant = nnq.DeQuantize()
self.logger = logger_cls()
def forward(self, *x) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
xl = _convert_tuple_to_list(x)
output = self.orig_module(*xl)
xl_float = _dequantize_tensor_list(xl)
shadow_output = self.shadow_module(*xl_float)
self.logger(output, shadow_output)
return output
def add(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add(x, y)
x = x.dequantize()
y = y.dequantize()
shadow_output = self.shadow_module.add(x, y)
self.logger(output, shadow_output)
return output
def add_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.add_scalar(x, y)
self.logger(output, shadow_output)
return output
def mul(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.mul(x, y)
x = x.dequantize()
y = y.dequantize()
shadow_output = self.shadow_module.mul(x, y)
self.logger(output, shadow_output)
return output
def mul_scalar(self, x: torch.Tensor, y: float) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.mul_scalar(x, y)
x = x.dequantize()
shadow_output = self.shadow_module.mul_scalar(x, y)
self.logger(output, shadow_output)
return output
def cat(self, x: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.cat(x, dim)
x = [y.dequantize() for y in x]
shadow_output = self.shadow_module.cat(x, dim)
self.logger(output, shadow_output)
return output
def add_relu(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fmt: off
"""
""" # blank docblock to make autodoc happy
# fmt: on
output = self.orig_module.add_relu(x, y)
x = x.dequantize()
y = y.dequantize()
shadow_output = self.shadow_module.add_relu(x, y)
self.logger(output, shadow_output)
return output
def prepare_model_with_stubs(
float_module: nn.Module,
q_module: nn.Module,
module_swap_list: Set[type],
logger_cls: Callable,
) -> None:
r"""Prepare the model by attaching the float module to its matching quantized
module as the shadow if the float module type is in module_swap_list.
Example usage::
prepare_model_with_stubs(float_model, q_model, module_swap_list, Logger)
q_model(data)
ob_dict = get_logger_dict(q_model)
Args:
float_module: float module used to generate the q_module
q_module: module quantized from float_module
module_swap_list: list of float module types to attach the shadow
logger_cls: type of logger to be used in shadow module to process the outputs of
quantized module and its float shadow module
"""
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.prepare_model_with_stubs"
)
float_module_children = {}
for name, mod in float_module.named_children():
float_module_children[name] = mod
reassign = {}
for name, mod in q_module.named_children():
if name not in float_module_children:
continue
float_mod = float_module_children[name]
if type(float_mod) not in module_swap_list:
prepare_model_with_stubs(float_mod, mod, module_swap_list, logger_cls)
# Insert shadow module only if the module is not of the same type as
# the floating point module
if type(float_mod) in module_swap_list and not _is_identical_module_type(
mod, float_mod
):
reassign[name] = Shadow(mod, float_mod, logger_cls)
for key, value in reassign.items():
q_module._modules[key] = value
def _is_identical_module_type(mod1, mod2):
# Compare if two modules have the same dtype
mod1_module_types = [type(mod) for mod in mod1.modules()]
mod2_module_types = [type(mod) for mod in mod2.modules()]
return mod1_module_types == mod2_module_types
def compare_model_stub(
float_model: nn.Module,
q_model: nn.Module,
module_swap_list: Set[type],
*data,
logger_cls=ShadowLogger,
) -> Dict[str, Dict]:
r"""Compare quantized module in a model with its floating point counterpart,
feeding both of them the same input. Return a dict with key corresponding to
module names and each entry being a dictionary with two keys 'float' and
'quantized', containing the output tensors of quantized and its matching
float shadow module. This dict can be used to compare and compute the module
level quantization error.
This function first call prepare_model_with_stubs() to swap the quantized
module that we want to compare with the Shadow module, which takes quantized
module, corresponding float module and logger as input, and creates a forward
path inside to make the float module to shadow quantized module sharing the
same input. The logger can be customizable, default logger is ShadowLogger
and it will save the outputs of the quantized module and float module that
can be used to compute the module level quantization error.
Example usage::
module_swap_list = [torchvision.models.quantization.resnet.QuantizableBasicBlock]
ob_dict = compare_model_stub(float_model,qmodel,module_swap_list, data)
for key in ob_dict:
print(key, compute_error(ob_dict[key]['float'], ob_dict[key]['quantized'].dequantize()))
Args:
float_model: float model used to generate the q_model
q_model: model quantized from float_model
module_swap_list: list of float module types at which shadow modules will
be attached.
data: input data used to run the prepared q_model
logger_cls: type of logger to be used in shadow module to process the outputs of
quantized module and its float shadow module
"""
torch._C._log_api_usage_once("quantization_api._numeric_suite.compare_model_stub")
prepare_model_with_stubs(float_model, q_model, module_swap_list, logger_cls)
q_model(*data)
ob_dict = get_logger_dict(q_model)
return ob_dict
def get_matching_activations(
float_module: nn.Module,
q_module: nn.Module,
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Find the matching activation between float and quantized modules.
Args:
float_module: float module used to generate the q_module
q_module: module quantized from float_module
Return:
act_dict: dict with key corresponding to quantized module names and each
entry being a dictionary with two keys 'float' and 'quantized', containing
the matching float and quantized activations
"""
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.get_matching_activations"
)
float_dict = get_logger_dict(float_module)
quantized_dict = get_logger_dict(q_module)
act_dict: Dict[str, Dict] = {}
for key in quantized_dict:
if len(quantized_dict[key]["tensor_val"]) == 0:
continue
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
if match_key is not None:
act_dict[key] = {}
act_dict[key]["float"] = float_dict[match_key]["tensor_val"]
act_dict[key]["quantized"] = quantized_dict[key]["tensor_val"]
return act_dict
def prepare_model_outputs(
float_module: nn.Module,
q_module: nn.Module,
logger_cls=OutputLogger,
allow_list=None,
) -> None:
r"""Prepare the model by attaching the logger to both float module
and quantized module if they are in the allow_list.
Args:
float_module: float module used to generate the q_module
q_module: module quantized from float_module
logger_cls: type of logger to be attached to float_module and q_module
allow_list: list of module types to attach logger
"""
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.prepare_model_outputs"
)
if allow_list is None:
allow_list = get_default_compare_output_module_list()
qconfig_debug = torch.ao.quantization.QConfig(activation=logger_cls, weight=None)
float_module.qconfig = qconfig_debug # type: ignore[assignment]
prepare(
float_module, inplace=True, allow_list=allow_list, prepare_custom_config_dict={}
)
q_module.qconfig = qconfig_debug # type: ignore[assignment]
prepare(
q_module,
inplace=True,
allow_list=allow_list,
observer_non_leaf_module_list=NON_LEAF_MODULE_TO_ADD_OBSERVER_ALLOW_LIST,
prepare_custom_config_dict={},
)
def compare_model_outputs(
float_model: nn.Module,
q_model: nn.Module,
*data,
logger_cls=OutputLogger,
allow_list=None,
) -> Dict[str, Dict[str, torch.Tensor]]:
r"""Compare output activations between float and quantized models at
corresponding locations for the same input. Return a dict with key corresponding
to quantized module names and each entry being a dictionary with two keys
'float' and 'quantized', containing the activations of quantized model and
float model at matching locations. This dict can be used to compare and
compute the propagation quantization error.
Example usage::
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
for key in act_compare_dict:
print(
key,
compute_error(
act_compare_dict[key]['float'],
act_compare_dict[key]['quantized'].dequantize()
)
)
Args:
float_model: float model used to generate the q_model
q_model: model quantized from float_model
data: input data used to run the prepared float_model and q_model
logger_cls: type of logger to be attached to float_module and q_module
allow_list: list of module types to attach logger
Return:
act_compare_dict: dict with key corresponding to quantized module names
and each entry being a dictionary with two keys 'float' and 'quantized',
containing the matching float and quantized activations
"""
torch._C._log_api_usage_once(
"quantization_api._numeric_suite.compare_model_outputs"
)
if allow_list is None:
allow_list = get_default_compare_output_module_list()
prepare_model_outputs(float_model, q_model, logger_cls, allow_list)
float_model(*data)
q_model(*data)
act_compare_dict = get_matching_activations(float_model, q_model)
return act_compare_dict

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,470 @@
# mypy: allow-untyped-defs
import collections
import enum
from typing import Any, Dict, List, Optional, Set, Tuple
import torch
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
from torch.ao.quantization.utils import getattr_from_fqn
from torch.fx import GraphModule
from torch.fx.graph import Graph, Node
from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map
from .ns_types import NSNodeTargetType, NSSubgraph
from .pattern_utils import (
end_node_matches_reversed_fusion,
get_reversed_fusions,
get_type_a_related_to_b,
)
toq = torch.ops.quantized
def _get_output_nodes(g: Graph) -> List[Node]:
return [n for n in g.nodes if n.op == "output"]
class _NSGraphMatchableSubgraphsIterator:
"""
Iterates through the graph of gm, starting with the output nodes
and continuing backwards.
1. Returns matchable subgraphs, in order. A subgraph is defined by
(start_node, end_node).
2. Skips over non-matchable subgraphs
"""
def __init__(
self,
gm: GraphModule,
non_matchable_functions: Set[NSNodeTargetType],
non_matchable_modules: Set[NSNodeTargetType],
non_matchable_methods: Set[NSNodeTargetType],
):
self.gm: GraphModule = gm
self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
self.seen_nodes: Set[Node] = set()
self.stack: List[Node] = []
for start_node in _get_output_nodes(self.gm.graph):
self.stack.append(start_node)
def __iter__(self):
return self
def __next__(self) -> NSSubgraph:
"""
Returns the next matchable subgraph.
"""
while len(self.stack) > 0:
cur_end_node = self.stack.pop()
if cur_end_node in self.seen_nodes:
continue
# for subgraphs which are single nodes, start_node == end_node
# for subgraphs with more than one node, start node != end_node
cur_start_node = cur_end_node
# Subgraphs like linear-relu have the base node as the start node.
# Subgraphs like dequantize-linear-relu-to(torch.float16) have the
# base node as the second node.
# The cur_base_op_node var will move to the actual node during
# the fusion matching later in this code block.
cur_base_op_node = cur_end_node
# Check for potential fusions. For now, we are greedy
# and always skip all non-base nodes of a fusion. For example,
# if we match linear-relu backwards, we will always skip the
# relu node and attempt to match the linear node. This can
# be made configurable later if needed.
for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
is_match = end_node_matches_reversed_fusion(
cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes
)
if is_match:
# navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
# if the base op index matches the current node, set it
rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx
if rev_fusion_idx == rev_base_op_idx:
cur_base_op_node = cur_start_node
break
self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack
for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg)
# skip unmatchable nodes
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
if not self._is_matchable(cur_base_op_node):
continue
# If an observer or a fake_quant was not matched as a part of
# a pattern of multiple nodes, ignore it. One case where this is
# relevant is an observer on a graph input, which was added because
# it is necessary for the next node.
if cur_end_node.op == "call_module" and cur_start_node is cur_end_node:
maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
continue
return NSSubgraph(
start_node=cur_start_node,
end_node=cur_end_node,
base_op_node=cur_base_op_node,
)
raise StopIteration
def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
"""
Adds all of the nodes in this arg to the stack, properly navigating
through list, dicts and tuples.
"""
if isinstance(arg, Node):
self.stack.append(arg)
elif (
isinstance(arg, torch.fx.immutable_collections.immutable_list)
or type(arg) is tuple
):
for inner_arg in arg:
self._recursively_add_node_arg_to_stack(inner_arg)
elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
for value in arg.values():
self._recursively_add_node_arg_to_stack(value)
def _is_matchable(self, node: Node) -> bool:
if node.op == "call_function":
return node.target not in self.non_matchable_functions
elif node.op == "call_module":
assert isinstance(node.target, str)
target_mod = getattr_from_fqn(self.gm, node.target)
return not any(
isinstance(target_mod, t) # type: ignore[arg-type]
for t in self.non_matchable_modules
)
elif node.op == "call_method":
return node.target not in self.non_matchable_methods
else:
return False
class GraphMatchingException(Exception):
"""
Exception raised when two graphs cannot be matched.
"""
class SubgraphTypeRelationship(enum.Enum):
# same type, known
# example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
EQUAL = enum.auto()
# same type, but the type is not known to Numerical Suite
# (user defined type, etc).
EQUAL_BUT_UKNOWN = enum.auto()
# known, same subgraph_relationship set, but not the same type
# example: F.linear and toq.linear
RELATED_BUT_NOT_EQUAL = enum.auto()
# not related
NOT_RELATED = enum.auto()
def _get_subgraph_relationship_type(
subgraph_a: NSSubgraph,
subgraph_b: NSSubgraph,
gm_a: GraphModule,
gm_b: GraphModule,
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
) -> SubgraphTypeRelationship:
node_a = subgraph_a.base_op_node
node_b = subgraph_b.base_op_node
# TODO(next): make this code handle matching by what is before the base op
if node_a.op != node_b.op:
if not (
node_a.op in ("call_function", "call_method")
and node_b.op in ("call_function", "call_method")
):
return SubgraphTypeRelationship.NOT_RELATED
if node_a.op in ("call_function", "call_method"):
key = (node_a.target, node_b.target)
if key not in type_a_related_to_b:
if node_a.target == node_b.target:
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
else:
return SubgraphTypeRelationship.NOT_RELATED
# after this point, we are dealing with known types
if node_a.target == node_b.target:
node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
if node_a_has_prev and (not node_b_has_prev):
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and node_b_has_prev:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
elif (not node_a_has_prev) and (not node_b_has_prev):
return SubgraphTypeRelationship.EQUAL
else:
# TODO(future PR): check for matches start_op_node and base_op_node
return SubgraphTypeRelationship.EQUAL
if key in type_a_related_to_b:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
else:
return SubgraphTypeRelationship.NOT_RELATED
elif node_a.op == "call_module":
assert (
subgraph_a.base_op_node == subgraph_a.start_node
and subgraph_b.base_op_node == subgraph_b.start_node
), "Matching call_module patterns where base_op_node != start_node is not supported yet"
# for call_module, we need to look up the modules to do the type check
assert isinstance(node_a.target, str)
mod_a = getattr_from_fqn(gm_a, node_a.target)
assert isinstance(node_b.target, str)
mod_b = getattr_from_fqn(gm_b, node_b.target)
key = (type(mod_a), type(mod_b))
if key not in type_a_related_to_b:
if type(mod_a) == type(mod_b):
return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
else:
return SubgraphTypeRelationship.NOT_RELATED
elif type(mod_a) == type(mod_b):
return SubgraphTypeRelationship.EQUAL
else:
return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
return SubgraphTypeRelationship.NOT_RELATED
def _get_name_for_subgraph(
subgraph_a: NSSubgraph,
gm_a: GraphModule,
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
existing_names: Set[str],
) -> str:
"""
Returns a unique name for a subgraph. This name is based on two things:
1. the name of the set containing the underlying type of the base op in the
subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
2. the number of previous subgraphs with related underlying type of the base op
For example, in the graph
linear0 -> relu0 -> linear1 -> relu1
The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
from the output node backwards, the name given to (linear1, relu1) will be
`base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
will be `base_op_torch.nn.functional.linear_1`.
Why are we not just using the node name? Answer: because of two requirements:
A. fusions must be supported
B. some Numeric Suite APIs can be called without having all of the models in memory
For example, let's say we need to match nodes of
(1) ... -> linear0 -> relu0 -> ...
And
(2) ... -> linear_relu0 -> ...
Without being able to inspect them together. With the current naming scheme, if
we iterate through both of these graphs in the same order, and assuming the rest
of the graphs match, both of these subgraphs will get the same name without
(1) and (2) knowing anything about each other.
"""
target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
target_base_type = None
for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
if target_type in sets_of_related_ops:
target_base_type = base_name
target_base_name = "base_op_" + str(target_base_type)
counter = 0
proposed_name = target_base_name + "_" + str(counter)
while proposed_name in existing_names:
counter += 1
proposed_name = target_base_name + "_" + str(counter)
existing_names.add(proposed_name)
return proposed_name
def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
if node.op in ("call_function", "call_method"):
return node.target
elif node.op == "call_module":
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
return type(mod)
return None
def get_matching_subgraph_pairs(
gm_a: GraphModule,
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
"""
Matches matchable subgraphs of graph_a to graph_b.
For a node, "matchable" is defined as a node which is not an observer,
fake_quants, quant or dequant.
A subgraph can contain one or more nodes. A subgraph is matchable if
at least one node inside of it is matchable. Currently, all nodes in
a subgraph must be matchable (because we assume no observers will be
inserted in the middle of a fusion).
A subgraph is defined by (start_node, end_node). We assume that only
start_node and end_node are linked with the surrounding graph, all other
nodes in a subgraph are self-contained.
A pair of nodes is "related" if both nodes represent the same mathematical
operation across different quantization flavors. For example,
`F.linear` and `torch.ops.quantized.linear` are related, and
`F.linear` and `torch.nn.Conv` are not related.
For each matchable pair of nodes node_a and node_b, they will match
if node_a and node_b are related.
For graphs A and B, they will match iff:
1. the number of matchable subgraphs in A and B is equivalent
2. when iterating through the matchable subgraphs of A and B in the same order, each
corresponding pair of base nodes is related.
This enables us to find the corresponding subgraphs between
graphs of related models. For example, if we had two graphs such as:
graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
w -/
b -/
graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
packed_params_0 -/
This function will return the following result:
{
'conv_0': ( # the name of the node in graph_b
(conv_0, conv_0), # (start_node_a, end_node_a)
(qconv_0, qconv_0), # (start_node_b, end_node_b)
),
}
Or, if we have a fusion pattern,
graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
w -/
b -/
graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
packed_params_0 -/
This function will return the following result:
{
'linear_relu_0': ( # the name of the node in graph_b
(linear_0, relu_0), # (start_node_a, end_node_a)
(linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
),
}
"""
if unmatchable_types_map is None:
unmatchable_types_map = get_unmatchable_types_map()
non_matchable_functions = unmatchable_types_map["funs_unmatchable"]
non_matchable_modules = unmatchable_types_map["mods_unmatchable"]
non_matchable_methods = unmatchable_types_map["meths_unmatchable"]
graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods
)
graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods
)
results = collections.OrderedDict()
if base_name_to_sets_of_related_ops is None:
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
existing_names_a: Set[str] = set()
existing_names_b: Set[str] = set()
while True:
# fetch the next subgraphs from a and b
cur_subgraph_a, cur_subgraph_b = None, None
try:
cur_subgraph_a = next(graph_a_iterator)
except StopIteration:
pass
try:
cur_subgraph_b = next(graph_b_iterator)
except StopIteration:
pass
# look up types of a and b for useful error messages
type_start_a, type_start_b = None, None
if cur_subgraph_a is not None:
type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
if cur_subgraph_b is not None:
type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
# check for results and determine what to do next
if cur_subgraph_a is not None and cur_subgraph_b is not None:
# both nodes were fetched, check for subgraph_relationship
# note: subgraph_relationship is checked on the start node, i.e.
# if a linear-relu pattern is checked, we would check for subgraph_relationship
# of the linear
subgraph_relationship = _get_subgraph_relationship_type(
cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b
)
if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
msg = f"""
The subgraphs
({cur_subgraph_a}, {type_start_a}) and
({cur_subgraph_b}, {type_start_b})
are not related. Please ensure that the two models you pass in have the same number
of subgraphs, and each pair of subgraphs is related to each other."""
raise GraphMatchingException(msg)
elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
# skip matching but unknown types
continue
key_name_a = _get_name_for_subgraph(
cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a
)
key_name_b = _get_name_for_subgraph(
cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
)
assert (
key_name_a == key_name_b
), f"Subgraph names {key_name_a} and {key_name_b} do not match"
results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
continue
elif cur_subgraph_a is None and cur_subgraph_b is None:
# we reached the end of both graphs
break
else:
# only one node was fetched, no match possible, throw error
msg = f"""
Attempting to match
({cur_subgraph_a}, {type_start_a}) and
({cur_subgraph_b}, {type_start_b}),
one of which is empty. Please ensure that the two models you pass in have the same number
of subgraphs."""
raise GraphMatchingException(msg)
# The subgraph pairs are originally created by traversing the two graphs
# from the outputs to the inputs. Reverse the results to return the
# subgraphs in their order of execution.
results = collections.OrderedDict(reversed(list(results.items())))
return results

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,759 @@
import operator
from typing import Callable, Dict, List, Optional, Set, Tuple
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend
import torch.ao.quantization.quantization_mappings as quantization_mappings
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.backend_config import get_native_backend_config
from .ns_types import NSNodeTargetType
toq = torch.ops.quantized
def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
# note: this set is modified below by items from backend_config
sets_of_related_ops: List[Set[NSNodeTargetType]] = [
# conv modules
{
nn.Conv1d,
},
{
nn.Conv2d,
},
{
nn.Conv3d,
},
# conv functionals
{
F.conv1d,
},
{
F.conv2d,
},
{
F.conv3d,
},
# linear modules
{
nn.Linear,
},
# linear functionals
{
F.linear,
},
# average pool
{
nn.AvgPool1d,
torch.avg_pool1d,
},
{
nn.AvgPool2d,
torch._C._nn.avg_pool2d,
},
{
nn.AvgPool3d,
torch._C._nn.avg_pool3d,
},
# adaptive average pool
{
nn.AdaptiveAvgPool1d,
F.adaptive_avg_pool1d,
},
{
nn.AdaptiveAvgPool2d,
F.adaptive_avg_pool2d,
},
{
nn.AdaptiveAvgPool3d,
F.adaptive_avg_pool3d,
},
# LSTM
{
nn.LSTM,
},
# add
{
torch.add,
operator.add, # x + y
},
# cat
{
torch.cat,
},
# mul
{
torch.mul,
operator.mul,
},
# relu
{
F.relu,
nn.ReLU,
"relu",
"relu_",
torch.relu,
},
# maxpool
{
nn.MaxPool1d,
F.max_pool1d,
},
{
nn.MaxPool2d,
F.max_pool2d,
},
{
nn.MaxPool3d,
F.max_pool3d,
},
# sigmoid
{
torch.sigmoid,
"sigmoid",
"sigmoid_",
nn.Sigmoid,
F.sigmoid,
},
# BatchNorm
{
nn.BatchNorm2d,
},
{
nn.BatchNorm3d,
},
# ConvTranspose
{
nn.ConvTranspose1d,
},
{
nn.ConvTranspose2d,
},
{
nn.ConvTranspose3d,
},
# functional transposed conv
{
F.conv_transpose1d,
},
{
F.conv_transpose2d,
},
{
F.conv_transpose3d,
},
# ELU
{
nn.ELU,
},
# Embedding
{
nn.Embedding,
},
# EmbeddingBag
{
nn.EmbeddingBag,
},
# GroupNorm
{
nn.GroupNorm,
},
# Hardswish
{
nn.Hardswish,
},
# InstanceNorm
{
nn.InstanceNorm1d,
},
{
nn.InstanceNorm2d,
},
{
nn.InstanceNorm3d,
},
# LayerNorm
{
nn.LayerNorm,
},
# LeakyReLU
{
nn.LeakyReLU,
},
# ReLU6
{
nn.ReLU6,
F.relu6,
},
# F.elu
{
F.elu,
},
# F.hardswish
{
F.hardswish,
},
# F.group_norm
{
F.group_norm,
},
# F.instance_norm
{
F.instance_norm,
},
# F.layer_norm
{
F.layer_norm,
},
# F.leaky_relu
{
F.leaky_relu,
},
# F.silu
{
nn.SiLU,
F.silu,
},
# F.mish
{
nn.Mish,
F.mish,
},
# F.tanh
{
nn.Tanh,
F.tanh,
torch.tanh,
"tanh_",
"tanh",
},
# F.hardsigmoid
{
"hardsigmoid_",
"hardsigmoid",
F.hardsigmoid,
nn.Hardsigmoid,
},
# F.hardtanh
{
nn.Hardtanh,
F.hardtanh,
F.hardtanh_,
},
# floordiv
{
operator.floordiv,
},
# unsqueeze
{
torch.unsqueeze,
},
# stack
{
torch.stack,
},
# squeeze
{
torch.squeeze,
},
# sort
{
torch.sort,
},
# repeat_interleave
{
torch.repeat_interleave,
},
# min
{
torch.min,
},
# mean
{
torch.mean,
},
# max
{
torch.max,
},
# transpose
{
torch.transpose,
},
# flatten
{
torch.flatten,
},
# clamp
{
torch.clamp,
},
# chunk
{
torch.chunk,
},
# interpolate
{
torch.nn.functional.interpolate,
},
# dropout
{
nn.Dropout,
},
# F.dropout
{
F.dropout,
},
# matmul
{
torch.matmul,
},
# Softmax
{
nn.Softmax,
},
# PReLU
{
nn.PReLU,
nnq.PReLU,
},
# F.prelu
{
F.prelu,
toq.prelu,
},
# pixel shuffle
{
nn.PixelShuffle,
},
{
F.pixel_shuffle,
},
# pixel unshuffle
{
nn.PixelUnshuffle,
},
{
F.pixel_unshuffle,
},
# narrow
{
torch.narrow,
},
]
# for each floating point op, add versions of the op added by
# backend_config
backend_config = get_native_backend_config()
new_connections: List[Tuple[Callable, Callable]] = [
# technical debt edge case
(nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
]
for pattern, config in backend_config._pattern_complex_format_to_config.items():
# pattern format: (c, (b, a))
first_element = pattern
# look from the end, because pattern is in reverse order
while isinstance(first_element, (list, tuple)):
first_element = first_element[-1]
if config.fused_module is not None:
# case 1: pattern fuses a pattern of ops into an op
# example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
new_connections.append((first_element, config.fused_module))
if config.qat_module is not None:
# case 2: pattern swaps a module into a QAT module
# example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
new_connections.append((first_element, config.qat_module))
if config.reference_quantized_module is not None:
# case 3: reference version of floating point module, such as
# nn.Conv2d and nnqr.Conv2d
new_connections.append((first_element, config.reference_quantized_module))
#
# Add reference module swaps from default lowering path
#
for source_to_target in (
_lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
_lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
_lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
_lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
):
for source, target in source_to_target.items(): # type: ignore[attr-defined]
new_connections.append((source, target))
for source_to_double_target in (
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
_lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
_lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
):
for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
new_connections.append((source, target1))
new_connections.append((source, target2))
#
# Add function swaps from default lowering path
#
for source, (
target1,
target2,
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1))
new_connections.append((source, target2))
for source_to_target in (
_lower_to_native_backend.QBIN_OP_MAPPING,
_lower_to_native_backend.QBIN_RELU_OP_MAPPING,
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
):
for source, target in source_to_target.items():
new_connections.append((source, target))
#
# Add other swaps, ideally in the future this could be removed
# after the lowering code stops using these.
#
for source_to_target in (
quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
):
for source, target in source_to_target.items():
new_connections.append((source, target))
# add the new connections from backend_config
for item1, item2 in new_connections:
for set_of_related_ops in sets_of_related_ops:
if item1 in set_of_related_ops or item2 in set_of_related_ops:
set_of_related_ops.add(item1)
set_of_related_ops.add(item2)
break
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
counter = 0
for set_of_related_ops in sets_of_related_ops:
base_name = str(counter)
counter += 1
base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
return base_name_to_sets_of_related_ops
def get_base_name_for_op(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
op: NSNodeTargetType,
) -> Optional[str]:
for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
if op in set_of_related_ops:
return base_name
return None
def add_op_to_sets_of_related_ops(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
op: NSNodeTargetType,
related_op: Optional[NSNodeTargetType],
) -> None:
if related_op is not None:
for set_of_related_ops in base_name_to_sets_of_related_ops.values():
if related_op in set_of_related_ops:
set_of_related_ops.add(op)
return
# if we got here, related_op was not found
raise AssertionError(f"{related_op} was not found")
else:
counter = 0
while str(counter) in base_name_to_sets_of_related_ops:
counter += 1
base_name_to_sets_of_related_ops[str(counter)] = {op}
# TODO(future PR): clean this up
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
F.linear,
F.conv1d,
F.conv2d,
F.conv3d,
torch.cat,
F.elu,
F.hardswish,
F.instance_norm,
F.layer_norm,
F.leaky_relu,
F.dropout,
F.silu,
F.mish,
operator.add,
torch.add,
operator.mul,
torch.mul,
torch.sum,
F.prelu,
}
FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
toq.linear,
toq.linear_relu,
toq.conv1d,
toq.conv1d_relu,
toq.conv2d,
toq.conv2d_relu,
toq.conv3d,
toq.conv3d_relu,
toq.cat,
toq.elu,
toq.hardswish,
toq.instance_norm,
toq.layer_norm,
toq.leaky_relu,
toq.dropout,
toq.prelu,
# TODO(future PR): implement shadowing for binary ops and
# uncomment below
# toq.add,
# toq.mul,
}
FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
F.relu,
F.tanh,
torch.tanh,
F.sigmoid,
torch.sigmoid,
F.hardsigmoid,
operator.floordiv,
torch.adaptive_avg_pool1d,
F.adaptive_avg_pool2d,
F.adaptive_avg_pool3d,
F.dropout,
F.hardtanh,
F.hardtanh_,
F.interpolate,
F.max_pool1d,
F.max_pool2d,
F.max_pool3d,
F.relu6,
F.pixel_shuffle,
F.pixel_unshuffle,
torch.avg_pool1d,
torch._C._nn.avg_pool2d,
torch._C._nn.avg_pool3d,
torch.cat,
torch.chunk,
torch.clamp,
torch.flatten,
torch.transpose,
torch.max,
torch.mean,
torch.min,
torch.narrow,
torch.repeat_interleave,
torch.sort,
torch.squeeze,
torch.stack,
torch.unsqueeze,
operator.add,
}
MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
nn.Linear,
nnqat.Linear,
nnqatd.Linear,
nnqd.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nnqat.Conv1d,
nnqat.Conv2d,
nnqat.Conv3d,
nnqat.Embedding,
nnqat.EmbeddingBag,
nn.LSTM,
# note: nnqd.Linear is an instance of nnq.Linear, so this
# check has to happen before the int8 module check
nnqd.LSTM,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.Dropout,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.ELU,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LayerNorm,
nn.Hardswish,
nn.LeakyReLU,
nn.ReLU6,
nn.SiLU,
nn.Mish,
nn.Softmax,
nn.PReLU,
nni.BNReLU2d,
nni.BNReLU3d,
nni.ConvReLU1d,
nni.ConvReLU2d,
nni.ConvReLU3d,
nni.LinearReLU,
nni.LinearBn1d,
nni.ConvBn1d,
nni.ConvBn2d,
nni.ConvBn3d,
nniqat.ConvBn1d,
nniqat.ConvBn2d,
nniqat.ConvBn3d,
nniqat.ConvBnReLU1d,
nniqat.ConvBnReLU2d,
nniqat.ConvBnReLU3d,
nniqat.ConvReLU1d,
nniqat.ConvReLU2d,
nniqat.ConvReLU3d,
nniqat.LinearReLU,
nniqat.LinearBn1d,
nniqd.LinearReLU,
nni.LinearLeakyReLU,
nni.LinearTanh,
nni.ConvAdd2d,
nni.ConvAddReLU2d,
}
MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
nnq.Linear,
nnq.Conv1d,
nnq.Conv2d,
nnq.Conv3d,
nnq.BatchNorm2d,
nnq.BatchNorm3d,
nnq.Dropout,
nnq.ConvTranspose1d,
nnq.ConvTranspose2d,
nnq.ELU,
nnq.InstanceNorm1d,
nnq.InstanceNorm2d,
nnq.InstanceNorm3d,
nnq.LayerNorm,
nnq.Hardswish,
nnq.LeakyReLU,
nnq.Embedding,
nnq.EmbeddingBag,
nnq.Dropout,
nnq.Softmax,
nnq.PReLU,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,
nniq.ConvReLU2d,
nniq.ConvReLU3d,
nniq.LinearReLU,
nniq.LinearLeakyReLU,
nniq.LinearTanh,
nniq.ConvAdd2d,
nniq.ConvAddReLU2d,
}
MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
nn.ReLU,
nn.Tanh,
nn.Sigmoid,
nn.Hardsigmoid,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.Dropout,
nn.Hardtanh,
nn.Identity,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.PixelShuffle,
nn.PixelUnshuffle,
nn.ReLU6,
}
METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
"sigmoid_",
"sigmoid",
"tanh_",
"tanh",
"hardsigmoid_",
"hardsigmoid",
"relu_",
"relu",
}
return {
"funs_io_type_fp32": FUNS_IO_TYPE_FP32,
"funs_io_type_fp16": FUNS_IO_TYPE_FP16,
"funs_io_type_int8": FUNS_IO_TYPE_INT8,
"funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8,
"mods_io_type_fp32": MODS_IO_TYPE_FP32,
"mods_io_type_int8": MODS_IO_TYPE_INT8,
"mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8,
"meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8,
}
def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
torch.quantize_per_tensor,
operator.getitem,
}
MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
nn.Identity,
}
METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
"to",
"dequantize",
"reshape",
"view",
"unsqueeze_",
"unsqueeze",
"transpose",
"squeeze_",
"squeeze",
"size",
"shape",
"resize_",
"repeat_interleave",
"repeat",
"permute",
"numel",
"mean",
"detach_",
"detach",
"contiguous",
"clamp",
"chunk",
}
return {
"funs_unmatchable": FUNS_UNMATCHABLE,
"mods_unmatchable": MODS_UNMATCHABLE,
"meths_unmatchable": METHS_UNMATCHABLE,
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,65 @@
import enum
from typing import Any, Callable, Dict, List, NamedTuple, Union
from torch.fx.graph import Node
class NSSingleResultValuesType(str, enum.Enum):
WEIGHT = "weight"
NODE_OUTPUT = "node_output"
NODE_INPUT = "node_input"
class NSSubgraph(NamedTuple):
start_node: Node
end_node: Node
base_op_node: Node
# TODO(future PR): see if we can use typing_extensions's TypedDict instead
# to properly type the various keys
# {
# # one of NSSingleResultValuesType
# 'type': 'weight',
# # the values of type specified above
# 'values': [torch.tensor(...), ...],
# # name of the node directly before the logger
# 'prev_node_name': 'linear1',
# # type of the underlying function or module
# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc
# # name of the node responsible for adding this logger
# # Note: this may differ from prev_node_name if we are logging inputs
# 'ref_node_name': 'linear1',
# # index of this node within the arg of the input/output node
# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
# 'index_within_arg': 0,
# # index of this node within the args of the input/output node
# # for example, in add(x1, x2), x2 would have index_of_arg == 1
# 'index_of_arg': 0,
# # precomputed comparisons of logger values to reference values
# 'comparisons': [torch.tensor(...), ...]
# # name of function used for precomputed comparisons
# 'comparison_fn_name': 'sqnr',
# # string representation of qconfig responsible for creating this logger
# 'qconfig_str': 'QConfig(...)',
# }
NSSingleResultType = Dict[str, Any]
# {
# 'layer_name_1': { # subgraph name
# 'node_output': { # results type (node_output, node_input, weight)
# 'model_name_a': # model name
# [NSSingleResultType, ...], # results, ordered by index_within_arg
# 'model_name_b':
# [NSSingleResultType, ...],
# },
# },
# }
#
NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]]
# Defines the underlying target type of a node, for example:
# `F.conv1d` for a `call_function` conv node
# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
# `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
NSNodeTargetType = Union[Callable, str]

View File

@ -0,0 +1,209 @@
from typing import Any, Callable, Dict, List, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
from torch.ao.quantization.backend_config import get_native_backend_config
from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
from torch.ao.quantization.utils import getattr_from_fqn
from torch.fx import GraphModule
from torch.fx.graph import Node
from .ns_types import NSNodeTargetType
toq = torch.ops.quantized
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
# TODO(future PR): allow customizations
# TODO(future PR): reuse existing quantization mappings
# TODO(future PR): add the rest of modules and ops here
type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
for s in base_name_to_sets_of_related_ops.values():
s_list = list(s)
# add every bidirectional pair
for idx_0 in range(0, len(s_list)):
for idx_1 in range(idx_0, len(s_list)):
type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
return type_a_related_to_b
NSFusionElType = Union[
Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
str, # call_method name, example: "dequantize"
Tuple[
str, Any
], # call_method name and first argument, example: ("to", torch.float16)
]
NSFusionType = Union[
Tuple[NSFusionElType, NSFusionElType],
Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
]
def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
"""
Set of potential fusions, in reverse order. The order is reversed
to match how fusion patterns are defined in quantization code.
Fusion format:
((fusion_op_0, fusion_op_1), base_op_idx)
Where base_op_idx is the idx of the op we should use to match other related
ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
of 0 represents the first op in regular (non-reverse) order, 1 represents the
second op, etc.
"""
results: List[Tuple[NSFusionType, int]] = []
# Possible syntaxes:
# * single op: torch.nn.Conv2d
# * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
# For fusions, we only care about patterns composed of multiple ops.
# TODO(future PR): allow customizations from default patterns.
all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
default_base_op_idx = 0
for quant_pattern in all_quant_patterns.keys():
# TODO: this is a temporary hack to flatten the patterns from quantization so
# that it works with the ns matcher function, maybe we should use `_is_match`
# in torch.ao.quantization.fx.match_utils to match the patterns
if (
isinstance(quant_pattern, tuple)
and len(quant_pattern) == 2
and isinstance(quant_pattern[1], tuple)
and len(quant_pattern[1]) == 2
):
# flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
# Only patterns of multiple ops are fusions, ignore
# patterns which contain a single ops (they get matched
# without caring about fusions).
if isinstance(quant_pattern, tuple):
results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
# For each pattern, add additional patterns with observers and
# fake quants at the end.
# TODO(future PR): if needed, implement matching for a node
# having multiple output observers.
for cls in (ObserverBase, FakeQuantizeBase):
if isinstance(quant_pattern, tuple):
new_pattern = (cls, *quant_pattern)
else:
new_pattern = (cls, quant_pattern)
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
# After this point, results contains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
# Patterns for matching fp16 emulation are not specified in the quantization
# fusion mappings. For now, define them here.
fp16_em_base_op_idx = 1
patterns_to_add = [
# linear-relu fp16 emulation:
# fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
(
(("to", torch.float16), F.relu, F.linear, "dequantize"),
fp16_em_base_op_idx,
),
# Conv-BN fusion (this happens outside of quantization patterns,
# which is why it is defined separately here).
((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
]
for p in patterns_to_add:
results.append(p) # type: ignore[arg-type]
results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
return results
def end_node_matches_reversed_fusion(
end_node: Node,
reversed_fusion: NSFusionType,
gm: GraphModule,
seen_nodes: Set[Node],
) -> bool:
"""
Returns true if a pattern ending with `end_node` matches
the fusion pattern.
"""
cur_node = end_node
for fusion_idx in range(len(reversed_fusion)):
# each node can only belong to one matched pattern
if cur_node in seen_nodes:
return False
cur_fusion_el = reversed_fusion[fusion_idx]
if cur_node.op == "call_function":
fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and (
not isinstance(cur_fusion_el, type)
)
if fusion_el_is_fun:
if cur_node.target != cur_fusion_el:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == "call_module":
fusion_el_is_mod = isinstance(cur_fusion_el, type)
if fusion_el_is_mod:
assert isinstance(cur_node.target, str)
target_mod = getattr_from_fqn(gm, cur_node.target)
if not isinstance(cur_fusion_el, type):
return False
if not isinstance(target_mod, cur_fusion_el):
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
elif cur_node.op == "call_method":
fusion_el_is_meth_with_second_arg = (
isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
)
fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
if fusion_el_is_meth_without_args:
if cur_node.target != cur_fusion_el:
return False
else:
assert isinstance(cur_fusion_el, tuple)
if cur_node.target != cur_fusion_el[0]:
return False
elif len(cur_node.args) < 2:
return False
elif cur_node.args[1] != cur_fusion_el[1]:
return False
if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
cur_node = cur_node.args[0]
else:
return False
else:
return False
else:
return False
return True

View File

@ -0,0 +1,249 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import copy
from typing import Any, Callable, Dict, List, TYPE_CHECKING, Union
import torch
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
if TYPE_CHECKING:
from torch.ao.quantization.qconfig import QConfigAny
__all__ = ["QConfigMultiMapping"]
_QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
"global_qconfig": "set_global",
"object_type_qconfigs": "set_object_type",
"module_name_regex_qconfigs": "set_module_name_regex",
"module_name_qconfigs": "set_module_name",
"module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
}
def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
to_remove = []
for index, cur_qconfig in enumerate(qconfig_list):
if cur_qconfig is None:
to_remove.append(index)
break
for checked_qconfig in qconfig_list[:index]:
if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
to_remove.append(index)
break
for index in to_remove[::-1]:
qconfig_list.pop(index)
class QConfigMultiMapping:
"""
This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
so that multiple QConfigs can be specified for each QConfig matching style.
The user can specify QConfigs using the following methods (in increasing match priority):
``set_global`` : sets the global (default) QConfigs
``set_object_type`` : sets the QConfigs for a given module type, function, or method name
``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
``set_module_name`` : sets the QConfigs for modules matching the given module name
``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
of the given module name, object type, and the index at which the module appears
Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
single QConfig.
Example usage::
qconfig_mapping = QConfigMultiMapping()
.set_global([qconfig1, qconfig2])
.set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
.set_object_type(torch.nn.ReLU, [qconfig1])
.set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
.set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
.set_module_name("module1", [None])
.set_module_name("module2", [qconfig2])
.set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
"""
def __init__(self) -> None:
# initialize this with 1 QConfigMapping to avoid corner cases
self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
def _handle_list_size_mismatch(
self, qconfig_list: List[QConfigAny], style: str
) -> None:
# this method handles cases where the size of qconfig_list does not match
# the size of qconfig_mappings_list.
# Issue: Consider a user inserting global_qconfig A and B first, then inserting
# qconfig C as an object_type_qconfig for conv ops. If we internally store
# 1 QConfigMapping with A and C and another with just B, then the
# second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
# we avoid this by maintaining the invariant that if any QConfigMapping
# has a qconfig style+key with a qconfig in it, all QConfigMappings must
# have either a qconfig or None for that same style+key. In the above
# example, a None qconfig would prevent the unwanted match in the
# second QConfigMapping
if len(qconfig_list) > len(self.qconfig_mappings_list):
# Case: we have more qconfigs (in qconfig_list) than QConfigMappings
# Add new QConfigMappings (initialized so we maintain the `invariant`)
new_qconfig_mapping = QConfigMapping()
# searches other QConfigMappings for qconfig style+keys
# that need to be inserted as `None` into the new QConfigMapping
for qconfig_mapping in self.qconfig_mappings_list:
# global_qconfig has None by default
for check_style in _QCONFIG_STYLE_ORDER[1:]:
qconfigs_dict = getattr(qconfig_mapping, check_style)
target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
for key in qconfigs_dict:
target_qconfigs_dict[key] = None
break
# insert copies of this new QConfigMapping until all entires
# in qconfig_list can fit among the QConfigMappings
while len(qconfig_list) > len(self.qconfig_mappings_list):
self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
else:
# Case: we have fewer qconfigs in qconfig_list than QConfigMappings
# pad qconfig_list with `None` until length is same
while len(qconfig_list) < len(self.qconfig_mappings_list):
qconfig_list.append(None)
# this function applies the insertion method across each QConfigMapping
def _insert_qconfig_list(
self,
style: str,
args: List[Union[str, int, Callable]],
qconfig_list: List[QConfigAny],
) -> None:
# we remove duplicates and None to make the ordering of qconfigs
# deterministic upon insertion.
_remove_duplicates_and_none(qconfig_list)
self._handle_list_size_mismatch(qconfig_list, style)
method_name = _QCONFIG_STYLE_TO_METHOD[style]
for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
# uses QConfigMapping set method to insert qconfig
set_method = getattr(qconfig_mapping, method_name)
set_method(*args, qconfig)
def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
"""
Set global QConfigs
see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
"""
self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
return self
def set_object_type(
self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
) -> QConfigMultiMapping:
"""
Set object type QConfigs
see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
"""
self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
return self
def set_module_name_regex(
self, module_name_regex: str, qconfig_list: List[QConfigAny]
) -> QConfigMultiMapping:
"""
Set module_name_regex QConfigs
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
"""
self._insert_qconfig_list(
"module_name_regex_qconfigs", [module_name_regex], qconfig_list
)
return self
def set_module_name(
self, module_name: str, qconfig_list: List[QConfigAny]
) -> QConfigMultiMapping:
"""
Set module_name QConfigs
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
"""
self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
return self
def set_module_name_object_type_order(
self,
module_name: str,
object_type: Callable,
index: int,
qconfig_list: List[QConfigAny],
) -> QConfigMultiMapping:
"""
Set module_name QConfigs
see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
"""
self._insert_qconfig_list(
"module_name_object_type_order_qconfigs",
[module_name, object_type, index],
qconfig_list,
)
return self
def __repr__(self):
return (
self.__class__.__name__
+ " ["
+ "".join(
f"\n{qconfig_mapping.__repr__()},"
for qconfig_mapping in self.qconfig_mappings_list
)
+ "\n]"
)
@classmethod
def from_list_qconfig_mapping(
cls, qconfig_mapping_list: List[QConfigMapping]
) -> QConfigMultiMapping:
"""
Creates a QConfigMultiMapping from a list of QConfigMappings
"""
new_qconfig_multi_mapping = cls()
new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
qconfig_mapping_list
)
# we need to avoid the issue described in _handle_list_size_mismatch,
# so we reinsert all the qconfigs using the QConfigMultiMapping
# set methods
# go through all qconfig styles
# note: global can be ignored since it is None by default
for style in _QCONFIG_STYLE_ORDER[1:]:
# gather all key+qconfigs for current style
# into qconfig_dict_list
qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
for qconfig_mapping in qconfig_mapping_list:
qconfig_dict = getattr(qconfig_mapping, style)
for key, qconfig in qconfig_dict.items():
if key not in qconfig_dict_list:
qconfig_dict_list[key] = []
qconfig_dict_list[key].append(qconfig)
# reinsert all gathered key+qconfigs
set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
set_method = getattr(new_qconfig_multi_mapping, set_method_name)
for key, qconfig_list in qconfig_dict_list.items():
if isinstance(key, tuple):
set_method(*key, qconfig_list)
else:
set_method(key, qconfig_list)
return new_qconfig_multi_mapping

View File

@ -0,0 +1,540 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import enum
import operator
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq
import torch.nn as nn
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
from torch.ao.quantization.observer import _is_activation_post_process
from torch.ao.quantization.utils import getattr_from_fqn
from torch.fx import GraphModule
from torch.fx.graph import Node
from .ns_types import NSNodeTargetType, NSResultsType
toq = torch.ops.quantized
# TODO(future PR): consider deleting this enum and using the torch types
# directly. This might be tricky because it is not a one to one mapping.
class NodeInputOrOutputType(enum.Enum):
FP32 = enum.auto() # torch.float
INT8 = enum.auto() # torch.qint8 or torch.quint8
FP16 = enum.auto() # torch.float16
UNKNOWN = enum.auto() # we cannot determine input/output dtype
# TODO(future PR): while these functions can support multiple dtypes,
# for the purposes of numerical debugging we want to get the actual
# dtype used in the model. We will likely need some kind of dtype
# propagation to estimate this.
FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
# TODO(future PRs): dynamic quant, fake quant, etc
def get_node_first_input_and_output_type(
node: Node,
gm: GraphModule,
logger_cls: Callable,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
# TODO(future PR): clean this up
FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
if node.op == "call_function":
if node.target in FUNS_IO_TYPE_FP32:
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
if node.target in FUNS_IO_TYPE_FP16:
return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
elif node.target in FUNS_IO_TYPE_INT8:
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_module":
assert node.op == "call_module"
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
is_known_fp32_or_int8_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
)
if (
isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
or is_known_fp32_or_int8_input_module
):
# A logger or observer's input and output type is the output
# type of the preceding node.
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
is_known_fp32_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
)
is_known_int8_input_module = any(
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
)
if is_known_fp32_input_module:
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
elif is_known_int8_input_module:
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
elif node.op == "call_method":
if node.target == "dequantize":
# Dequantize is a special node because it allows multiple input types.
# So, we look up the output type of the previous node and return that
# as the input type of this node instance.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
prev_node, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, NodeInputOrOutputType.FP32)
elif node.target == "to":
# to is a special node because it allows multiple input types.
# So, we look up the output type of the previous node and return that
# as the input type of this node instance. We also look up the target
# of to and return the correct output type.
prev_node = get_normalized_nth_input(node, gm, 0)
assert isinstance(prev_node, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
prev_node, gm, logger_cls, node_type_to_io_type_map
)
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
assert (
cur_node_dtype_target is torch.float16
), f"{cur_node_dtype_target} handling needs to be added"
return (prev_node_output_type, NodeInputOrOutputType.FP16)
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
first_arg = get_normalized_nth_input(node, gm, 0)
assert isinstance(first_arg, Node)
(
_prev_node_input_type,
prev_node_output_type,
) = get_node_first_input_and_output_type(
first_arg, gm, logger_cls, node_type_to_io_type_map
)
return (prev_node_output_type, prev_node_output_type)
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
else:
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
def get_node_input_qparams(
node: Node,
gm: GraphModule,
node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
"""
Returns the qparams (scale, zero_point) of the first input to `node`,
if they can be inferred from the graph.
"""
prev_node = get_normalized_nth_input(node, gm, 0)
if not isinstance(prev_node, Node):
return None
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
scale_obj = getattr_from_fqn(gm, scale_node.target)
zp_obj = getattr_from_fqn(gm, zp_node.target)
return (scale_obj, zp_obj)
if prev_node.op == "call_function":
# quantize - read the args directly
if prev_node.target == torch.quantize_per_tensor:
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
return None
# TODO(future PR): handle more functionals
# TODO(future PR): handle functional ops which inherit qparams from input
elif prev_node.op == "call_module":
# get type of the module
assert isinstance(prev_node.target, str)
module_obj = getattr_from_fqn(gm, prev_node.target)
if isinstance(
module_obj,
(
nnq.Linear,
nnq.Conv1d,
nnq.Conv2d,
nniq.ConvReLU2d,
nnq.Conv3d,
nnq.BatchNorm2d,
nnq.BatchNorm3d,
nnq.ConvTranspose1d,
nnq.ConvTranspose2d,
nnq.ELU,
nnq.GroupNorm,
nnq.InstanceNorm1d,
nnq.InstanceNorm2d,
nnq.InstanceNorm3d,
nnq.LayerNorm,
nnq.Hardswish,
nnq.LeakyReLU,
nnq.ReLU6,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,
nniq.ConvReLU2d,
nniq.ConvReLU3d,
nniq.LinearReLU,
),
):
return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
is_known_fp32_or_int8_input_module = any(
isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
)
if is_known_fp32_or_int8_input_module:
return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
return None
def return_first_non_observer_node(
node: Node,
gm: GraphModule,
) -> Node:
"""
If node is not an observer, returns it. If node is an observer,
navigates up the graph and returns the first parent which is not an
observer. For example,
graph: (node_non_obs), node = node_non_obs : returns node_non_obs
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
# code duplication intended, not worth refactoring
assert isinstance(node.target, str)
node_obj = getattr_from_fqn(gm, node.target)
if _is_activation_post_process(node_obj):
assert len(node.args) == 1
assert isinstance(node.args[0], Node)
node = node.args[0]
return node
def get_number_of_non_param_args(
node: Node,
gm: GraphModule,
) -> int:
"""
Assumes that all non-param args occur first. Returns the number of
non-param args expected for a node. For example, for
F.linear(x, weight, bias)
Returns 1, because x is a non-param arg and weight and bias are params.
For
lstm_mod(x, hid)
Returns 2, because both x and hid are non-param args.
"""
if node.op == "call_module":
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
if isinstance(node_obj, nn.LSTM):
return 2
# default is 1
return 1
def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
"""
Returns the indices of args of the node which we should attach
loggers to, if input logging is enabled.
For example,
* for (x + y), returns [0, 1]
* for (1 + y), returns [1]
* for (x + 1), returns [0]
* for (linear(x, w, b)) returns [0]
* by default, returns [0]
"""
if len(node.args) == 0:
return []
if node.op == "call_function" and (
# TODO(future PR): use relationship map instead of hardcoding
node.target in (torch.add, torch.ops.quantized.add, operator.add)
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
):
result = []
for i in range(2):
if type(node.args[i]) == Node:
result.append(i)
return result
return [0]
def get_target_type_str(node: Node, gm: GraphModule) -> str:
"""
Returns a string representation of the type of the function or module
pointed to by this node, or '' for other node types.
"""
target_type = ""
if node.op in ("call_function", "call_method"):
target_type = torch.typename(node.target)
elif node.op == "call_module":
assert isinstance(node.target, str)
target_mod = getattr_from_fqn(gm, node.target)
target_type = torch.typename(target_mod)
return target_type
def rekey_logger_info_on_node_name_of_model(
results: NSResultsType,
model_name: str,
) -> NSResultsType:
"""
Rekeys the layer name of a results dictionary to use node names
from `model_name`.
For example, transforms
{'base_op_1_0': {'node_output': {'model_a':
[{'ref_node_name': 'linear1', ...}]}}}
into
{'linear1': {'node_output': {'model_a':
[{'ref_node_name': 'linear1', ...}]}}}
Note: we cannot use these node names directly because they are not
guaranteed to be consistent across models. This is why we extract
the results first and rekey afterwards.
"""
new_results = {}
for old_layer_name, result_type_to_results in results.items():
new_layer_name = None
for model_name_to_results in result_type_to_results.values():
for cur_model_name, list_of_results in model_name_to_results.items():
if cur_model_name == model_name:
assert len(list_of_results)
new_layer_name = list_of_results[0]["ref_node_name"]
else:
continue
if new_layer_name is not None:
new_results[new_layer_name] = result_type_to_results
else:
new_results[old_layer_name] = result_type_to_results
return new_results
def maybe_add_missing_fqns(results: NSResultsType) -> None:
"""
If `fqn` entries are filled in for one of the models in `results`, copies
them over to any models which do not have them filled out.
A common use case benefitting from this is comparing a model prepared by
quantization to a quantized model. In this case, the model prepared by
quantization would have `fqn` entries, and the quantized model would not.
"""
# Check in the first result to find any model with fqn entries defined.
model_name_with_fqns = None
for result_type_to_results in results.values():
for model_name_to_results in result_type_to_results.values():
for model_name, model_results in model_name_to_results.items():
if len(model_results) > 0:
if model_results[0]["fqn"] is not None:
model_name_with_fqns = model_name
break
break
break
if model_name_with_fqns:
for result_type_to_results in results.values():
for model_name_to_results in result_type_to_results.values():
ref_model_results = model_name_to_results[model_name_with_fqns]
for model_name, model_results in model_name_to_results.items():
if model_name == model_name_with_fqns:
continue
for i in range(len(model_results)):
fqn = ref_model_results[i]["fqn"]
model_results[i]["fqn"] = fqn
def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
def inner(*args, **kwargs):
a0, a1, *a_other = args
if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
isinstance(a0, list) and isinstance(a1, list)
):
results = []
for el0, el1 in zip(a0, a1):
new_args = (el0, el1, *a_other)
results.append(inner(*new_args, **kwargs))
return results
elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
if a0.is_quantized:
a0 = a0.dequantize()
if a1.is_quantized:
a1 = a1.dequantize()
# for the purposes of this util, only handle floats
if a0.dtype != torch.float or a1.dtype != torch.float:
return None
new_args = (a0, a1, *a_other)
return f(*new_args, **kwargs)
return inner
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the SQNR between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the normalized L2 error between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the cosine similarity between `x` and `y`.
Args:
x: Tensor or tuple of tensors
y: Tensor or tuple of tensors
Return:
float or tuple of floats
"""
# For convolutions, the shape of the quantized weight has one additional
# dimension compared to the shape of the fp32 weight. Match the shapes
# to enable cosine similarity comparison.
x = x.reshape(1, -1)
y = y.reshape(1, -1)
return torch.nn.functional.cosine_similarity(x, y)
def op_type_supports_shadowing(node: Node) -> bool:
if node.op == "call_function":
if node.target in (
torch.add,
torch.mul,
operator.add,
operator.mul,
torch.cat,
torch.stack,
):
# shadowing for ops with multiple tensor inputs is not implemented yet
return False
return True
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
"""
Given a node, gets the n'th input to that node, normalizing
args and kwargs to the best of its ability.
"""
try:
norm_args_and_kwargs = node.normalized_arguments(
gm, normalize_to_only_use_kwargs=True
)
if norm_args_and_kwargs is not None:
norm_args, norm_kwargs = norm_args_and_kwargs
assert len(norm_args) + len(norm_kwargs) > idx
if idx < len(norm_args):
return norm_args[idx]
else:
# note: in Python 3.7+ dicts are ordered
return list(norm_kwargs.values())[idx]
else:
assert len(node.args) + len(node.kwargs) > idx
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
kwargs_idx = idx + len(node.args)
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
except RuntimeError:
# this RuntimeError happens when node argument normalization
# requires typehints to proceed, such as for torch.add where
# either the first, second or both arguments could be tensors
assert len(node.args) + len(node.kwargs) > idx
if idx < len(node.args):
return node.args[idx] # type: ignore[return-value]
else:
kwargs_idx = idx + len(node.args)
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]

View File

@ -0,0 +1,280 @@
from typing import Callable, Dict, List, Optional
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.qat as nnqat
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import GraphModule
from torch.fx.graph import Node
from .ns_types import NSSingleResultType, NSSingleResultValuesType
from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node
toq = torch.ops.quantized
def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
return mod.weight.detach() # type: ignore[operator]
def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
return mod[0].weight.detach() # type: ignore[index]
def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
return mod._weight_bias()[0] # type: ignore[operator]
def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
res = []
for idx, param_name in enumerate(mod._flat_weights_names): # type: ignore[arg-type]
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
param_value = mod._flat_weights[idx].detach() # type: ignore[index]
res.append(param_value)
return res
def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
res = []
for weight_value in mod._all_weight_values: # type: ignore[union-attr]
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
return res
def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
return mod.weight.detach()
elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)):
return mod[0].weight.detach()
else:
return mod._weight_bias()[0] # type: ignore[operator]
def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
if isinstance(mod, nn.Linear):
return mod.weight.detach()
elif isinstance(mod, nni.LinearReLU):
return mod[0].weight.detach()
else:
return mod._weight_bias()[0] # type: ignore[operator]
def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
# TODO(future PR): make more generic, handle everything
if isinstance(mod, nn.LSTM):
res = []
for idx, param_name in enumerate(mod._flat_weights_names):
if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
param_value = mod._flat_weights[idx].detach()
res.append(param_value)
return res
else:
assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
res = []
for weight_value in mod._all_weight_values:
res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
return res
def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# traverse backwards from the weight arg, accounting for any observers
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
weight_node = return_first_non_observer_node(weight_arg_node, gm)
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# qconv state is arg 1
qconv_state_node = node.args[1]
assert isinstance(qconv_state_node, Node)
assert qconv_state_node.op == "get_attr"
qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target) # type: ignore[arg-type]
return qconv_state_obj.weight()
def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# traverse backwards from the weight arg, accounting for any observers
# supported patterns:
# weight -> obs -> linear
# weight -> to(torch.float16) -> dequantize -> linear
linear_second_arg = node.args[1]
assert isinstance(linear_second_arg, Node)
if linear_second_arg.op == "call_module":
# weight -> obs -> linear
weight_arg_node = node.args[1]
assert isinstance(weight_arg_node, Node)
weight_node = weight_arg_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
return weight.detach()
elif linear_second_arg.op == "call_method":
# weight -> to(torch.float16) -> dequantize -> linear
assert linear_second_arg.op == "call_method"
dequant_node = node.args[1]
assert isinstance(dequant_node, Node)
to_fp16_node = dequant_node.args[0]
assert isinstance(to_fp16_node, Node)
# extract the dtype, so we can cast to it before returning
target_dtype = to_fp16_node.args[1]
weight_node = to_fp16_node.args[0]
assert isinstance(weight_node, Node)
assert weight_node.op == "get_attr"
weight = getattr_from_fqn(gm, weight_node.target) # type: ignore[arg-type]
# return the weight with fp16 cast
return weight.detach().to(target_dtype)
else:
assert linear_second_arg.op == "get_attr"
weight = getattr_from_fqn(gm, linear_second_arg.target) # type: ignore[arg-type]
return weight.detach()
def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
# packed weight is arg 1
packed_weight_node = node.args[1]
assert isinstance(packed_weight_node, Node)
assert packed_weight_node.op == "get_attr"
packed_weight = getattr_from_fqn(gm, packed_weight_node.target) # type: ignore[arg-type]
# TODO(future PR): why does packed_weight.unpack() not work?
(weight, _bias), _name = packed_weight.__getstate__()
return weight
def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
"call_module": {
# Conv1d
nn.Conv1d: mod_weight_detach,
nni.ConvReLU1d: mod_0_weight_detach,
nnq.Conv1d: mod_weight_bias_0,
nnqat.Conv1d: mod_weight_detach,
nniqat.ConvBn1d: mod_weight_detach,
nniqat.ConvBnReLU1d: mod_weight_detach,
nniqat.ConvReLU1d: mod_weight_detach,
nniq.ConvReLU1d: mod_weight_bias_0,
# Conv2d
nn.Conv2d: mod_weight_detach,
nni.ConvReLU2d: mod_0_weight_detach,
nnq.Conv2d: mod_weight_bias_0,
nnqat.Conv2d: mod_weight_detach,
nniqat.ConvBn2d: mod_weight_detach,
nniqat.ConvBnReLU2d: mod_weight_detach,
nniqat.ConvReLU2d: mod_weight_detach,
nniq.ConvReLU2d: mod_weight_bias_0,
# Conv3d
nn.Conv3d: mod_weight_detach,
nni.ConvReLU3d: mod_0_weight_detach,
nnq.Conv3d: mod_weight_bias_0,
nnqat.Conv3d: mod_weight_detach,
nniqat.ConvBn3d: mod_weight_detach,
nniqat.ConvBnReLU3d: mod_weight_detach,
nniqat.ConvReLU3d: mod_weight_detach,
nniq.ConvReLU3d: mod_weight_bias_0,
# Linear
nn.Linear: mod_weight_detach,
nnq.Linear: mod_weight_bias_0,
nni.LinearReLU: mod_0_weight_detach,
nniq.LinearReLU: mod_weight_bias_0,
nnqat.Linear: mod_weight_detach,
nnqd.Linear: mod_weight_bias_0,
nniqat.LinearReLU: mod_weight_detach,
nniqat.LinearBn1d: mod_weight_detach,
nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
# LSTM
nn.LSTM: get_lstm_weight,
nnqd.LSTM: get_qlstm_weight,
},
"call_function": {
# Conv
F.conv1d: get_conv_fun_weight,
F.conv2d: get_conv_fun_weight,
F.conv3d: get_conv_fun_weight,
toq.conv1d: get_qconv_fun_weight,
toq.conv2d: get_qconv_fun_weight,
toq.conv3d: get_qconv_fun_weight,
toq.conv1d_relu: get_qconv_fun_weight,
toq.conv2d_relu: get_qconv_fun_weight,
toq.conv3d_relu: get_qconv_fun_weight,
# Linear
F.linear: get_linear_fun_weight,
toq.linear: get_qlinear_fun_weight,
toq.linear_relu: get_qlinear_fun_weight,
},
}
return op_to_type_to_weight_extraction_fn
def extract_weight_from_node(
node: Node,
gm: GraphModule,
op_to_type_to_weight_extraction_fn: Optional[
Dict[str, Dict[Callable, Callable]]
] = None,
) -> Optional[NSSingleResultType]:
res_type = NSSingleResultValuesType.WEIGHT.value
# Not all graphmodules have _node_name_to_scope, so only fill it
# out if it exists.
fqn = None
if hasattr(gm, "_node_name_to_scope"):
fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]
if op_to_type_to_weight_extraction_fn is None:
op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
ref_node_type = get_target_type_str(node, gm)
# for extracting weights, these are always the same
prev_node_type = ref_node_type
if node.op == "call_function":
function_mapping = op_to_type_to_weight_extraction_fn["call_function"]
for target_fn_type, weight_extraction_fn in function_mapping.items():
if node.target == target_fn_type:
weight = weight_extraction_fn(node, gm)
return {
"type": res_type,
"values": [weight],
"prev_node_name": node.name,
"prev_node_target_type": prev_node_type,
"ref_node_name": node.name,
"ref_node_target_type": ref_node_type,
"index_within_arg": 0,
"index_of_arg": 0,
"fqn": fqn,
}
elif node.op == "call_module":
# for call_module, we need to look up the modules to do the type check
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
for target_mod_type, weight_extraction_fn in module_mapping.items():
if type(mod) == target_mod_type:
weight = weight_extraction_fn(mod)
return {
"type": res_type,
"values": [weight],
"prev_node_name": node.name,
"prev_node_target_type": prev_node_type,
"ref_node_name": node.name,
"ref_node_target_type": ref_node_type,
"index_within_arg": 0,
"index_of_arg": 0,
"fqn": fqn,
}
return None