# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- from fusion_utils import NumpyHelper from onnx import ModelProto, TensorProto from onnx.external_data_helper import set_external_data from onnx_model import OnnxModel from onnxruntime import OrtValue def extract_raw_data_from_model(model: ModelProto): """ Extract external data from model and return the external data as a list of tuples (name, value). Note this function does not handle external data that is not loaded into the model as raw data. Args: model (ModelProto): the model proto to extract external data from. Returns: (external_names, external_values): a tuple of two lists of external data names and values. """ external_data = [] onnx_model = OnnxModel(model) for graph in onnx_model.graphs(): for initializer in graph.initializer: name = initializer.name if initializer.HasField("raw_data"): numpy_tensor = NumpyHelper.to_array(initializer) ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor) external_data.append((name, ort_value)) # mimic set_external_data set_external_data(initializer, location="foo.bin") initializer.name = name initializer.ClearField("raw_data") return zip(*external_data) def has_external_data(model: ModelProto): """ Check if the model has external data. Args: model (ModelProto): the model proto to check for external data. Returns: bool: True if the model has external data, False otherwise. """ onnx_model = OnnxModel(model) for graph in onnx_model.graphs(): for initializer in graph.initializer: if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL: return True return False