# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- from __future__ import annotations import argparse import copy import importlib import logging import os import numpy as np import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from packaging import version from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader from .onnx_model import ONNXModel from .quant_utils import QuantFormat, attribute_to_kwarg logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.INFO) logger = logging.getLogger(__name__) class WeightOnlyQuantConfig: def __init__(self, algorithm, quant_format): """This is the Base class for Weight Only Quant Configuration. Args: algorithm: weight only quantize algorithm name. quant_format: QuantFormat{QOperator, QDQ}. QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. """ self.algorithm = algorithm self.quant_format = quant_format class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, ratios=None, quant_format=QuantFormat.QOperator, ): """ This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration. RTN is the most straightforward way to quantize weight using scale maps. Args: ratios: percentile of clip. Defaults to {}. quant_format (QuantFormat{QOperator, QDQ}, optional): QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format" if ratios is None: ratios = {} super().__init__( algorithm="RTN", quant_format=quant_format, ) self.ratios = ratios class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, calibration_data_reader: CalibrationDataReader, percdamp=0.01, block_size=128, actorder=False, mse=False, perchannel=True, quant_format=QuantFormat.QOperator, ): """ This is a class for GPTQ algorithm Weight Only Quant Configuration. GPTQ algorithm provides more accurate quantization but requires more computational resources. Args: calibration_data_reader: a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: percent of the average Hessian diagonal to use for dampening. block_size (int, optional): channel number in one block to execute a GPTQ quantization iteration. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. mse (bool, optional): whether get scale and zero point with mse error. perchannel (bool, optional): whether quantize weight per-channel. quant_format (QuantFormat{QOperator, QDQ}, optional): QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format" super().__init__( algorithm="GPTQ", quant_format=quant_format, ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, block_size=128, bits=4, axis=1, quant_format=QuantFormat.QOperator, ): """ This is a class for HQQ algorithm Weight Only Quant Configuration. HQQ algorithm quant weight without needing calibrate data. Args: block_size (int, optional): channel number in one block to execute a HQQ quantization iteration. bits (int, optional): how many bits to represent weight. axis (int, optional): 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf quant_format (QuantFormat{QOperator, QDQ}, optional): QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format" super().__init__( algorithm="HQQ", quant_format=quant_format, ) self.block_size = block_size self.bits = bits self.axis = axis class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, quant_format=QuantFormat.QOperator, ): """ This is a class for weight only affine quantization configuration. Args: block_size (int, optional): channel number in one block to execute an affine quantization iteration. is_symmetric (bool, optional): whether quantize weight symmetrically. accuracy_level (int, optional): Accuracy level of the 4-bit quantized MatMul computation. Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details. (https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) quant_format (QuantFormat{QOperator, QDQ}, optional): QOperator format quantizes the model with quantized operators directly. QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. Defaults to QuantFormat.QOperator. """ super().__init__(algorithm="DEFAULT", quant_format=quant_format) self.block_size = block_size self.is_symmetric = is_symmetric self.bits = 4 self.accuracy_level = accuracy_level def is_divisible(val1, val2): return int(val2 * np.ceil(val1 / val2)) == val1 class HQQWeightOnlyQuantizer: def __init__( self, config: HQQWeightOnlyQuantConfig, ): self.config = config # Proximal solver || weight - dequantize(quantize(weight))||_p^p @staticmethod def optimize_weights( tensor, scale, zero, min_max: list[int], axis: int = 0, opt_params: dict = None, # noqa: RUF013 verbose=False, ): import torch opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params lp_norm, beta, kappa, iters = ( opt_params["lp_norm"], opt_params["beta"], opt_params["kappa"], opt_params["iters"], ) dtype = torch.float16 if tensor.is_cuda else torch.float32 w_f = tensor.to(dtype) scale = scale.to(dtype) zero = zero.to(dtype) if lp_norm == 1: def shrink_op(x, beta): return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) else: def shrink_op(x, beta, p=lp_norm): return torch.sign(x) * torch.nn.functional.relu( torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) ) best_error = 1e4 for i in range(iters): w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) w_r = (w_q - zero) / scale w_e = shrink_op(w_f - w_r, beta) zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) beta *= kappa current_error = float(torch.abs(w_f - w_r).mean()) if verbose: print(i, np.round(current_error, 6)) if current_error < best_error: best_error = current_error else: break del w_f, w_q, w_r, w_e return scale, zero @staticmethod def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): if pack_tensor.shape[0] == ori_int_tensor.shape[0]: ori_int_tensor = ori_int_tensor.T pack_tensor = pack_tensor.T if bits in [2, 4, 8]: compress_ratio = pack_tensor.element_size() * 8 // bits for j in range(compress_ratio): pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) else: raise NotImplementedError("Only 2,4,8 bits are supported.") # from Official implementation of Half-Quadratic Quantization (HQQ) def quantize_internal( self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 ): import torch weight = tensor.float() ori_shape = weight.shape pad_len = (group_size - ori_shape[axis] % group_size) % group_size if axis == 1: weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0) else: weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0) shape = weight.shape # Reshape for grouping if (group_size is not None) and channel_wise: weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1]) # Get min/max values if channel_wise is False: _min, _max = weight.min(), weight.max() optimize = False else: _min = weight.min(axis=axis, keepdim=True)[0] _max = weight.max(axis=axis, keepdim=True)[0] max_v = 2**bits - 1 min_v = 0 min_max = [min_v, max_v] # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. # clamp to avoid half-precision problems scale = (max_v / (_max - _min)).clamp(max=2e4) #!!!!!!!!!!!!!!! min_max_axis = _max - _min if (min_max_axis == 0).sum().item() > 0: min_max_axis[min_max_axis == 0] = max_v scale = (max_v / min_max_axis).clamp(max=2e4) zero = -_min * scale if round_zero: zero = torch.round(zero) # Fine-tune weights if optimize: scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis) # Quantize # Necessary for fake quantization backprop w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) w_q = w_q.reshape(shape).int() scale = 1.0 / scale if axis == 1: scale = scale.reshape(shape[0], -1) zero = zero.reshape(shape[0], -1) else: scale = scale.reshape(-1, shape[-1]) zero = zero.reshape(-1, shape[-1]) # cleanup del weight, _min, _max return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: """ If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. """ if node.op_type != "MatMul": return [node] # only care about MatMul for now import torch logger.info(f"start to quantize {node.name} ...") input_b = node.input[1] b_pb, bs_graph = get_initializer(input_b, graph_stack) if b_pb is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return [node] # only care about constant weight b_array = onnx.numpy_helper.to_array(b_pb) if len(b_array.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") return [node] # can only process 2-D matrix b_array_torch = torch.from_numpy(b_array) if torch.cuda.is_available(): b_array_torch = b_array_torch.cuda() quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size ) quant_weight_torch = quant_weight_torch.contiguous() scales_torch = scales_torch.contiguous() zero_points_torch = zero_points_torch.contiguous() packed_torch = torch.zeros( (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2), dtype=torch.uint8, device=quant_weight_torch.device, ) self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) scales = scales_torch.cpu().numpy() zero_points = zero_points_torch.cpu().numpy() # reshape to the predefined shape in MatmulNbits scales = scales.reshape(-1) zero_points = zero_points.reshape(-1) rows, cols = b_array_torch.shape block_size = self.config.block_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size packed_torch = packed_torch.reshape(cols, k_blocks, blob_size) b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) b_quant.name = b_pb.name + "_Q4" for input in bs_graph.input: if input.name == input_b: bs_graph.input.remove(input) break scales_tensor = onnx.numpy_helper.from_array(scales) scales_tensor.name = b_pb.name + "_scales" bs_graph.initializer.extend([b_quant, scales_tensor]) input_names = [node.input[0], b_quant.name, scales_tensor.name] zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = b_pb.name + "_zero_points" bs_graph.initializer.extend([zp_tensor]) input_names.append(zp_tensor.name) kwargs = {} rows, cols = b_array.shape kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = self.config.bits kwargs["block_size"] = self.config.block_size matmul_q4_node = onnx.helper.make_node( "MatMulNBits", inputs=input_names, outputs=[node.output[0]], name=node.name + "_Q4" if node.name else "", domain="com.microsoft", **kwargs, ) logger.info(f"complete quantization of {node.name} ...") return [matmul_q4_node] def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: for gid in range(len(graph_path) - 1, -1, -1): graph = graph_path[gid] for tensor in graph.initializer: if tensor.name == name: return tensor, graph return None, None class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """4b quantize fp32 weight to a blob""" if len(fp32weight.shape) != 2: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = fp32weight.shape block_size = self.config.block_size k_blocks = (rows + block_size - 1) // block_size if self.config.quant_format == QuantFormat.QOperator: blob_size = block_size // 2 padded_rows = k_blocks * block_size pad_len = padded_rows - rows if pad_len > 0: fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant") # block wise quantization, each block comes from a single column packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) quantize_matmul_4bits( packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype) quantize_qdq_matmul_4bits( packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) return (packed, scales, zero_point) def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]: """ If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node. If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul. """ if node.op_type != "MatMul": return [node] # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4 input_b = node.input[1] b_tensor, b_graph = get_initializer(input_b, graph_stack) if b_tensor is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return [node] # only care about constant weight b_ndarray = onnx.numpy_helper.to_array(b_tensor) if len(b_ndarray.shape) != 2: logger.info("MatMul weight is not 2D. Skip to quantize") return [node] # can only process 2-D matrix packed, scales, zero_points = self.int4_block_quant(b_ndarray) if self.config.quant_format == QuantFormat.QOperator: b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + "_Q4") scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales") else: b_quant = onnx.helper.make_tensor(b_tensor.name + "_DQ_Q4", qtype, b_ndarray.shape, packed.tobytes(), True) scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") for input in b_graph.input: if input.name == input_b: b_graph.input.remove(input) break b_graph.initializer.extend([b_quant, scales_tensor]) output_nodes = [] if self.config.quant_format == QuantFormat.QOperator: input_names = [node.input[0], b_quant.name, scales_tensor.name] if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points, b_tensor.name + "_zero_points") input_names.append(zp_tensor.name) b_graph.initializer.extend([zp_tensor]) kwargs = {} rows, cols = b_ndarray.shape kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 kwargs["block_size"] = self.config.block_size if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.config.accuracy_level matmul_q4_node = onnx.helper.make_node( "MatMulNBits", inputs=input_names, outputs=[node.output[0]], name=node.name + "_Q4" if node.name else "", domain="com.microsoft", **kwargs, ) output_nodes.append(matmul_q4_node) else: dq_input_names = [b_quant.name, scales_tensor.name] dq_output_names = [b_quant.name + "_output"] matmul_input_names = [node.input[0], dq_output_names[0]] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( b_tensor.name + "_DQ_zero_points", qtype, scales.shape, zero_points.tobytes(), True ) dq_input_names.append(zp_tensor.name) b_graph.initializer.extend([zp_tensor]) dq_kwargs = {"axis": 0, "block_size": self.config.block_size} dq_node = onnx.helper.make_node( "DequantizeLinear", inputs=dq_input_names, outputs=dq_output_names, name=node.name + "_DQ_Q4" if node.name else "", **dq_kwargs, ) matmul_node = onnx.helper.make_node( "MatMul", inputs=matmul_input_names, outputs=matmul_output_names, name=node.name + "_matmul_Q4" if node.name else "", ) output_nodes.extend([dq_node, matmul_node]) logger.info(f"complete quantization of {node.name} ...") return output_nodes class MatMul4BitsQuantizer: """ Perform 4b quantization of constant MatMul weights. If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the MatMul node. If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is replaced by the DequantizeLinear + MatMul nodes. """ def __init__( self, model: ModelProto | str, block_size: int = 128, is_symmetric: bool = False, accuracy_level: int | None = None, nodes_to_exclude=None, quant_format=QuantFormat.QOperator, algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: nodes_to_exclude = [] self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) self.model_path = model if isinstance(model, str) else None self.block_size = block_size self.is_symmetric = is_symmetric self.accuracy_level = accuracy_level self.nodes_to_exclude = set(nodes_to_exclude) self.node_quantizer = None if algo_config is None: algo_config = DefaultWeightOnlyQuantConfig( block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level, quant_format=quant_format, ) self.algo_config = algo_config if algo_config.algorithm == "HQQ": self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) elif algo_config.algorithm == "DEFAULT": self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] for node in graph.node: graph_attrs = [ attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] if len(graph_attrs): kwargs = {} for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: # recursive call to take care of sub-graph graph_stack.append(attr.g) kv = {attr.name: self._process_subgraph(graph_stack)} elif attr.type == onnx.AttributeProto.GRAPHS: value = [] for subgraph in attr.graphs: # recursive call to take care of sub-graph graph_stack.append(subgraph) value.extend([self._process_subgraph(graph_stack)]) kv = {attr.name: value} else: kv = attribute_to_kwarg(attr) kwargs.update(kv) node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) out_nodes = [] if node.name in self.nodes_to_exclude: logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") out_nodes = [node] elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": out_nodes = self.node_quantizer.quantize(node, graph_stack) else: out_nodes = self.node_quantizer.quantize(node, graph_stack) new_nodes.extend(out_nodes) graph.ClearField("node") graph.node.extend(new_nodes) graph_stack.pop() return graph def _generate_q4_node_config(self): """Generate weight only quant configuration for nodes.""" q4_node_config = {} template_config_q4 = { "bits": 4, "group_size": self.block_size, "scheme": "sym" if self.is_symmetric else "asym", } for node in self.model.model.graph.node: if node.op_type in ["MatMul"]: if not all([self.model.get_initializer(i) is None for i in node.input]): q4_node_config[node.name] = template_config_q4 return q4_node_config def int4_quant_algo(self): """4b quantize a model with RTN or GPTQ algorithm. Please refer to https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md for more details on weight only quantization using Intel® Neural Compressor. """ def inc_dataloader(): data_reader = copy.deepcopy(self.algo_config.calibration_data_reader) for data in data_reader: yield data, None kwargs = {} if self.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level weight_only_node_config = self._generate_q4_node_config() algorithm = self.algo_config.algorithm logger.info(f"start to quantize model with {algorithm} algorithm...") if algorithm == "RTN": from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize kwargs["ratios"] = self.algo_config.ratios self.model = rtn_quantize( model=self.model_path if self.model_path is not None else self.model.model, weight_config=weight_only_node_config, **kwargs, ) elif algorithm == "GPTQ": from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize kwargs["percdamp"] = self.algo_config.percdamp kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder kwargs["mse"] = self.algo_config.mse kwargs["perchannel"] = self.algo_config.perchannel kwargs["n_samples"] = -1 dataloader = inc_dataloader() self.model = gptq_quantize( model=self.model_path if self.model_path is not None else self.model.model, weight_config=weight_only_node_config, dataloader=dataloader, **kwargs, ) logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] # Update domain opset if self.algo_config.quant_format == QuantFormat.QOperator: self.model.set_opset_import("com.microsoft", 1) else: opset_import = self.model.opset_import() for opset in opset_import: if opset.domain in [None, "ai.onnx", ""] and opset.version < 21: logger.warning( "The opset of the input model is under 21 and doesn't support int4 data type. " "Force to update it to opset 21, but the generated model may not be a valid model." ) self.model.set_opset_import(opset.domain, 21) self._process_subgraph(graph_stack) self.model.clean_initializers() else: # use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm try: importlib.import_module("neural_compressor") except Exception as e: logging.error(f"{e}.") raise RuntimeError( "neural-compressor is not correctly installed. Please check your environment." ) from e import neural_compressor assert version.parse(neural_compressor.__version__) >= version.parse( "2.3.2" ), "Require neural-compressor >= 2.3.2 to support weight only quantization!" self.int4_quant_algo() def ort_convert_str_to_bool(value): return value.lower() in ("true", "1") def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. A weight matrix is partitioned into into blocks, where each block is a continguous subset inside each column. Each block is quantized into a set of 4b integers with a scaling factor and an optional offset. """ ) parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization") parser.add_argument( "--quant_method", default="default", type=str, choices=["default", "hqq", "rtn", "gptq"], help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor", ) parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") parser.add_argument( "--symmetric", required=False, default=True, const=True, nargs="?", type=ort_convert_str_to_bool, choices=[True, False], help="Indicate whether to quantize the model symmetrically, symmetric is not supported by hqq", ) parser.add_argument( "--accuracy_level", required=False, type=int, help="Accuracy level of the 4-bit quantized MatMul computation. " "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) parser.add_argument( "--nodes_to_exclude", nargs="+", type=str, required=False, default=[], help="Specify the nodes to be excluded from quantization with node names", ) parser.add_argument( "--quant_format", default="QOperator", type=str, choices=["QOperator", "QDQ"], help="QuantFormat {QOperator, QDQ}" "QOperator format quantizes the model with quantized operators directly." "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.verbose: logger.setLevel(logging.DEBUG) input_model_path = args.input_model output_model_path = args.output_model quant_format = QuantFormat[args.quant_format] if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") raise Exception(f"file {output_model_path} already exists") if args.symmetric and args.quant_method == "hqq": logger.warning("symmetric is not supportted by hqq, will force to symmetric=False") args.symmetric = False model = onnx.load(input_model_path) if args.quant_method == "hqq": quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) elif args.quant_method == "default": quant_config = DefaultWeightOnlyQuantConfig( block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, quant_format=quant_format, ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig() elif args.quant_method == "gptq": quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) else: raise ValueError(f"Unsupported quantization method: {args.quant_method}") quant = MatMul4BitsQuantizer( model=model, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, algo_config=quant_config, ) quant.process() quant.model.save_model_to_file(output_model_path, True)