I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,3 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

View File

@ -0,0 +1,155 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
# A library and utility for drawing ONNX nets. Most of this implementation has
# been borrowed from the caffe2 implementation
# https://github.com/pytorch/pytorch/blob/v2.3.1/caffe2/python/net_drawer.py
#
# The script takes two required arguments:
# -input: a path to a serialized ModelProto .pb file.
# -output: a path to write a dot file representation of the graph
#
# Given this dot file representation, you can-for example-export this to svg
# with the graphviz `dot` utility, like so:
#
# $ dot -Tsvg my_output.dot -o my_output.svg
from __future__ import annotations
import argparse
import json
from collections import defaultdict
from typing import Any, Callable
import pydot
from onnx import GraphProto, ModelProto, NodeProto
OP_STYLE = {
"shape": "box",
"color": "#0F9D58",
"style": "filled",
"fontcolor": "#FFFFFF",
}
BLOB_STYLE = {"shape": "octagon"}
_NodeProducer = Callable[[NodeProto, int], pydot.Node]
def _escape_label(name: str) -> str:
# json.dumps is poor man's escaping
return json.dumps(name)
def _form_and_sanitize_docstring(s: str) -> str:
url = "javascript:alert("
url += _escape_label(s).replace('"', "'").replace("<", "").replace(">", "")
url += ")"
return url
def GetOpNodeProducer( # noqa: N802
embed_docstring: bool = False, **kwargs: Any
) -> _NodeProducer:
def really_get_op_node(op: NodeProto, op_id: int) -> pydot.Node:
if op.name:
node_name = f"{op.name}/{op.op_type} (op#{op_id})"
else:
node_name = f"{op.op_type} (op#{op_id})"
for i, input_ in enumerate(op.input):
node_name += "\n input" + str(i) + " " + input_
for i, output in enumerate(op.output):
node_name += "\n output" + str(i) + " " + output
node = pydot.Node(node_name, **kwargs)
if embed_docstring:
url = _form_and_sanitize_docstring(op.doc_string)
node.set_URL(url)
return node
return really_get_op_node
def GetPydotGraph( # noqa: N802
graph: GraphProto,
name: str | None = None,
rankdir: str = "LR",
node_producer: _NodeProducer | None = None,
embed_docstring: bool = False,
) -> pydot.Dot:
if node_producer is None:
node_producer = GetOpNodeProducer(embed_docstring=embed_docstring, **OP_STYLE)
pydot_graph = pydot.Dot(name, rankdir=rankdir)
pydot_nodes: dict[str, pydot.Node] = {}
pydot_node_counts: dict[str, int] = defaultdict(int)
for op_id, op in enumerate(graph.node):
op_node = node_producer(op, op_id)
pydot_graph.add_node(op_node)
for input_name in op.input:
if input_name not in pydot_nodes:
input_node = pydot.Node(
_escape_label(input_name + str(pydot_node_counts[input_name])),
label=_escape_label(input_name),
**BLOB_STYLE,
)
pydot_nodes[input_name] = input_node
else:
input_node = pydot_nodes[input_name]
pydot_graph.add_node(input_node)
pydot_graph.add_edge(pydot.Edge(input_node, op_node))
for output_name in op.output:
if output_name in pydot_nodes:
pydot_node_counts[output_name] += 1
output_node = pydot.Node(
_escape_label(output_name + str(pydot_node_counts[output_name])),
label=_escape_label(output_name),
**BLOB_STYLE,
)
pydot_nodes[output_name] = output_node
pydot_graph.add_node(output_node)
pydot_graph.add_edge(pydot.Edge(op_node, output_node))
return pydot_graph
def main() -> None:
parser = argparse.ArgumentParser(description="ONNX net drawer")
parser.add_argument(
"--input",
type=str,
required=True,
help="The input protobuf file.",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="The output protobuf file.",
)
parser.add_argument(
"--rankdir",
type=str,
default="LR",
help="The rank direction of the pydot graph.",
)
parser.add_argument(
"--embed_docstring",
action="store_true",
help="Embed docstring as javascript alert. Useful for SVG format.",
)
args = parser.parse_args()
model = ModelProto()
with open(args.input, "rb") as fid:
content = fid.read()
model.ParseFromString(content)
pydot_graph = GetPydotGraph(
model.graph,
name=model.graph.name,
rankdir=args.rankdir,
node_producer=GetOpNodeProducer(
embed_docstring=args.embed_docstring, **OP_STYLE
),
)
pydot_graph.write_dot(args.output)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,425 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import numpy as np
from onnx import (
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
SparseTensorProto,
TensorProto,
)
from onnx.helper import (
make_attribute,
make_function,
make_graph,
make_model,
make_node,
make_tensor,
make_tensor_value_info,
set_model_props,
tensor_dtype_to_np_dtype,
)
from onnx.numpy_helper import from_array
def _replace_constant(
node: NodeProto, threshold: int, value_constant_of_shape: float
) -> list[NodeProto]:
"""Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor."""
if node.op_type != "Constant":
raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.")
for att in node.attribute:
if att.name == "sparse_value":
raise NotImplementedError(
f"This feature is not yet implemented for a sparse constant "
f"(node name={node.name!r})."
)
if att.name == "value":
value = att.t
new_name = f"{value.name}__SHAPE"
dims = value.dims
size = np.prod(dims)
if size <= threshold:
return [node]
init = from_array(np.array(list(dims), dtype=np.int64), name=new_name)
dtype = tensor_dtype_to_np_dtype(value.data_type)
node_shape = make_node(
"Constant",
[],
[new_name],
value=init,
)
new_node = make_node(
"ConstantOfShape",
[new_name],
node.output,
value=from_array(np.array([value_constant_of_shape], dtype=dtype)),
)
return [node_shape, new_node]
raise NotImplementedError(
f"Replacement of constant with attribute {att.name!r}"
)
return [node]
def _replace_constant_of_shape_with_range(
onx: GraphProto | FunctionProto,
) -> GraphProto | FunctionProto:
"""Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors.
The function is not recursive. The recursivity is done by
*replace_initializer_by_constant_of_shape*.
"""
if isinstance(onx, GraphProto):
nodes = list(onx.node)
elif isinstance(onx, FunctionProto):
nodes = list(onx.node)
else:
raise TypeError(f"Not implemented for type {type(onx)}.")
existing_names = set()
for node in nodes:
existing_names |= set(node.input)
existing_names |= set(node.output)
def _find_name(prefix):
if prefix not in existing_names:
existing_names.add(prefix)
return prefix
i = 2
while True:
name = f"{prefix}_{i}"
if name not in existing_names:
existing_names.add(name)
return name
i += 1
# The function should never go through that line.
raise RuntimeError("The function should never go through that line.")
cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0)
cst1 = make_node("Constant", [], [_find_name("one")], value_int=1)
update = {}
for inode, node in enumerate(nodes):
if node.op_type != "ConstantOfShape":
continue
shape = node.input[0]
n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")])
a = make_node(
"Range",
[cst0.output[0], n.output[0], cst1.output[0]],
[_find_name(f"{shape}_RANGE")],
)
if len(node.attribute) == 1:
to = node.attribute[0].t.data_type
else:
to = TensorProto.FLOAT
ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to)
cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to)
d = make_node(
"Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")]
)
resh = make_node("Reshape", [d.output[0], shape], node.output)
update[inode] = [n, a, ac, cl, d, resh]
for inode, up in sorted(update.items(), reverse=True):
nodes[inode : inode + 1] = up
nodes.insert(0, cst0)
nodes.insert(1, cst1)
if isinstance(onx, GraphProto):
graph = make_graph(
nodes,
onx.name,
onx.input,
onx.output,
initializer=onx.initializer,
sparse_initializer=onx.sparse_initializer,
)
return graph
if isinstance(onx, FunctionProto):
new_onx = make_function(
onx.domain,
onx.name,
onx.input,
onx.output,
nodes,
opset_imports=onx.opset_import,
)
return new_onx
raise TypeError(f"Not implemented for type {type(onx)}.")
def _replace_constant_of_shape_value(
onx: GraphProto | FunctionProto, value_constant_of_shape: float
) -> GraphProto | FunctionProto:
"""Replaces all fill value of all nodes *ConstantOfShape*."""
if isinstance(onx, GraphProto):
nodes = list(onx.node)
elif isinstance(onx, FunctionProto):
nodes = list(onx.node)
else:
raise TypeError(f"Not implemented for type {type(onx)}.")
existing_names = set()
for node in nodes:
existing_names |= set(node.input)
existing_names |= set(node.output)
update = {}
for inode, node in enumerate(nodes):
if node.op_type != "ConstantOfShape":
continue
tensor = node.attribute[0].t
new_tensor = make_tensor(
tensor.name, tensor.data_type, [1], [value_constant_of_shape]
)
new_node = make_node("ConstantOfShape", node.input, node.output)
att = make_attribute(node.attribute[0].name, value=new_tensor)
new_node.attribute.append(att)
update[inode] = new_node
for inode, up in update.items():
nodes[inode] = up
if isinstance(onx, GraphProto):
graph = make_graph(
nodes,
onx.name,
onx.input,
onx.output,
initializer=onx.initializer,
sparse_initializer=onx.sparse_initializer,
)
return graph
if isinstance(onx, FunctionProto):
new_onx = make_function(
onx.domain,
onx.name,
onx.input,
onx.output,
nodes,
opset_imports=onx.opset_import,
)
return new_onx
raise TypeError(f"Not implemented for type {type(onx)}.")
def replace_initializer_by_constant_of_shape( # noqa: PLR0911
onx: FunctionProto | GraphProto | ModelProto,
threshold: int = 128,
ir_version: int | None = None,
use_range: bool = False,
value_constant_of_shape: float = 0.5,
):
"""Replace initializers or constant node by nodes *ConstantOfShape* to reduce the size.
This reduce the cost to write a unit test about a specific graph structure.
Args:
onx: ModelProto
threshold: every initializer under this threshold is not
impacted
ir_version: initializer must be specified as input for
`ir_version <= 3`, this must be specified if onx is
:class:`FunctionProto` or :class:`GraphProto`
use_range: if uses operator *Range* instead of *ConstantOfShape*
to avoid constant tensors
value_constant_of_shape: value to use as a value for all nodes
*ConstantOfShape*, a high value may produce nan or inf
predictions
Returns:
onx, modified ModelProto
The function is designed so that the function can be reapplied on a modified model
and either replace *ConstantOfShape* with *Range* operators, either replace the fill value
for every *ConstantOfShape*.
"""
if isinstance(onx, FunctionProto):
modified = False
new_nodes: list[NodeProto] = []
for node in onx.node:
if node.op_type == "Constant":
cst_nodes = _replace_constant(node, threshold, value_constant_of_shape)
if len(cst_nodes) == 2: # noqa: PLR2004
modified = True
new_nodes.extend(cst_nodes)
continue
new_nodes.append(node)
if modified:
new_onx = make_function(
onx.domain,
onx.name,
onx.input,
onx.output,
new_nodes,
opset_imports=onx.opset_import,
)
if use_range:
return _replace_constant_of_shape_with_range(new_onx)
if value_constant_of_shape != 1:
return _replace_constant_of_shape_value(
new_onx, value_constant_of_shape
)
return new_onx
if use_range:
return _replace_constant_of_shape_with_range(onx)
if value_constant_of_shape != 1:
return _replace_constant_of_shape_value(onx, value_constant_of_shape)
return onx
if isinstance(onx, ModelProto):
new_graph = replace_initializer_by_constant_of_shape(
onx.graph,
ir_version=ir_version or onx.ir_version,
threshold=threshold,
use_range=use_range,
value_constant_of_shape=value_constant_of_shape,
)
new_functions = [
replace_initializer_by_constant_of_shape(
f,
threshold=threshold,
ir_version=ir_version or onx.ir_version,
use_range=use_range,
value_constant_of_shape=value_constant_of_shape,
)
for f in onx.functions
]
model = make_model(
new_graph,
functions=new_functions,
producer_name=onx.producer_name,
producer_version=onx.producer_version,
ir_version=ir_version or onx.ir_version,
doc_string=onx.doc_string,
domain=onx.domain,
model_version=onx.model_version,
)
if len(onx.metadata_props) > 0: # pragma: no cover
values = {p.key: p.value for p in onx.metadata_props}
set_model_props(model, values)
del model.opset_import[:]
for oimp in onx.opset_import:
op_set = model.opset_import.add()
if oimp.domain == "" and oimp.version < 11 and use_range: # noqa: PLR2004
raise RuntimeError(
f"Range was introduced in opset 11 but opset is {oimp.version}."
)
if oimp.domain == "" and oimp.version < 9: # noqa: PLR2004
raise RuntimeError(
f"ConstantOfShape was introduced in "
f"opset 9 but opset is {oimp.version}."
)
op_set.domain = oimp.domain
op_set.version = oimp.version
return model
if not isinstance(onx, GraphProto):
raise TypeError(f"onx should be a GraphProto at this stage not {type(onx)}.")
n_modifications = 0
new_nodes = []
removed = set()
additional_inputs = []
new_inits: list[TensorProto] = []
for init in onx.initializer:
dims = tuple(init.dims)
size = np.prod(dims)
if size <= threshold:
new_inits.append(init)
continue
n_modifications += 1
new_name = f"{init.name}__SHAPE"
new_inits.append(
from_array(np.array(list(dims), dtype=np.int64), name=new_name)
)
dtype = tensor_dtype_to_np_dtype(init.data_type)
node = make_node(
"ConstantOfShape",
[new_name],
[init.name],
value=from_array(np.array([0.5], dtype=dtype)),
)
new_nodes.append(node)
removed.add(init.name)
if ir_version is not None and ir_version <= 3: # noqa: PLR2004
additional_inputs.append(
make_tensor_value_info(new_name, TensorProto.INT64, [len(dims)])
)
new_sparse_inits: list[SparseTensorProto] = []
for sp_init in onx.sparse_initializer:
dims = tuple(sp_init.dims)
size = np.prod(dims)
if size <= threshold:
new_sparse_inits.append(sp_init)
continue
raise NotImplementedError(
f"This feature is not yet implemented for a sparse initializer "
f"(indices.name={sp_init.indices.name!r}, "
f"values.name={sp_init.values.name!r})."
)
for node in onx.node:
if node.op_type == "Constant":
shape_nodes = _replace_constant(node, threshold, value_constant_of_shape)
if len(shape_nodes) == 2: # noqa: PLR2004
n_modifications += 1
new_nodes.extend(shape_nodes)
continue
modified = False
atts = []
for att in node.attribute:
if (
att.type == AttributeProto.GRAPH
and hasattr(att, "g")
and att.g is not None
):
g = replace_initializer_by_constant_of_shape(
att.g,
threshold=threshold,
ir_version=ir_version,
use_range=use_range,
value_constant_of_shape=value_constant_of_shape,
)
if id(g) != id(att.g):
modified = True
att = make_attribute(att.name, g) # noqa: PLW2901
atts.append(att)
if modified:
new_node = make_node(node.op_type, node.input, node.output)
new_node.attribute.extend(atts)
new_nodes.append(new_node)
n_modifications += 1
else:
new_nodes.append(node)
if n_modifications > 0:
graph = make_graph(
new_nodes,
onx.name,
[i for i in onx.input if i.name not in removed] + additional_inputs,
onx.output,
initializer=new_inits,
sparse_initializer=new_sparse_inits,
)
if use_range:
return _replace_constant_of_shape_with_range(graph)
if value_constant_of_shape != 1:
return _replace_constant_of_shape_value(graph, value_constant_of_shape)
return graph
if use_range:
return _replace_constant_of_shape_with_range(onx)
if value_constant_of_shape != 1:
return _replace_constant_of_shape_value(onx, value_constant_of_shape)
return onx

