56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
from fusion_utils import NumpyHelper
|
|
from onnx import ModelProto, TensorProto
|
|
from onnx.external_data_helper import set_external_data
|
|
from onnx_model import OnnxModel
|
|
|
|
from onnxruntime import OrtValue
|
|
|
|
|
|
def extract_raw_data_from_model(model: ModelProto):
|
|
"""
|
|
Extract external data from model and return the external data as a list of tuples (name, value).
|
|
Note this function does not handle external data that is not loaded into the model as raw data.
|
|
|
|
Args:
|
|
model (ModelProto): the model proto to extract external data from.
|
|
Returns:
|
|
(external_names, external_values): a tuple of two lists of external data names and values.
|
|
"""
|
|
external_data = []
|
|
onnx_model = OnnxModel(model)
|
|
for graph in onnx_model.graphs():
|
|
for initializer in graph.initializer:
|
|
name = initializer.name
|
|
|
|
if initializer.HasField("raw_data"):
|
|
numpy_tensor = NumpyHelper.to_array(initializer)
|
|
ort_value = OrtValue.ortvalue_from_numpy(numpy_tensor)
|
|
external_data.append((name, ort_value))
|
|
# mimic set_external_data
|
|
set_external_data(initializer, location="foo.bin")
|
|
initializer.name = name
|
|
initializer.ClearField("raw_data")
|
|
|
|
return zip(*external_data)
|
|
|
|
|
|
def has_external_data(model: ModelProto):
|
|
"""
|
|
Check if the model has external data.
|
|
|
|
Args:
|
|
model (ModelProto): the model proto to check for external data.
|
|
Returns:
|
|
bool: True if the model has external data, False otherwise.
|
|
"""
|
|
onnx_model = OnnxModel(model)
|
|
for graph in onnx_model.graphs():
|
|
for initializer in graph.initializer:
|
|
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
|
|
return True
|
|
return False
|