734 lines
26 KiB
Python
734 lines
26 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
from typing import MutableMapping
|
|
|
|
from onnx import (
|
|
AttributeProto,
|
|
GraphProto,
|
|
ModelProto,
|
|
TensorProto,
|
|
checker,
|
|
helper,
|
|
utils,
|
|
)
|
|
|
|
|
|
def check_overlapping_names(
|
|
g1: GraphProto, g2: GraphProto, io_map: list[tuple[str, str]] | None = None
|
|
) -> list[tuple[str, list[str]]]:
|
|
"""Checks whether there are name collisions between two graphs
|
|
|
|
Returns a list of tuples where the first element represents the member containing overlapping names
|
|
(One of: "node", "edge", "value_info", "initializer", "sparse_initializer"), and the
|
|
second element contains a list of names that appear in both graphs on that category.
|
|
|
|
Optionally, it takes an io_map, representing the output/inputs to be connected. It provided, overlapping
|
|
present in the io_map argument will be ignored.
|
|
"""
|
|
if type(g1) is not GraphProto:
|
|
raise ValueError("g1 argument is not an ONNX graph")
|
|
if type(g2) is not GraphProto:
|
|
raise ValueError("g2 argument is not an ONNX graph")
|
|
|
|
def _overlapping(c1: list[str], c2: list[str]) -> list[str]:
|
|
return list(set(c1) & set(c2))
|
|
|
|
def _edge_names(graph: GraphProto, exclude: set[str] | None = None) -> list[str]:
|
|
if exclude is None:
|
|
exclude = set()
|
|
edges = []
|
|
for n in graph.node:
|
|
for i in n.input:
|
|
if i != "" and i not in exclude:
|
|
edges.append(i) # noqa: PERF401
|
|
for o in n.output:
|
|
if o != "" and o not in exclude:
|
|
edges.append(o) # noqa: PERF401
|
|
return edges
|
|
|
|
result = []
|
|
|
|
if not io_map:
|
|
io_map = []
|
|
io_map_inputs = {elem[1] for elem in io_map}
|
|
|
|
# Edges already cover input/output
|
|
overlap = _overlapping(_edge_names(g1), _edge_names(g2, exclude=io_map_inputs))
|
|
if overlap:
|
|
result.append(("edge", overlap))
|
|
|
|
overlap = _overlapping(
|
|
[e.name for e in g1.value_info], [e.name for e in g2.value_info]
|
|
)
|
|
if overlap:
|
|
result.append(("value_info", overlap))
|
|
|
|
overlap = _overlapping(
|
|
[e.name for e in g1.initializer], [e.name for e in g2.initializer]
|
|
)
|
|
if overlap:
|
|
result.append(("initializer", overlap))
|
|
|
|
overlap = _overlapping(
|
|
[e.values.name for e in g1.sparse_initializer],
|
|
[e.values.name for e in g2.sparse_initializer],
|
|
) + _overlapping(
|
|
[e.indices.name for e in g1.sparse_initializer],
|
|
[e.indices.name for e in g2.sparse_initializer],
|
|
)
|
|
if overlap:
|
|
result.append(("sparse_initializer", overlap))
|
|
|
|
return result
|
|
|
|
|
|
def merge_graphs(
|
|
g1: GraphProto,
|
|
g2: GraphProto,
|
|
io_map: list[tuple[str, str]],
|
|
inputs: list[str] | None = None,
|
|
outputs: list[str] | None = None,
|
|
prefix1: str | None = None,
|
|
prefix2: str | None = None,
|
|
name: str | None = None,
|
|
doc_string: str | None = None,
|
|
) -> GraphProto:
|
|
"""Combines two ONNX graphs into a single one.
|
|
|
|
The combined graph is defined by connecting the specified set of outputs/inputs. Those inputs/outputs
|
|
not specified in the io_map argument will remain as inputs/outputs of the combined graph.
|
|
|
|
Arguments:
|
|
g1 (GraphProto): First graph
|
|
g2 (GraphProto): Second graph
|
|
io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
|
|
representing outputs of the first graph and inputs of the second
|
|
to be connected
|
|
inputs (list of string): Optional list of inputs to be included in the combined graph
|
|
By default, all inputs not present in the ``io_map`` argument will be
|
|
included in the combined model
|
|
outputs (list of string): Optional list of outputs to be included in the combined graph
|
|
By default, all outputs not present in the ``io_map`` argument will be
|
|
included in the combined model
|
|
prefix1 (string): Optional prefix to be added to all names in g1
|
|
prefix2 (string): Optional prefix to be added to all names in g2
|
|
name (string): Optional name for the combined graph
|
|
By default, the name is g1.name and g2.name concatenated with an undescore delimiter
|
|
doc_string (string): Optional docstring for the combined graph
|
|
If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
|
|
|
|
Returns:
|
|
GraphProto
|
|
"""
|
|
if type(g1) is not GraphProto:
|
|
raise ValueError("g1 argument is not an ONNX graph")
|
|
if type(g2) is not GraphProto:
|
|
raise ValueError("g2 argument is not an ONNX graph")
|
|
|
|
# Prefixing names in the graph if requested, adjusting io_map accordingly
|
|
if prefix1 or prefix2:
|
|
if prefix1:
|
|
g1_copy = GraphProto()
|
|
g1_copy.CopyFrom(g1)
|
|
g1 = g1_copy
|
|
g1 = add_prefix_graph(g1, prefix=prefix1)
|
|
if prefix2:
|
|
g2_copy = GraphProto()
|
|
g2_copy.CopyFrom(g2)
|
|
g2 = g2_copy
|
|
g2 = add_prefix_graph(g2, prefix=prefix2)
|
|
io_map = [
|
|
(
|
|
prefix1 + io[0] if prefix1 else io[0],
|
|
prefix2 + io[1] if prefix2 else io[1],
|
|
)
|
|
for io in io_map
|
|
]
|
|
|
|
io_map_g1_outs = {io[0] for io in io_map}
|
|
io_map_g2_ins = {io[1] for io in io_map}
|
|
reversed_io_map = {in_name: out_name for out_name, in_name in io_map}
|
|
g1_outs = {o.name for o in g1.output}
|
|
g2_ins = {i.name for i in g2.input}
|
|
|
|
# If necessary extract subgraphs
|
|
if inputs or outputs:
|
|
if not inputs:
|
|
g1_inputs = [i.name for i in g1.input]
|
|
g2_inputs = [i.name for i in g2.input]
|
|
else:
|
|
input_set = set(inputs)
|
|
g1_inputs = [i.name for i in g1.input if i.name in input_set]
|
|
g2_inputs = [
|
|
i.name
|
|
for i in g2.input
|
|
if i.name in input_set or i.name in io_map_g2_ins
|
|
]
|
|
|
|
if not outputs:
|
|
g1_outputs = [o.name for o in g1.output]
|
|
g2_outputs = [o.name for o in g2.output]
|
|
else:
|
|
output_set = set(outputs)
|
|
g1_outputs = [
|
|
o.name
|
|
for o in g1.output
|
|
if o.name in output_set or o.name in io_map_g1_outs
|
|
]
|
|
g2_outputs = [o.name for o in g2.output if o.name in output_set]
|
|
|
|
if len(g1_inputs) < len(g1.input) or len(g1_outputs) < len(g1.output):
|
|
e1 = utils.Extractor(helper.make_model(g1))
|
|
g1 = e1.extract_model(g1_inputs, g1_outputs).graph
|
|
|
|
if len(g2_inputs) < len(g2.input) or len(g2_outputs) < len(g2.output):
|
|
e2 = utils.Extractor(helper.make_model(g2))
|
|
g2 = e2.extract_model(g2_inputs, g2_outputs).graph
|
|
|
|
# Check that input/output names specified in the io_map argument are valid input/output names
|
|
for g1_out_name, g2_in_name in io_map:
|
|
if g1_out_name not in g1_outs:
|
|
raise ValueError(f"Output {g1_out_name} is not present in g1")
|
|
if g2_in_name not in g2_ins:
|
|
raise ValueError(f"Input {g2_in_name} is not present in g2")
|
|
|
|
# Check for name collision
|
|
overlapping_names = check_overlapping_names(g1, g2, io_map)
|
|
if len(overlapping_names) > 0:
|
|
category, names = overlapping_names[0]
|
|
raise ValueError(
|
|
"Cant merge two graphs with overlapping names. "
|
|
f"Found repeated {category} names: "
|
|
+ ", ".join(names)
|
|
+ "\n"
|
|
+ "Consider using ``onnx.compose.add_prefix`` to add a prefix to names in one of the graphs."
|
|
)
|
|
|
|
g = GraphProto()
|
|
|
|
g.node.extend(g1.node)
|
|
g2_nodes_begin = len(g.node)
|
|
g.node.extend(g2.node)
|
|
g2_nodes_end = len(g.node)
|
|
|
|
# Search inputs of the subgraph recursively
|
|
def connect_io(sub_graph: GraphProto, start: int, end: int) -> None:
|
|
for node_idx in range(start, end):
|
|
node = sub_graph.node[node_idx]
|
|
for attr in node.attribute:
|
|
if attr.type == AttributeProto.GRAPH:
|
|
connect_io(attr.g, 0, len(attr.g.node))
|
|
|
|
for index, name_ in enumerate(node.input):
|
|
if name_ in reversed_io_map:
|
|
node.input[index] = reversed_io_map[name_]
|
|
|
|
# Connecting outputs of the first graph with the inputs of the second
|
|
connect_io(g, g2_nodes_begin, g2_nodes_end)
|
|
|
|
if inputs:
|
|
input_set = set(inputs)
|
|
g.input.extend([i for i in g1.input if i.name in input_set])
|
|
g.input.extend([i for i in g2.input if i.name in input_set])
|
|
else:
|
|
g.input.extend(g1.input)
|
|
g.input.extend([i for i in g2.input if i.name not in io_map_g2_ins])
|
|
|
|
if outputs:
|
|
output_set = set(outputs)
|
|
g.output.extend([o for o in g1.output if o.name in output_set])
|
|
g.output.extend([o for o in g2.output if o.name in output_set])
|
|
else:
|
|
g.output.extend([o for o in g1.output if o.name not in io_map_g1_outs])
|
|
g.output.extend(g2.output)
|
|
|
|
g.initializer.extend(g1.initializer)
|
|
g.initializer.extend(
|
|
[init for init in g2.initializer if init.name not in io_map_g2_ins]
|
|
)
|
|
|
|
g.sparse_initializer.extend(g1.sparse_initializer)
|
|
g.sparse_initializer.extend(
|
|
[
|
|
init
|
|
for init in g2.sparse_initializer
|
|
if init.values.name not in io_map_g2_ins
|
|
]
|
|
)
|
|
|
|
g.value_info.extend(g1.value_info)
|
|
g.value_info.extend([vi for vi in g2.value_info if vi.name not in io_map_g2_ins])
|
|
|
|
g.name = name if name is not None else "_".join([g1.name, g2.name])
|
|
|
|
if doc_string is None:
|
|
doc_string = (
|
|
f"Graph combining {g1.name} and {g2.name}\n"
|
|
+ g1.name
|
|
+ "\n\n"
|
|
+ g1.doc_string
|
|
+ "\n\n"
|
|
+ g2.name
|
|
+ "\n\n"
|
|
+ g2.doc_string
|
|
)
|
|
g.doc_string = doc_string
|
|
|
|
return g
|
|
|
|
|
|
def merge_models(
|
|
m1: ModelProto,
|
|
m2: ModelProto,
|
|
io_map: list[tuple[str, str]],
|
|
inputs: list[str] | None = None,
|
|
outputs: list[str] | None = None,
|
|
prefix1: str | None = None,
|
|
prefix2: str | None = None,
|
|
name: str | None = None,
|
|
doc_string: str | None = None,
|
|
producer_name: str | None = "onnx.compose.merge_models",
|
|
producer_version: str | None = "1.0",
|
|
domain: str | None = "",
|
|
model_version: int | None = 1,
|
|
) -> ModelProto:
|
|
"""Combines two ONNX models into a single one.
|
|
|
|
The combined model is defined by connecting the specified set of outputs/inputs.
|
|
Those inputs/outputs not specified in the io_map argument will remain as
|
|
inputs/outputs of the combined model.
|
|
|
|
Both models should have the same IR version, and same operator sets imported.
|
|
|
|
Arguments:
|
|
m1 (ModelProto): First model
|
|
m2 (ModelProto): Second model
|
|
io_map (list of pairs of string): The pairs of names [(out0, in0), (out1, in1), ...]
|
|
representing outputs of the first graph and inputs of the second
|
|
to be connected
|
|
inputs (list of string): Optional list of inputs to be included in the combined graph
|
|
By default, all inputs not present in the ``io_map`` argument will be
|
|
included in the combined model
|
|
outputs (list of string): Optional list of outputs to be included in the combined graph
|
|
By default, all outputs not present in the ``io_map`` argument will be
|
|
included in the combined model
|
|
prefix1 (string): Optional prefix to be added to all names in m1
|
|
prefix2 (string): Optional prefix to be added to all names in m2
|
|
name (string): Optional name for the combined graph
|
|
By default, the name is g1.name and g2.name concatenated with an undescore delimiter
|
|
doc_string (string): Optional docstring for the combined graph
|
|
If not provided, a default docstring with the concatenation of g1 and g2 docstrings is used
|
|
producer_name (string): Optional producer name for the combined model. Default: 'onnx.compose'
|
|
producer_version (string): Optional producer version for the combined model. Default: "1.0"
|
|
domain (string): Optional domain of the combined model. Default: ""
|
|
model_version (int): Optional version of the graph encoded. Default: 1
|
|
|
|
Returns:
|
|
ModelProto
|
|
"""
|
|
if type(m1) is not ModelProto:
|
|
raise ValueError("m1 argument is not an ONNX model")
|
|
if type(m2) is not ModelProto:
|
|
raise ValueError("m2 argument is not an ONNX model")
|
|
|
|
if m1.ir_version != m2.ir_version:
|
|
raise ValueError(
|
|
f"IR version mismatch {m1.ir_version} != {m2.ir_version}."
|
|
" Both models should have the same IR version"
|
|
)
|
|
ir_version = m1.ir_version
|
|
|
|
opset_import_map: MutableMapping[str, int] = {}
|
|
opset_imports = list(m1.opset_import) + list(m2.opset_import)
|
|
|
|
for entry in opset_imports:
|
|
if entry.domain in opset_import_map:
|
|
found_version = opset_import_map[entry.domain]
|
|
if entry.version != found_version:
|
|
raise ValueError(
|
|
"Can't merge two models with different operator set ids for a given domain. "
|
|
f"Got: {m1.opset_import} and {m2.opset_import}"
|
|
)
|
|
else:
|
|
opset_import_map[entry.domain] = entry.version
|
|
|
|
# Prefixing names in the graph if requested, adjusting io_map accordingly
|
|
if prefix1 or prefix2:
|
|
if prefix1:
|
|
m1_copy = ModelProto()
|
|
m1_copy.CopyFrom(m1)
|
|
m1 = m1_copy
|
|
m1 = add_prefix(m1, prefix=prefix1)
|
|
if prefix2:
|
|
m2_copy = ModelProto()
|
|
m2_copy.CopyFrom(m2)
|
|
m2 = m2_copy
|
|
m2 = add_prefix(m2, prefix=prefix2)
|
|
io_map = [
|
|
(
|
|
prefix1 + io[0] if prefix1 else io[0],
|
|
prefix2 + io[1] if prefix2 else io[1],
|
|
)
|
|
for io in io_map
|
|
]
|
|
|
|
graph = merge_graphs(
|
|
m1.graph,
|
|
m2.graph,
|
|
io_map,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
name=name,
|
|
doc_string=doc_string,
|
|
)
|
|
model = helper.make_model(
|
|
graph,
|
|
producer_name=producer_name,
|
|
producer_version=producer_version,
|
|
domain=domain,
|
|
model_version=model_version,
|
|
opset_imports=opset_imports,
|
|
ir_version=ir_version,
|
|
)
|
|
|
|
# Merging model metadata props
|
|
model_props = {}
|
|
for meta_entry in m1.metadata_props:
|
|
model_props[meta_entry.key] = meta_entry.value
|
|
for meta_entry in m2.metadata_props:
|
|
if meta_entry.key in model_props:
|
|
value = model_props[meta_entry.key]
|
|
if value != meta_entry.value:
|
|
raise ValueError(
|
|
"Can't merge models with different values for the same model metadata property."
|
|
f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}."
|
|
)
|
|
else:
|
|
model_props[meta_entry.key] = meta_entry.value
|
|
helper.set_model_props(model, model_props)
|
|
|
|
# Merging functions
|
|
function_overlap = list(
|
|
{f.name for f in m1.functions} & {f.name for f in m2.functions}
|
|
)
|
|
if function_overlap:
|
|
raise ValueError(
|
|
"Can't merge models with overlapping local function names."
|
|
" Found in both graphs: " + ", ".join(function_overlap)
|
|
)
|
|
model.functions.MergeFrom(m1.functions)
|
|
model.functions.MergeFrom(m2.functions)
|
|
|
|
checker.check_model(model)
|
|
return model
|
|
|
|
|
|
def add_prefix_graph(
|
|
graph: GraphProto,
|
|
prefix: str,
|
|
rename_nodes: bool | None = True,
|
|
rename_edges: bool | None = True,
|
|
rename_inputs: bool | None = True,
|
|
rename_outputs: bool | None = True,
|
|
rename_initializers: bool | None = True,
|
|
rename_value_infos: bool | None = True,
|
|
inplace: bool | None = False,
|
|
name_map: dict[str, str] | None = None,
|
|
) -> GraphProto:
|
|
"""Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
|
|
initializers, sparse initializer, value infos.
|
|
|
|
It can be used as a utility before merging graphs that have overlapping names.
|
|
Empty names are not prefixed.
|
|
|
|
Arguments:
|
|
graph (GraphProto): Graph
|
|
prefix (str): Prefix to be added to each name in the graph
|
|
rename_nodes (bool): Whether to prefix node names
|
|
rename_edges (bool): Whether to prefix node edge names
|
|
rename_inputs (bool): Whether to prefix input names
|
|
rename_outputs (bool): Whether to prefix output names
|
|
rename_initializers (bool): Whether to prefix initializer and sparse initializer names
|
|
rename_value_infos (bool): Whether to prefix value info names
|
|
inplace (bool): If True, mutates the graph directly.
|
|
Otherwise, a copy will be created
|
|
name_map: (Dict): shared name_map in subgraph
|
|
|
|
Returns:
|
|
GraphProto
|
|
"""
|
|
if type(graph) is not GraphProto:
|
|
raise ValueError("graph argument is not an ONNX graph")
|
|
|
|
if not inplace:
|
|
g = GraphProto()
|
|
g.CopyFrom(graph)
|
|
else:
|
|
g = graph
|
|
|
|
def _prefixed(prefix: str, name: str) -> str:
|
|
return prefix + name if len(name) > 0 else name
|
|
|
|
if name_map is None:
|
|
name_map = {}
|
|
if rename_edges:
|
|
for n in g.node:
|
|
for e in n.input:
|
|
name_map[e] = _prefixed(prefix, e)
|
|
for e in n.output:
|
|
name_map[e] = _prefixed(prefix, e)
|
|
|
|
if rename_inputs:
|
|
for entry in g.input:
|
|
name_map[entry.name] = _prefixed(prefix, entry.name)
|
|
if rename_outputs:
|
|
for entry in g.output:
|
|
name_map[entry.name] = _prefixed(prefix, entry.name)
|
|
|
|
if rename_nodes:
|
|
for n in g.node:
|
|
n.name = _prefixed(prefix, n.name)
|
|
for attribute in n.attribute:
|
|
if attribute.g:
|
|
add_prefix_graph(
|
|
attribute.g, prefix, inplace=True, name_map=name_map
|
|
)
|
|
|
|
if rename_initializers:
|
|
for init in g.initializer:
|
|
name_map[init.name] = _prefixed(prefix, init.name)
|
|
for sparse_init in g.sparse_initializer:
|
|
name_map[sparse_init.values.name] = _prefixed(
|
|
prefix, sparse_init.values.name
|
|
)
|
|
name_map[sparse_init.indices.name] = _prefixed(
|
|
prefix, sparse_init.indices.name
|
|
)
|
|
|
|
if rename_value_infos:
|
|
for entry in g.value_info:
|
|
name_map[entry.name] = _prefixed(prefix, entry.name)
|
|
|
|
for n in g.node:
|
|
for i, output in enumerate(n.output):
|
|
if n.output[i] in name_map:
|
|
n.output[i] = name_map[output]
|
|
for i, input_ in enumerate(n.input):
|
|
if n.input[i] in name_map:
|
|
n.input[i] = name_map[input_]
|
|
|
|
for in_desc in g.input:
|
|
if in_desc.name in name_map:
|
|
in_desc.name = name_map[in_desc.name]
|
|
for out_desc in g.output:
|
|
if out_desc.name in name_map:
|
|
out_desc.name = name_map[out_desc.name]
|
|
|
|
for initializer in g.initializer:
|
|
if initializer.name in name_map:
|
|
initializer.name = name_map[initializer.name]
|
|
for sparse_initializer in g.sparse_initializer:
|
|
if sparse_initializer.values.name in name_map:
|
|
sparse_initializer.values.name = name_map[sparse_initializer.values.name]
|
|
if sparse_initializer.indices.name in name_map:
|
|
sparse_initializer.indices.name = name_map[sparse_initializer.indices.name]
|
|
|
|
for value_info in g.value_info:
|
|
if value_info.name in name_map:
|
|
value_info.name = name_map[value_info.name]
|
|
|
|
return g
|
|
|
|
|
|
def add_prefix(
|
|
model: ModelProto,
|
|
prefix: str,
|
|
rename_nodes: bool | None = True,
|
|
rename_edges: bool | None = True,
|
|
rename_inputs: bool | None = True,
|
|
rename_outputs: bool | None = True,
|
|
rename_initializers: bool | None = True,
|
|
rename_value_infos: bool | None = True,
|
|
rename_functions: bool | None = True,
|
|
inplace: bool | None = False,
|
|
) -> ModelProto:
|
|
"""Adds a prefix to names of elements in a graph: nodes, edges, inputs, outputs,
|
|
initializers, sparse initializer, value infos, and local functions.
|
|
|
|
It can be used as a utility before merging graphs that have overlapping names.
|
|
Empty names are not _prefixed.
|
|
|
|
Arguments:
|
|
model (ModelProto): Model
|
|
prefix (str): Prefix to be added to each name in the graph
|
|
rename_nodes (bool): Whether to prefix node names
|
|
rename_edges (bool): Whether to prefix node edge names
|
|
rename_inputs (bool): Whether to prefix input names
|
|
rename_outputs (bool): Whether to prefix output names
|
|
rename_initializers (bool): Whether to prefix initializer and sparse initializer names
|
|
rename_value_infos (bool): Whether to prefix value info nanes
|
|
rename_functions (bool): Whether to prefix local function names
|
|
inplace (bool): If True, mutates the model directly.
|
|
Otherwise, a copy will be created
|
|
|
|
Returns:
|
|
ModelProto
|
|
"""
|
|
if type(model) is not ModelProto:
|
|
raise ValueError("model argument is not an ONNX model")
|
|
|
|
if not inplace:
|
|
m = ModelProto()
|
|
m.CopyFrom(model)
|
|
model = m
|
|
|
|
add_prefix_graph(
|
|
model.graph,
|
|
prefix,
|
|
rename_nodes=rename_nodes,
|
|
rename_edges=rename_edges,
|
|
rename_inputs=rename_inputs,
|
|
rename_outputs=rename_outputs,
|
|
rename_initializers=rename_initializers,
|
|
rename_value_infos=rename_value_infos,
|
|
inplace=True, # No need to create a copy, since it's a new model
|
|
)
|
|
|
|
if rename_functions:
|
|
f_name_map = {}
|
|
for f in model.functions:
|
|
new_f_name = prefix + f.name
|
|
f_name_map[f.name] = new_f_name
|
|
f.name = new_f_name
|
|
# Adjust references to local functions in other local function
|
|
# definitions
|
|
for f in model.functions:
|
|
for n in f.node:
|
|
if n.op_type in f_name_map:
|
|
n.op_type = f_name_map[n.op_type]
|
|
# Adjust references to local functions in the graph
|
|
for n in model.graph.node:
|
|
if n.op_type in f_name_map:
|
|
n.op_type = f_name_map[n.op_type]
|
|
|
|
return model
|
|
|
|
|
|
def expand_out_dim_graph(
|
|
graph: GraphProto,
|
|
dim_idx: int,
|
|
inplace: bool | None = False,
|
|
) -> GraphProto:
|
|
"""Inserts an extra dimension with extent 1 to each output in the graph.
|
|
|
|
Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
|
|
for example when the second one expects a batch dimension.
|
|
|
|
Arguments:
|
|
graph (GraphProto): Graph
|
|
dim_idx (int): Index of the dimension to be inserted.
|
|
A negative value means counting dimensions from the back.
|
|
inplace (bool): If True, mutates the model directly.
|
|
Otherwise, a copy will be created
|
|
|
|
Returns:
|
|
GraphProto
|
|
"""
|
|
if type(graph) is not GraphProto:
|
|
raise ValueError("graph argument is not an ONNX graph")
|
|
|
|
if not inplace:
|
|
g = GraphProto()
|
|
g.CopyFrom(graph)
|
|
else:
|
|
g = graph
|
|
|
|
orig_out_names = [output.name for output in g.output]
|
|
|
|
for n in g.node:
|
|
for i, out in enumerate(n.output):
|
|
if out in orig_out_names:
|
|
n.output[i] = out + f"_collapsed_dim_{dim_idx}"
|
|
for i, inp in enumerate(n.input):
|
|
if inp in orig_out_names:
|
|
n.input[i] = inp + f"_collapsed_dim_{dim_idx}"
|
|
|
|
expand_dim_k = g.name + "_expand_out_dim_idx"
|
|
g.node.append(
|
|
helper.make_node(
|
|
"Constant",
|
|
inputs=[],
|
|
outputs=[expand_dim_k],
|
|
name=f"{expand_dim_k}-constant",
|
|
value=helper.make_tensor(
|
|
name=f"{expand_dim_k}-value",
|
|
data_type=TensorProto.INT64,
|
|
dims=[
|
|
1,
|
|
],
|
|
vals=[
|
|
dim_idx,
|
|
],
|
|
),
|
|
)
|
|
)
|
|
|
|
for _ in range(len(g.output)):
|
|
o = g.output.pop(0)
|
|
prev_output = o.name + f"_collapsed_dim_{dim_idx}"
|
|
g.node.append(
|
|
helper.make_node(
|
|
"Unsqueeze",
|
|
inputs=[prev_output, expand_dim_k],
|
|
outputs=[o.name],
|
|
name=f"unsqueeze-{o.name}",
|
|
)
|
|
)
|
|
new_shape = [d.dim_value for d in o.type.tensor_type.shape.dim]
|
|
new_shape.insert(dim_idx, 1)
|
|
g.output.append(
|
|
helper.make_tensor_value_info(
|
|
o.name, o.type.tensor_type.elem_type, new_shape
|
|
)
|
|
)
|
|
return g
|
|
|
|
|
|
def expand_out_dim(
|
|
model: ModelProto,
|
|
dim_idx: int,
|
|
inplace: bool | None = False,
|
|
) -> ModelProto:
|
|
"""Inserts an extra dimension with extent 1 to each output in the graph.
|
|
|
|
Inserts an Unsqueeze node for each output. It can be used as a utility before merging graphs,
|
|
for example when the second one expects a batch dimension.
|
|
|
|
Arguments:
|
|
model (ModelProto): Model
|
|
dim_idx (int): Index of the dimension to be inserted.
|
|
A negative value means counting dimensions from the back.
|
|
inplace (bool): If True, mutates the model directly.
|
|
Otherwise, a copy will be created
|
|
|
|
Returns:
|
|
ModelProto
|
|
"""
|
|
if type(model) is not ModelProto:
|
|
raise ValueError("model argument is not an ONNX model")
|
|
|
|
if not inplace:
|
|
m = ModelProto()
|
|
m.CopyFrom(model)
|
|
model = m
|
|
|
|
expand_out_dim_graph(
|
|
model.graph,
|
|
dim_idx,
|
|
inplace=True, # No need to create a copy, since it's a new model
|
|
)
|
|
return model
|