View File

@ -0,0 +1,98 @@
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
import onnx.checker
from onnx import ModelProto, ValueInfoProto
def update_inputs_outputs_dims(
model: ModelProto,
input_dims: dict[str, list[Any]],
output_dims: dict[str, list[Any]],
) -> ModelProto:
"""This function updates the dimension sizes of the model's inputs and outputs to the values
provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param
will be set for that dimension.
Example. if we have the following shape for inputs and outputs:
* shape(input_1) = ('b', 3, 'w', 'h')
* shape(input_2) = ('b', 4)
* shape(output) = ('b', 'd', 5)
The parameters can be provided as:
::
input_dims = {
"input_1": ['b', 3, 'w', 'h'],
"input_2": ['b', 4],
}
output_dims = {
"output": ['b', -1, 5]
}
Putting it together:
::
model = onnx.load('model.onnx')
updated_model = update_inputs_outputs_dims(model, input_dims, output_dims)
onnx.save(updated_model, 'model.onnx')
"""
dim_param_set: set[str] = set()
def init_dim_param_set(
dim_param_set: set[str], value_infos: list[ValueInfoProto]
) -> None:
for info in value_infos:
shape = info.type.tensor_type.shape
for dim in shape.dim:
if dim.HasField("dim_param"):
dim_param_set.add(dim.dim_param) # type: ignore
init_dim_param_set(dim_param_set, model.graph.input) # type: ignore
init_dim_param_set(dim_param_set, model.graph.output) # type: ignore
init_dim_param_set(dim_param_set, model.graph.value_info) # type: ignore
def update_dim(tensor: ValueInfoProto, dim: Any, j: int, name: str) -> None:
dim_proto = tensor.type.tensor_type.shape.dim[j]
if isinstance(dim, int):
if dim >= 0:
if dim_proto.HasField("dim_value") and dim_proto.dim_value != dim:
raise ValueError(
f"Unable to set dimension value to {dim} for axis {j} of {name}. Contradicts existing dimension value {dim_proto.dim_value}."
)
dim_proto.dim_value = dim
else:
generated_dim_param = name + "_" + str(j)
if generated_dim_param in dim_param_set:
raise ValueError(
f"Unable to generate unique dim_param for axis {j} of {name}. Please manually provide a dim_param value."
)
dim_proto.dim_param = generated_dim_param
elif isinstance(dim, str):
dim_proto.dim_param = dim
else:
raise ValueError( # noqa: TRY004
f"Only int or str is accepted as dimension value, incorrect type: {type(dim)}"
)
for input_ in model.graph.input:
input_name = input_.name
input_dim_arr = input_dims[input_name]
for j, dim in enumerate(input_dim_arr):
update_dim(input_, dim, j, input_name)
for output in model.graph.output:
output_name = output.name
output_dim_arr = output_dims[output_name]
for j, dim in enumerate(output_dim_arr):
update_dim(output, dim, j, output_name)
onnx.checker.check_model(model)
return model