# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import argparse import logging import os from typing import List, Tuple import numpy as np import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 from .onnx_model import ONNXModel from .quant_utils import attribute_to_kwarg logger = logging.getLogger(__name__) class MatMulBnb4Quantizer: """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type""" ################## # quantization types, must be consistent with native code type # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h # 4b floating point with bias of 3 FP4 = 0 # 4b NormalFloat NF4 = 1 def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None): nodes_to_exclude = nodes_to_exclude or [] assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4] self.model = ONNXModel(model) self.quant_type = quant_type self.block_size = block_size self.nodes_to_exclude = set(nodes_to_exclude) @staticmethod def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: for gid in range(len(graph_path) - 1, -1, -1): graph = graph_path[gid] for tensor in graph.initializer: if tensor.name == name: return tensor, graph return None, None def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: """4b quantize fp32/fp16 weight""" if len(fpweight.shape) != 2: raise ValueError("Current bnb4 block quantization only supports 2D tensors!") # need to copy since the transposed weight still has the original memory layout # Linear4bit quantizes its weight data which is the transposed weight fpweight_t = fpweight.transpose().copy() rows, cols = fpweight.shape numel = rows * cols block_size = self.block_size num_blocks = (numel + block_size - 1) // block_size quantized_numel = (numel + 1) // 2 packed = np.zeros(quantized_numel, dtype="uint8") absmax = np.zeros(num_blocks, dtype=fpweight.dtype) # block wise quantization, fpweight_t is flattened and divided into blocks quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows) return (packed, absmax) def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now logger.debug(f"start to quantize {node.name} ...") if node.name in self.nodes_to_exclude: logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") return node inputB = node.input[1] # noqa: N806 B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.debug("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight B_array = onnx.numpy_helper.to_array(B) # noqa: N806 if len(B_array.shape) != 2: logger.debug("MatMul weight is not 2D. Skip to quantize") return node # can only process 2-D matrix packed, absmax = self.bnb4_block_quant(B_array) B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 B_quant.name = B.name + "_Bnb4" for input in Bs_graph.input: if input.name == inputB: Bs_graph.input.remove(input) break absmax_tensor = onnx.numpy_helper.from_array(absmax) absmax_tensor.name = B.name + "_absmax" Bs_graph.initializer.extend([B_quant, absmax_tensor]) kwargs = {} rows, cols = B_array.shape kwargs["K"] = rows kwargs["N"] = cols kwargs["block_size"] = self.block_size kwargs["quant_type"] = self.quant_type matmul_bnb4_node = onnx.helper.make_node( "MatMulBnb4", inputs=[node.input[0], B_quant.name, absmax_tensor.name], outputs=[node.output[0]], name=node.name + "_Bnb4" if node.name else "", domain="com.microsoft", **kwargs, ) logger.debug(f"complete quantization of {node.name} ...") return matmul_bnb4_node def _process_subgraph(self, graph_stack: List[GraphProto]): new_nodes = [] graph = graph_stack[-1] for node in graph.node: graph_attrs = [ attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS ] if len(graph_attrs): kwargs = {} for attr in node.attribute: if attr.type == onnx.AttributeProto.GRAPH: # recursive call to take care of sub-graph graph_stack.append(attr.g) kv = {attr.name: self._process_subgraph(graph_stack)} elif attr.type == onnx.AttributeProto.GRAPHS: value = [] for subgraph in attr.graphs: # recursive call to take care of sub-graph graph_stack.append(subgraph) value.extend([self._process_subgraph(graph_stack)]) kv = {attr.name: value} else: kv = attribute_to_kwarg(attr) kwargs.update(kv) node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack)) graph.ClearField("node") graph.node.extend(new_nodes) graph_stack.pop() return graph def process(self): # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() has_ms_domain = False for opset in opset_import: if opset.domain == "com.microsoft": has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) self._process_subgraph(graph_stack) self.model.clean_initializers() def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices. A weight matrix is partitioned into blocks, where each block is a contiguous subset inside the flattened transposed weight matrix. Each block is quantized into a set of 4b integers with an absolute value scaling factor. """ ) parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument( "--quant_type", required=False, default=1, choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], help="Quantization data type. 0: FP4, 1: NF4", ) parser.add_argument( "--block_size", required=False, default=64, help="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", ) parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) parser.add_argument( "--nodes_to_exclude", nargs="+", type=str, required=False, default=[], help="Specify the nodes to be excluded from quantization with node names", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.verbose: logger.setLevel(logging.DEBUG) input_model_path = args.input_model output_model_path = args.output_model if os.path.exists(output_model_path): logger.error(f"file {output_model_path} already exists") raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude) quant.process() quant.model.save_model_to_file(output_model_path, True)