426 lines
14 KiB
Python
426 lines
14 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
|
|
from onnx import (
|
|
AttributeProto,
|
|
FunctionProto,
|
|
GraphProto,
|
|
ModelProto,
|
|
NodeProto,
|
|
SparseTensorProto,
|
|
TensorProto,
|
|
)
|
|
from onnx.helper import (
|
|
make_attribute,
|
|
make_function,
|
|
make_graph,
|
|
make_model,
|
|
make_node,
|
|
make_tensor,
|
|
make_tensor_value_info,
|
|
set_model_props,
|
|
tensor_dtype_to_np_dtype,
|
|
)
|
|
from onnx.numpy_helper import from_array
|
|
|
|
|
|
def _replace_constant(
|
|
node: NodeProto, threshold: int, value_constant_of_shape: float
|
|
) -> list[NodeProto]:
|
|
"""Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor."""
|
|
if node.op_type != "Constant":
|
|
raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.")
|
|
for att in node.attribute:
|
|
if att.name == "sparse_value":
|
|
raise NotImplementedError(
|
|
f"This feature is not yet implemented for a sparse constant "
|
|
f"(node name={node.name!r})."
|
|
)
|
|
if att.name == "value":
|
|
value = att.t
|
|
new_name = f"{value.name}__SHAPE"
|
|
dims = value.dims
|
|
size = np.prod(dims)
|
|
if size <= threshold:
|
|
return [node]
|
|
init = from_array(np.array(list(dims), dtype=np.int64), name=new_name)
|
|
dtype = tensor_dtype_to_np_dtype(value.data_type)
|
|
node_shape = make_node(
|
|
"Constant",
|
|
[],
|
|
[new_name],
|
|
value=init,
|
|
)
|
|
new_node = make_node(
|
|
"ConstantOfShape",
|
|
[new_name],
|
|
node.output,
|
|
value=from_array(np.array([value_constant_of_shape], dtype=dtype)),
|
|
)
|
|
return [node_shape, new_node]
|
|
raise NotImplementedError(
|
|
f"Replacement of constant with attribute {att.name!r}"
|
|
)
|
|
return [node]
|
|
|
|
|
|
def _replace_constant_of_shape_with_range(
|
|
onx: GraphProto | FunctionProto,
|
|
) -> GraphProto | FunctionProto:
|
|
"""Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors.
|
|
|
|
The function is not recursive. The recursivity is done by
|
|
*replace_initializer_by_constant_of_shape*.
|
|
"""
|
|
if isinstance(onx, GraphProto):
|
|
nodes = list(onx.node)
|
|
elif isinstance(onx, FunctionProto):
|
|
nodes = list(onx.node)
|
|
else:
|
|
raise TypeError(f"Not implemented for type {type(onx)}.")
|
|
|
|
existing_names = set()
|
|
for node in nodes:
|
|
existing_names |= set(node.input)
|
|
existing_names |= set(node.output)
|
|
|
|
def _find_name(prefix):
|
|
if prefix not in existing_names:
|
|
existing_names.add(prefix)
|
|
return prefix
|
|
i = 2
|
|
while True:
|
|
name = f"{prefix}_{i}"
|
|
if name not in existing_names:
|
|
existing_names.add(name)
|
|
return name
|
|
i += 1
|
|
# The function should never go through that line.
|
|
raise RuntimeError("The function should never go through that line.")
|
|
|
|
cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0)
|
|
cst1 = make_node("Constant", [], [_find_name("one")], value_int=1)
|
|
update = {}
|
|
for inode, node in enumerate(nodes):
|
|
if node.op_type != "ConstantOfShape":
|
|
continue
|
|
shape = node.input[0]
|
|
|
|
n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")])
|
|
a = make_node(
|
|
"Range",
|
|
[cst0.output[0], n.output[0], cst1.output[0]],
|
|
[_find_name(f"{shape}_RANGE")],
|
|
)
|
|
if len(node.attribute) == 1:
|
|
to = node.attribute[0].t.data_type
|
|
else:
|
|
to = TensorProto.FLOAT
|
|
ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to)
|
|
cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to)
|
|
d = make_node(
|
|
"Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")]
|
|
)
|
|
resh = make_node("Reshape", [d.output[0], shape], node.output)
|
|
update[inode] = [n, a, ac, cl, d, resh]
|
|
|
|
for inode, up in sorted(update.items(), reverse=True):
|
|
nodes[inode : inode + 1] = up
|
|
nodes.insert(0, cst0)
|
|
nodes.insert(1, cst1)
|
|
|
|
if isinstance(onx, GraphProto):
|
|
graph = make_graph(
|
|
nodes,
|
|
onx.name,
|
|
onx.input,
|
|
onx.output,
|
|
initializer=onx.initializer,
|
|
sparse_initializer=onx.sparse_initializer,
|
|
)
|
|
return graph
|
|
if isinstance(onx, FunctionProto):
|
|
new_onx = make_function(
|
|
onx.domain,
|
|
onx.name,
|
|
onx.input,
|
|
onx.output,
|
|
nodes,
|
|
opset_imports=onx.opset_import,
|
|
)
|
|
return new_onx
|
|
raise TypeError(f"Not implemented for type {type(onx)}.")
|
|
|
|
|
|
def _replace_constant_of_shape_value(
|
|
onx: GraphProto | FunctionProto, value_constant_of_shape: float
|
|
) -> GraphProto | FunctionProto:
|
|
"""Replaces all fill value of all nodes *ConstantOfShape*."""
|
|
if isinstance(onx, GraphProto):
|
|
nodes = list(onx.node)
|
|
elif isinstance(onx, FunctionProto):
|
|
nodes = list(onx.node)
|
|
else:
|
|
raise TypeError(f"Not implemented for type {type(onx)}.")
|
|
|
|
existing_names = set()
|
|
for node in nodes:
|
|
existing_names |= set(node.input)
|
|
existing_names |= set(node.output)
|
|
|
|
update = {}
|
|
for inode, node in enumerate(nodes):
|
|
if node.op_type != "ConstantOfShape":
|
|
continue
|
|
tensor = node.attribute[0].t
|
|
new_tensor = make_tensor(
|
|
tensor.name, tensor.data_type, [1], [value_constant_of_shape]
|
|
)
|
|
new_node = make_node("ConstantOfShape", node.input, node.output)
|
|
att = make_attribute(node.attribute[0].name, value=new_tensor)
|
|
new_node.attribute.append(att)
|
|
update[inode] = new_node
|
|
|
|
for inode, up in update.items():
|
|
nodes[inode] = up
|
|
|
|
if isinstance(onx, GraphProto):
|
|
graph = make_graph(
|
|
nodes,
|
|
onx.name,
|
|
onx.input,
|
|
onx.output,
|
|
initializer=onx.initializer,
|
|
sparse_initializer=onx.sparse_initializer,
|
|
)
|
|
return graph
|
|
if isinstance(onx, FunctionProto):
|
|
new_onx = make_function(
|
|
onx.domain,
|
|
onx.name,
|
|
onx.input,
|
|
onx.output,
|
|
nodes,
|
|
opset_imports=onx.opset_import,
|
|
)
|
|
return new_onx
|
|
raise TypeError(f"Not implemented for type {type(onx)}.")
|
|
|
|
|
|
def replace_initializer_by_constant_of_shape( # noqa: PLR0911
|
|
onx: FunctionProto | GraphProto | ModelProto,
|
|
threshold: int = 128,
|
|
ir_version: int | None = None,
|
|
use_range: bool = False,
|
|
value_constant_of_shape: float = 0.5,
|
|
):
|
|
"""Replace initializers or constant node by nodes *ConstantOfShape* to reduce the size.
|
|
|
|
This reduce the cost to write a unit test about a specific graph structure.
|
|
|
|
Args:
|
|
onx: ModelProto
|
|
threshold: every initializer under this threshold is not
|
|
impacted
|
|
ir_version: initializer must be specified as input for
|
|
`ir_version <= 3`, this must be specified if onx is
|
|
:class:`FunctionProto` or :class:`GraphProto`
|
|
use_range: if uses operator *Range* instead of *ConstantOfShape*
|
|
to avoid constant tensors
|
|
value_constant_of_shape: value to use as a value for all nodes
|
|
*ConstantOfShape*, a high value may produce nan or inf
|
|
predictions
|
|
|
|
Returns:
|
|
onx, modified ModelProto
|
|
|
|
The function is designed so that the function can be reapplied on a modified model
|
|
and either replace *ConstantOfShape* with *Range* operators, either replace the fill value
|
|
for every *ConstantOfShape*.
|
|
"""
|
|
if isinstance(onx, FunctionProto):
|
|
modified = False
|
|
new_nodes: list[NodeProto] = []
|
|
for node in onx.node:
|
|
if node.op_type == "Constant":
|
|
cst_nodes = _replace_constant(node, threshold, value_constant_of_shape)
|
|
if len(cst_nodes) == 2: # noqa: PLR2004
|
|
modified = True
|
|
new_nodes.extend(cst_nodes)
|
|
continue
|
|
new_nodes.append(node)
|
|
if modified:
|
|
new_onx = make_function(
|
|
onx.domain,
|
|
onx.name,
|
|
onx.input,
|
|
onx.output,
|
|
new_nodes,
|
|
opset_imports=onx.opset_import,
|
|
)
|
|
if use_range:
|
|
return _replace_constant_of_shape_with_range(new_onx)
|
|
if value_constant_of_shape != 1:
|
|
return _replace_constant_of_shape_value(
|
|
new_onx, value_constant_of_shape
|
|
)
|
|
return new_onx
|
|
if use_range:
|
|
return _replace_constant_of_shape_with_range(onx)
|
|
if value_constant_of_shape != 1:
|
|
return _replace_constant_of_shape_value(onx, value_constant_of_shape)
|
|
return onx
|
|
|
|
if isinstance(onx, ModelProto):
|
|
new_graph = replace_initializer_by_constant_of_shape(
|
|
onx.graph,
|
|
ir_version=ir_version or onx.ir_version,
|
|
threshold=threshold,
|
|
use_range=use_range,
|
|
value_constant_of_shape=value_constant_of_shape,
|
|
)
|
|
new_functions = [
|
|
replace_initializer_by_constant_of_shape(
|
|
f,
|
|
threshold=threshold,
|
|
ir_version=ir_version or onx.ir_version,
|
|
use_range=use_range,
|
|
value_constant_of_shape=value_constant_of_shape,
|
|
)
|
|
for f in onx.functions
|
|
]
|
|
model = make_model(
|
|
new_graph,
|
|
functions=new_functions,
|
|
producer_name=onx.producer_name,
|
|
producer_version=onx.producer_version,
|
|
ir_version=ir_version or onx.ir_version,
|
|
doc_string=onx.doc_string,
|
|
domain=onx.domain,
|
|
model_version=onx.model_version,
|
|
)
|
|
if len(onx.metadata_props) > 0: # pragma: no cover
|
|
values = {p.key: p.value for p in onx.metadata_props}
|
|
set_model_props(model, values)
|
|
|
|
del model.opset_import[:]
|
|
for oimp in onx.opset_import:
|
|
op_set = model.opset_import.add()
|
|
if oimp.domain == "" and oimp.version < 11 and use_range: # noqa: PLR2004
|
|
raise RuntimeError(
|
|
f"Range was introduced in opset 11 but opset is {oimp.version}."
|
|
)
|
|
if oimp.domain == "" and oimp.version < 9: # noqa: PLR2004
|
|
raise RuntimeError(
|
|
f"ConstantOfShape was introduced in "
|
|
f"opset 9 but opset is {oimp.version}."
|
|
)
|
|
op_set.domain = oimp.domain
|
|
op_set.version = oimp.version
|
|
return model
|
|
|
|
if not isinstance(onx, GraphProto):
|
|
raise TypeError(f"onx should be a GraphProto at this stage not {type(onx)}.")
|
|
|
|
n_modifications = 0
|
|
new_nodes = []
|
|
removed = set()
|
|
additional_inputs = []
|
|
|
|
new_inits: list[TensorProto] = []
|
|
for init in onx.initializer:
|
|
dims = tuple(init.dims)
|
|
size = np.prod(dims)
|
|
if size <= threshold:
|
|
new_inits.append(init)
|
|
continue
|
|
n_modifications += 1
|
|
new_name = f"{init.name}__SHAPE"
|
|
new_inits.append(
|
|
from_array(np.array(list(dims), dtype=np.int64), name=new_name)
|
|
)
|
|
dtype = tensor_dtype_to_np_dtype(init.data_type)
|
|
node = make_node(
|
|
"ConstantOfShape",
|
|
[new_name],
|
|
[init.name],
|
|
value=from_array(np.array([0.5], dtype=dtype)),
|
|
)
|
|
new_nodes.append(node)
|
|
removed.add(init.name)
|
|
if ir_version is not None and ir_version <= 3: # noqa: PLR2004
|
|
additional_inputs.append(
|
|
make_tensor_value_info(new_name, TensorProto.INT64, [len(dims)])
|
|
)
|
|
|
|
new_sparse_inits: list[SparseTensorProto] = []
|
|
for sp_init in onx.sparse_initializer:
|
|
dims = tuple(sp_init.dims)
|
|
size = np.prod(dims)
|
|
if size <= threshold:
|
|
new_sparse_inits.append(sp_init)
|
|
continue
|
|
raise NotImplementedError(
|
|
f"This feature is not yet implemented for a sparse initializer "
|
|
f"(indices.name={sp_init.indices.name!r}, "
|
|
f"values.name={sp_init.values.name!r})."
|
|
)
|
|
|
|
for node in onx.node:
|
|
if node.op_type == "Constant":
|
|
shape_nodes = _replace_constant(node, threshold, value_constant_of_shape)
|
|
if len(shape_nodes) == 2: # noqa: PLR2004
|
|
n_modifications += 1
|
|
new_nodes.extend(shape_nodes)
|
|
continue
|
|
modified = False
|
|
atts = []
|
|
for att in node.attribute:
|
|
if (
|
|
att.type == AttributeProto.GRAPH
|
|
and hasattr(att, "g")
|
|
and att.g is not None
|
|
):
|
|
g = replace_initializer_by_constant_of_shape(
|
|
att.g,
|
|
threshold=threshold,
|
|
ir_version=ir_version,
|
|
use_range=use_range,
|
|
value_constant_of_shape=value_constant_of_shape,
|
|
)
|
|
if id(g) != id(att.g):
|
|
modified = True
|
|
att = make_attribute(att.name, g) # noqa: PLW2901
|
|
atts.append(att)
|
|
if modified:
|
|
new_node = make_node(node.op_type, node.input, node.output)
|
|
new_node.attribute.extend(atts)
|
|
new_nodes.append(new_node)
|
|
n_modifications += 1
|
|
else:
|
|
new_nodes.append(node)
|
|
|
|
if n_modifications > 0:
|
|
graph = make_graph(
|
|
new_nodes,
|
|
onx.name,
|
|
[i for i in onx.input if i.name not in removed] + additional_inputs,
|
|
onx.output,
|
|
initializer=new_inits,
|
|
sparse_initializer=new_sparse_inits,
|
|
)
|
|
if use_range:
|
|
return _replace_constant_of_shape_with_range(graph)
|
|
if value_constant_of_shape != 1:
|
|
return _replace_constant_of_shape_value(graph, value_constant_of_shape)
|
|
return graph
|
|
if use_range:
|
|
return _replace_constant_of_shape_with_range(onx)
|
|
if value_constant_of_shape != 1:
|
|
return _replace_constant_of_shape_value(onx, value_constant_of_shape)
|
|
return onx
|