Files
2024-10-30 22:14:35 +01:00

56 lines
2.1 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from fusion_utils import NumpyHelper
from onnx import ModelProto, TensorProto
from onnx.external_data_helper import set_external_data
from onnx_model import OnnxModel
from onnxruntime import OrtValue
def extract_raw_data_from_model(model: ModelProto):
"""
Extract external data from model and return the external data as a list of tuples (name, value).
Note this function does not handle external data that is not loaded into the model as raw data.
Args:
model (ModelProto): the model proto to extract external data from.
Returns:
(external_names, external_values): a tuple of two lists of external data names and values.
"""
external_data = []
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
name = initializer.name
if initializer.HasField("raw_data"):
numpy_tensor = NumpyHelper.to_array(initializer)
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
external_data.append((name, ort_value))
# mimic set_external_data
set_external_data(initializer, location="foo.bin")
initializer.name = name
initializer.ClearField("raw_data")
return zip(*external_data)
def has_external_data(model: ModelProto):
"""
Check if the model has external data.
Args:
model (ModelProto): the model proto to check for external data.
Returns:
bool: True if the model has external data, False otherwise.
"""
onnx_model = OnnxModel(model)
for graph in onnx_model.graphs():
for initializer in graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
return True
return False