Files
2024-10-30 22:14:35 +01:00

64 lines
2.1 KiB
Python

import onnx
from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase
class QSplit(QuantOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)
def quantize(self):
node = self.node
(
quantized_input_names,
zero_point_names,
scale_names,
nodes,
) = self.quantizer.quantize_activation(node, [0])
if quantized_input_names is None:
return super().quantize()
quantized_node_name = ""
if node.name:
quantized_node_name = node.name + "_quant"
kwargs = {}
for attribute in node.attribute:
kwargs.update(attribute_to_kwarg(attribute))
# Output just derive the scale/zero from input
quantized_output_names = []
for output_name in node.output:
quantized_output_name = output_name + "quantized"
quantized_output_names.append(quantized_output_name)
q_output = QuantizedValue(
output_name,
quantized_output_name,
scale_names[0],
zero_point_names[0],
QuantizedValueType.Input,
)
self.quantizer.quantized_value_map[output_name] = q_output
if len(node.input) > 1:
quantized_input_names.extend(node.input[1:])
quantized_node = onnx.helper.make_node(
node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs
)
nodes.append(quantized_node)
self.quantizer.new_nodes += nodes
class QDQSplit(QDQOperatorBase):
def quantize(self):
node = self.node
assert node.op_type == "Split"
if not self.quantizer.is_tensor_quantized(node.input[0]):
self.quantizer.quantize_activation_tensor(node.input[0])
if not self.disable_qdq_for_node_output:
for output in node.output:
self.quantizer.quantize_output_same_as_input(output, node.input[0], node.name)