191 lines
7.6 KiB
Python
191 lines
7.6 KiB
Python
# Copyright (c) ONNX Project Contributors
|
|
#
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import shutil
|
|
import warnings
|
|
|
|
import onnx.backend.test.case.model as model_test
|
|
import onnx.backend.test.case.node as node_test
|
|
from onnx import ONNX_ML, TensorProto, numpy_helper
|
|
|
|
TOP_DIR = os.path.realpath(os.path.dirname(__file__))
|
|
DATA_DIR = os.path.join(TOP_DIR, "data")
|
|
|
|
|
|
def generate_data(args: argparse.Namespace) -> None:
|
|
def prepare_dir(path: str) -> None:
|
|
if os.path.exists(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path)
|
|
|
|
# Clean the output directory before generating data for node testcases
|
|
# It is used to check new generated data is correct in CIs
|
|
node_root = os.path.join(args.output, "node")
|
|
original_dir_number = len(
|
|
[name for name in os.listdir(node_root) if os.path.isfile(name)]
|
|
)
|
|
if args.clean and os.path.exists(node_root):
|
|
for sub_dir in os.listdir(node_root):
|
|
if ONNX_ML or not sub_dir.startswith("test_ai_onnx_ml_"):
|
|
shutil.rmtree(os.path.join(node_root, sub_dir))
|
|
|
|
cases = model_test.collect_testcases()
|
|
# If op_type is specified, only include those testcases including the given operator
|
|
# Otherwise, include all of the testcases
|
|
if args.diff:
|
|
cases += node_test.collect_diff_testcases()
|
|
else:
|
|
cases += node_test.collect_testcases(args.op_type)
|
|
node_number = 0
|
|
|
|
for case in cases:
|
|
output_dir = os.path.join(args.output, case.kind, case.name)
|
|
prepare_dir(output_dir)
|
|
if case.kind == "node":
|
|
node_number += 1
|
|
if case.kind == "real":
|
|
with open(os.path.join(output_dir, "data.json"), "w") as fi:
|
|
json.dump(
|
|
{
|
|
"url": case.url,
|
|
"model_name": case.model_name,
|
|
"rtol": case.rtol,
|
|
"atol": case.atol,
|
|
},
|
|
fi,
|
|
sort_keys=True,
|
|
)
|
|
else:
|
|
assert case.model
|
|
with open(os.path.join(output_dir, "model.onnx"), "wb") as f:
|
|
f.write(case.model.SerializeToString())
|
|
assert case.data_sets
|
|
for i, (inputs, outputs) in enumerate(case.data_sets):
|
|
data_set_dir = os.path.join(output_dir, f"test_data_set_{i}")
|
|
prepare_dir(data_set_dir)
|
|
for j, input in enumerate(inputs):
|
|
with open(os.path.join(data_set_dir, f"input_{j}.pb"), "wb") as f:
|
|
if case.model.graph.input[j].type.HasField("map_type"):
|
|
f.write(
|
|
numpy_helper.from_dict(
|
|
input, case.model.graph.input[j].name
|
|
).SerializeToString()
|
|
)
|
|
elif case.model.graph.input[j].type.HasField("sequence_type"):
|
|
f.write(
|
|
numpy_helper.from_list(
|
|
input, case.model.graph.input[j].name
|
|
).SerializeToString()
|
|
)
|
|
elif case.model.graph.input[j].type.HasField("optional_type"):
|
|
f.write(
|
|
numpy_helper.from_optional(
|
|
input, case.model.graph.input[j].name
|
|
).SerializeToString()
|
|
)
|
|
else:
|
|
assert case.model.graph.input[j].type.HasField(
|
|
"tensor_type"
|
|
)
|
|
if isinstance(input, TensorProto):
|
|
f.write(input.SerializeToString())
|
|
else:
|
|
f.write(
|
|
numpy_helper.from_array(
|
|
input, case.model.graph.input[j].name
|
|
).SerializeToString()
|
|
)
|
|
for j, output in enumerate(outputs):
|
|
with open(os.path.join(data_set_dir, f"output_{j}.pb"), "wb") as f:
|
|
if case.model.graph.output[j].type.HasField("map_type"):
|
|
f.write(
|
|
numpy_helper.from_dict(
|
|
output, case.model.graph.output[j].name
|
|
).SerializeToString()
|
|
)
|
|
elif case.model.graph.output[j].type.HasField("sequence_type"):
|
|
f.write(
|
|
numpy_helper.from_list(
|
|
output, case.model.graph.output[j].name
|
|
).SerializeToString()
|
|
)
|
|
elif case.model.graph.output[j].type.HasField("optional_type"):
|
|
f.write(
|
|
numpy_helper.from_optional(
|
|
output, case.model.graph.output[j].name
|
|
).SerializeToString()
|
|
)
|
|
else:
|
|
assert case.model.graph.output[j].type.HasField(
|
|
"tensor_type"
|
|
)
|
|
if isinstance(output, TensorProto):
|
|
f.write(output.SerializeToString())
|
|
else:
|
|
f.write(
|
|
numpy_helper.from_array(
|
|
output, case.model.graph.output[j].name
|
|
).SerializeToString()
|
|
)
|
|
if not args.clean and node_number != original_dir_number:
|
|
warnings.warn(
|
|
"There are some models under 'onnx/backend/test/data/node' which cannot not"
|
|
" be generated by the script from 'onnx/backend/test/case/node'. Please add"
|
|
" '--clean' option for 'python onnx/backend/test/cmd_tools.py generate-data'"
|
|
" to cleanup the existing directories and regenerate them.",
|
|
Warning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser("backend-test-tools")
|
|
subparsers = parser.add_subparsers()
|
|
|
|
subparser = subparsers.add_parser(
|
|
"generate-data", help="convert testcases to test data."
|
|
)
|
|
subparser.add_argument(
|
|
"-c",
|
|
"--clean",
|
|
default=False,
|
|
action="store_true",
|
|
help="Clean the output directory before generating data for node testcases.",
|
|
)
|
|
subparser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
default=DATA_DIR,
|
|
help="output directory (default: %(default)s)",
|
|
)
|
|
subparser.add_argument(
|
|
"-t",
|
|
"--op_type",
|
|
default=None,
|
|
help="op_type for test case generation. (generates test data for the specified op_type only.)",
|
|
)
|
|
subparser.add_argument(
|
|
"-d",
|
|
"--diff",
|
|
default=False,
|
|
action="store_true",
|
|
help="only generates test data for those changed files (compared to the main branch).",
|
|
)
|
|
subparser.set_defaults(func=generate_data)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|