298 lines
11 KiB
Python
298 lines
11 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import tarfile
|
|
|
|
import onnx.checker
|
|
import onnx.helper
|
|
import onnx.shape_inference
|
|
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto, ValueInfoProto
|
|
|
|
|
|
class Extractor:
|
|
def __init__(self, model: ModelProto) -> None:
|
|
self.model = onnx.shape_inference.infer_shapes(model)
|
|
self.graph = self.model.graph
|
|
self.wmap = self._build_name2obj_dict(self.graph.initializer)
|
|
self.vimap = self._build_name2obj_dict(self.graph.value_info)
|
|
|
|
@staticmethod
|
|
def _build_name2obj_dict(objs): # type: ignore
|
|
return {obj.name: obj for obj in objs}
|
|
|
|
def _collect_new_io_core(self, original_io, io_names_to_extract): # type: ignore
|
|
original_io_map = self._build_name2obj_dict(original_io)
|
|
original_io_names = set(original_io_map)
|
|
s_io_names_to_extract = set(io_names_to_extract)
|
|
io_names_to_keep = s_io_names_to_extract & original_io_names
|
|
new_io_names_to_add = s_io_names_to_extract - original_io_names
|
|
|
|
new_io_tensors = [original_io_map[name] for name in io_names_to_keep]
|
|
# activation become input or output
|
|
new_io_tensors.extend(self.vimap[name] for name in new_io_names_to_add)
|
|
|
|
# adjust sequence
|
|
new_io_tensors_map = self._build_name2obj_dict(new_io_tensors)
|
|
return [new_io_tensors_map[name] for name in io_names_to_extract]
|
|
|
|
def _collect_new_inputs(self, names: list[str]) -> list[ValueInfoProto]:
|
|
return self._collect_new_io_core(self.graph.input, names) # type: ignore
|
|
|
|
def _collect_new_outputs(self, names: list[str]) -> list[ValueInfoProto]:
|
|
return self._collect_new_io_core(self.graph.output, names) # type: ignore
|
|
|
|
def _dfs_search_reachable_nodes(
|
|
self,
|
|
node_output_name: str,
|
|
graph_input_names: set[str],
|
|
nodes: list[NodeProto],
|
|
reachable: set[int],
|
|
unreachable: set[int],
|
|
) -> None:
|
|
"""Helper function to find nodes which are connected to an output
|
|
|
|
Arguments:
|
|
node_output_name (str): The name of the output
|
|
graph_input_names (set of string): The names of all inputs of the graph
|
|
nodes (list of nodes): The list of all nodes of the graph
|
|
reachable (set of int): The set of indexes to reachable nodes in `nodes`
|
|
unreachable (set of int): The set of indexes to unreachable nodes in `nodes`
|
|
"""
|
|
# finish search at inputs
|
|
if node_output_name in graph_input_names:
|
|
return
|
|
|
|
# find nodes connected to this output
|
|
nodes_to_search = [
|
|
index for index in unreachable if node_output_name in nodes[index].output
|
|
]
|
|
|
|
# add nodes connected to this output to sets
|
|
for node_index in nodes_to_search:
|
|
reachable.add(node_index)
|
|
unreachable.remove(node_index)
|
|
|
|
# recurse on inputs
|
|
for node_index in nodes_to_search:
|
|
for name in nodes[node_index].input:
|
|
self._dfs_search_reachable_nodes(
|
|
name, graph_input_names, nodes, reachable, unreachable
|
|
)
|
|
|
|
def _collect_reachable_nodes(
|
|
self,
|
|
input_names: list[str],
|
|
output_names: list[str],
|
|
) -> list[NodeProto]:
|
|
_input_names = set(input_names)
|
|
nodes = list(self.graph.node)
|
|
reachable: set[int] = set()
|
|
unreachable: set[int] = set(range(len(nodes)))
|
|
for name in output_names:
|
|
self._dfs_search_reachable_nodes(
|
|
name, _input_names, nodes, reachable, unreachable
|
|
)
|
|
# needs to be topologically sorted
|
|
nodes = [nodes[node_index] for node_index in sorted(reachable)]
|
|
return nodes
|
|
|
|
def _collect_referred_local_functions(
|
|
self,
|
|
nodes, # type: list[NodeProto]
|
|
): # type: (...) -> list[FunctionProto]
|
|
# a node in a model graph may refer a function.
|
|
# a function contains nodes, some of which may in turn refer a function.
|
|
# we need to find functions referred by graph nodes and
|
|
# by nodes used to define functions.
|
|
def find_referred_funcs(nodes, referred_local_functions): # type: ignore
|
|
new_nodes = [] # type: list[NodeProto]
|
|
for node in nodes:
|
|
# check if the node is a function op
|
|
match_function = next(
|
|
(
|
|
f
|
|
for f in self.model.functions
|
|
if f.name == node.op_type and f.domain == node.domain
|
|
),
|
|
None,
|
|
)
|
|
if match_function and match_function not in referred_local_functions:
|
|
referred_local_functions.append(match_function)
|
|
new_nodes.extend(match_function.node)
|
|
|
|
return new_nodes
|
|
|
|
referred_local_functions = [] # type: list[FunctionProto]
|
|
new_nodes = find_referred_funcs(nodes, referred_local_functions)
|
|
while new_nodes:
|
|
new_nodes = find_referred_funcs(new_nodes, referred_local_functions)
|
|
|
|
return referred_local_functions
|
|
|
|
def _collect_reachable_tensors(
|
|
self,
|
|
nodes: list[NodeProto],
|
|
) -> tuple[list[TensorProto], list[ValueInfoProto]]:
|
|
all_tensors_names: set[str] = set()
|
|
|
|
for node in nodes:
|
|
all_tensors_names.update(node.input)
|
|
all_tensors_names.update(node.output)
|
|
|
|
initializer = [self.wmap[t] for t in self.wmap if t in all_tensors_names]
|
|
value_info = [self.vimap[t] for t in self.vimap if t in all_tensors_names]
|
|
len_sparse_initializer = len(self.graph.sparse_initializer)
|
|
if len_sparse_initializer != 0:
|
|
raise ValueError(
|
|
f"len_sparse_initializer is {len_sparse_initializer}, it must be 0."
|
|
)
|
|
len_quantization_annotation = len(self.graph.quantization_annotation)
|
|
if len_quantization_annotation != 0:
|
|
raise ValueError(
|
|
f"len_quantization_annotation is {len_quantization_annotation}, it must be 0."
|
|
)
|
|
return initializer, value_info
|
|
|
|
def _make_model(
|
|
self,
|
|
nodes: list[NodeProto],
|
|
inputs: list[ValueInfoProto],
|
|
outputs: list[ValueInfoProto],
|
|
initializer: list[TensorProto],
|
|
value_info: list[ValueInfoProto],
|
|
local_functions: list[FunctionProto],
|
|
) -> ModelProto:
|
|
name = "Extracted from {" + self.graph.name + "}"
|
|
graph = onnx.helper.make_graph(
|
|
nodes, name, inputs, outputs, initializer=initializer, value_info=value_info
|
|
)
|
|
|
|
meta = {
|
|
"ir_version": self.model.ir_version,
|
|
"opset_imports": self.model.opset_import,
|
|
"producer_name": "onnx.utils.extract_model",
|
|
"functions": local_functions,
|
|
}
|
|
return onnx.helper.make_model(graph, **meta)
|
|
|
|
def extract_model(
|
|
self,
|
|
input_names: list[str],
|
|
output_names: list[str],
|
|
) -> ModelProto:
|
|
inputs = self._collect_new_inputs(input_names)
|
|
outputs = self._collect_new_outputs(output_names)
|
|
nodes = self._collect_reachable_nodes(input_names, output_names)
|
|
initializer, value_info = self._collect_reachable_tensors(nodes)
|
|
local_functions = self._collect_referred_local_functions(nodes)
|
|
model = self._make_model(
|
|
nodes, inputs, outputs, initializer, value_info, local_functions
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def extract_model(
|
|
input_path: str | os.PathLike,
|
|
output_path: str | os.PathLike,
|
|
input_names: list[str],
|
|
output_names: list[str],
|
|
check_model: bool = True,
|
|
) -> None:
|
|
"""Extracts sub-model from an ONNX model.
|
|
|
|
The sub-model is defined by the names of the input and output tensors *exactly*.
|
|
|
|
Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
|
|
which is defined by the input and output tensors, should not _cut through_ the
|
|
subgraph that is connected to the _main graph_ as attributes of these operators.
|
|
|
|
Arguments:
|
|
input_path (str | os.PathLike): The path to original ONNX model.
|
|
output_path (str | os.PathLike): The path to save the extracted ONNX model.
|
|
input_names (list of string): The names of the input tensors that to be extracted.
|
|
output_names (list of string): The names of the output tensors that to be extracted.
|
|
check_model (bool): Whether to run model checker on the extracted model.
|
|
"""
|
|
if not os.path.exists(input_path):
|
|
raise ValueError(f"Invalid input model path: {input_path}")
|
|
if not output_path:
|
|
raise ValueError("Output model path shall not be empty!")
|
|
if not output_names:
|
|
raise ValueError("Output tensor names shall not be empty!")
|
|
|
|
onnx.checker.check_model(input_path)
|
|
model = onnx.load(input_path)
|
|
|
|
e = Extractor(model)
|
|
extracted = e.extract_model(input_names, output_names)
|
|
|
|
onnx.save(extracted, output_path)
|
|
if check_model:
|
|
onnx.checker.check_model(output_path)
|
|
|
|
|
|
def _tar_members_filter(
|
|
tar: tarfile.TarFile, base: str | os.PathLike
|
|
) -> list[tarfile.TarInfo]:
|
|
"""Check that the content of ``tar`` will be extracted safely
|
|
|
|
Args:
|
|
tar: The tarball file
|
|
base: The directory where the tarball will be extracted
|
|
|
|
Returns:
|
|
list of tarball members
|
|
"""
|
|
result = []
|
|
for member in tar:
|
|
member_path = os.path.join(base, member.name)
|
|
abs_base = os.path.abspath(base)
|
|
abs_member = os.path.abspath(member_path)
|
|
if not abs_member.startswith(abs_base):
|
|
raise RuntimeError(
|
|
f"The tarball member {member_path} in downloading model contains "
|
|
f"directory traversal sequence which may contain harmful payload."
|
|
)
|
|
elif member.issym() or member.islnk():
|
|
raise RuntimeError(
|
|
f"The tarball member {member_path} in downloading model contains "
|
|
f"symbolic links which may contain harmful payload."
|
|
)
|
|
result.append(member)
|
|
return result
|
|
|
|
|
|
def _extract_model_safe(
|
|
model_tar_path: str | os.PathLike, local_model_with_data_dir_path: str | os.PathLike
|
|
) -> None:
|
|
"""Safely extracts a tar file to a specified directory.
|
|
|
|
This function ensures that the extraction process mitigates against
|
|
directory traversal vulnerabilities by validating or sanitizing paths
|
|
within the tar file. It also provides compatibility for different versions
|
|
of the tarfile module by checking for the availability of certain attributes
|
|
or methods before invoking them.
|
|
|
|
Args:
|
|
model_tar_path: The path to the tar file to be extracted.
|
|
local_model_with_data_dir_path: The directory path where the tar file
|
|
contents will be extracted to.
|
|
"""
|
|
with tarfile.open(model_tar_path) as model_with_data_zipped:
|
|
# Mitigate tarball directory traversal risks
|
|
if hasattr(tarfile, "data_filter"):
|
|
model_with_data_zipped.extractall(
|
|
path=local_model_with_data_dir_path, filter="data"
|
|
)
|
|
else:
|
|
model_with_data_zipped.extractall(
|
|
path=local_model_with_data_dir_path,
|
|
members=_tar_members_filter(
|
|
model_with_data_zipped, local_model_with_data_dir_path
|
|
),
|
|
)
|