Files
Reinforced-Learning-Godot/rl/Lib/site-packages/onnx/test/inference_function_test.py
2024-10-30 22:14:35 +01:00

299 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) ONNX Project Contributors
from __future__ import annotations
import unittest
import numpy as np
import onnx
from onnx import TensorProto, TypeProto
from onnx.checker import ValidationError
from onnx.defs import OpSchema, get_all_schemas_with_history, get_schema
from onnx.helper import (
make_graph,
make_node,
make_opsetid,
make_tensor_type_proto,
make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnx.shape_inference import InferenceError, infer_node_outputs
ADD_SCHEMA = max(
(s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
key=lambda s: s.since_version,
)
RESHAPE_SCHEMA = max(
(
s
for s in get_all_schemas_with_history()
if s.name == "Reshape" and s.domain == ""
),
key=lambda s: s.since_version,
)
CLIP_SCHEMA = max(
(s for s in get_all_schemas_with_history() if s.name == "Clip" and s.domain == ""),
key=lambda s: s.since_version,
)
def _to_tensor_types(
tensor_types: dict[str, tuple[int, tuple[int | str | None, ...]]]
) -> dict[str, TypeProto]:
return {key: make_tensor_type_proto(*value) for key, value in tensor_types.items()}
def _run_case(
schema: OpSchema,
input_names: list[str],
output_names: list[str],
input_types: dict[str, TypeProto],
input_data: dict[str, np.ndarray] | None = None,
) -> dict[str, TypeProto]:
if input_data is None:
input_data = {}
return infer_node_outputs(
schema,
make_node(schema.name, input_names, output_names, domain=schema.domain),
input_types,
{key: from_array(arr) for key, arr in input_data.items()},
)
class TestInferenceFunctionCall(unittest.TestCase):
def test_add_inference(self) -> None:
cases = [
(
{"A": (TensorProto.FLOAT, ()), "B": (TensorProto.FLOAT, ())},
{"C": (TensorProto.FLOAT, ())},
),
(
{
"A": (TensorProto.FLOAT, (None, 2)),
"B": (TensorProto.FLOAT, (2,)),
},
{"C": (TensorProto.FLOAT, (None, 2))},
),
(
{
"A": (TensorProto.FLOAT, (None, 2)),
"B": (TensorProto.FLOAT, (1, 2)),
},
{"C": (TensorProto.FLOAT, (None, 2))},
),
(
{
"A": (TensorProto.DOUBLE, ("n", "m")),
"B": (TensorProto.DOUBLE, (1, "n", "m")),
},
{"C": (TensorProto.DOUBLE, (1, "n", "m"))},
),
(
{
"A": (TensorProto.FLOAT, ("x", 2)),
"B": (TensorProto.FLOAT, ("y", 2)),
},
{"C": (TensorProto.FLOAT, (None, 2))},
),
]
for ins, outs in cases:
assert _run_case(ADD_SCHEMA, ["A", "B"], ["C"], _to_tensor_types(ins)) == _to_tensor_types(outs) # type: ignore
def test_clip_inference_with_optional_input(self) -> None:
# Test case where the second input is optional
input_names = ["X", "", "max"]
output_names = ["Y"]
input_types = _to_tensor_types(
{"X": (TensorProto.FLOAT, (3, 4)), "max": (TensorProto.FLOAT, ())}
)
expected_output_types = _to_tensor_types({"Y": (TensorProto.FLOAT, (3, 4))})
assert (
_run_case(CLIP_SCHEMA, input_names, output_names, input_types)
== expected_output_types
)
def test_add_inference_raises_errors(self) -> None:
with self.assertRaises(ValidationError):
_run_case(
ADD_SCHEMA,
["A"],
["C"],
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
)
with self.assertRaises(ValidationError):
_run_case(
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}),
)
with self.assertRaises(InferenceError):
_run_case(
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types(
{
"A": (TensorProto.FLOAT, (2, 4)),
"B": (TensorProto.FLOAT, (3, 4)),
}
),
)
with self.assertRaises(KeyError):
_run_case(
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
)
def test_reshape_inference(self) -> None:
assert _run_case(
RESHAPE_SCHEMA,
["x", "t"],
["y"],
_to_tensor_types(
{
"x": (TensorProto.FLOAT, (5, 4)),
"t": (TensorProto.INT64, (3,)),
}
),
{"t": np.array([2, 2, 5], dtype=np.int64)},
) == _to_tensor_types({"y": (TensorProto.FLOAT, (2, 2, 5))})
def test_scan_inference_with_subgraph(self) -> None:
seq_len = "sequence"
input_size = 2
loop_state_size = 3
input_value_infos = [
make_tensor_value_info("loop_state_in", TensorProto.UNDEFINED, None),
make_tensor_value_info("input", TensorProto.UNDEFINED, None),
make_tensor_value_info("outer", TensorProto.UNDEFINED, None),
]
output_value_infos = [
make_tensor_value_info("loop_state_out", TensorProto.UNDEFINED, None),
make_tensor_value_info("output", TensorProto.FLOAT, (seq_len, input_size)),
]
subgraph = make_graph(
[
make_node("Identity", ["loop_state_in"], ["loop_state_out"]),
make_node("Add", ["input", "outer"], ["output"]),
],
"subgraph",
input_value_infos,
output_value_infos,
)
assert infer_node_outputs(
get_schema("Scan", 9),
make_node(
"Scan",
["loop_state_orig", "scan_input", "scan_outer"],
["loop_state_final", "scan_output"],
num_scan_inputs=1,
body=subgraph,
),
_to_tensor_types(
{
"loop_state_orig": (TensorProto.FLOAT, (loop_state_size,)),
"scan_input": (TensorProto.FLOAT, (seq_len, input_size)),
"scan_outer": (TensorProto.FLOAT, (input_size,)),
}
),
# Same as default value in Scan-9
opset_imports=[make_opsetid("", 9)],
ir_version=4,
) == _to_tensor_types(
{
"loop_state_final": (TensorProto.FLOAT, (loop_state_size,)),
"scan_output": (TensorProto.FLOAT, (seq_len, input_size)),
}
)
def test_inference_with_conflow(self) -> None:
model_script = """
<
ir_version: 8,
opset_import: ["" : 18, "onnxscript.atenlib" : 1],
producer_name: "pytorch",
producer_version: "2.1.0"
>
torch_jit (float input_0) => (float reault, int64 index)
{
reault, index = onnxscript.atenlib.aten_min_dim <dim = 0, keepdim = 1> (input_0)
}
<
domain: "onnxscript.atenlib",
opset_import: ["" : 18]
>
aten_min_dim <dim>(self) => (result_7, indices_6)
{
tmp = Shape (self)
tmp_0 = Size (tmp)
tmp_1 = Constant <value = int64 tmp_1 {0}> ()
tmp_1_cast = CastLike (tmp_1, tmp_0)
tmp_2 = Equal (tmp_0, tmp_1_cast)
cond = Not (tmp_2)
indices_6, result_7 = If (cond) <
then_branch = thenGraph_4 () => ( indices, result) {
dim = Constant <value_int: int = @dim> ()
tmp_3 = Constant <value_ints = [-1]> ()
dims = Reshape (dim, tmp_3)
result = ReduceMin <keepdims: int = @keepdim> (self, dims)
indices = ArgMin <axis: int = @dim, keepdims: int = @keepdim> (self)
}, else_branch = elseGraph_4 () => ( indices_4, result_5) {
indices_4 = Constant <value_int = 0> ()
result_5 = Identity (self)
}
>
}
"""
model = onnx.parser.parse_model(model_script)
onnx.shape_inference.infer_shapes(model, strict_mode=False)
with self.assertRaises(onnx.shape_inference.InferenceError):
onnx.shape_inference.infer_shapes(model, strict_mode=True)
def test_inference_with_attribute(self) -> None:
model_script = """
<
ir_version: 8,
opset_import: ["" : 18, "custom" : 1],
producer_name: "",
producer_version: "1.0"
>
MeanVarianceNormalization (float[N] x) => (float[M] y)
{
y = custom.custom_mvn <axes = [0]> (x)
}
<
domain: "custom",
opset_import: ["" : 18]
>
custom_mvn <axes>(X) => (Y)
{
Exponent = Constant <value = float {2.0}>()
Epsilon = Constant <value = float {1e-9}>()
axes = Constant <value_ints: ints = @axes>()
X_RM = ReduceMean (X, axes)
EX_squared = Pow (X_RM, Exponent)
X_squared = Pow (X, Exponent)
E_Xsquared = ReduceMean (X_squared, axes)
Variance = Sub (E_Xsquared, EX_squared)
STD = Sqrt (Variance)
X_variance = Sub (X, X_RM)
Processed_STD = Add (STD, Epsilon)
Y = Div (X_variance, Processed_STD)
}
"""
model = onnx.parser.parse_model(model_script)
# onnx.shape_inference.infer_shapes(model, strict_mode=False)
onnx.shape_inference.infer_shapes(model, strict_mode=True)
if __name__ == "__main__":
unittest.main()