I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,60 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <list>
#include <utility>
#include "gtest/gtest.h"
#include "onnx/common/path.h"
#ifdef _WIN32
// Only test clean_relative_path and normalize_separator on non-Windows
// because Windows has its own implementation for them from std::filesystem::path.
#else
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
namespace Test {
TEST(PathTest, CleanRelativePathTest) {
// Already normal.
EXPECT_EQ(clean_relative_path("abc"), "abc");
EXPECT_EQ(clean_relative_path("abc/def"), "abc/def");
EXPECT_EQ(clean_relative_path("a/b/c"), "a/b/c");
EXPECT_EQ(clean_relative_path("."), ".");
EXPECT_EQ(clean_relative_path(".."), "..");
EXPECT_EQ(clean_relative_path("../.."), "../..");
EXPECT_EQ(clean_relative_path("../../abc"), "../../abc");
// Remove trailing slash
EXPECT_EQ(clean_relative_path("abc/"), "abc");
EXPECT_EQ(clean_relative_path("abc/def/"), "abc/def");
EXPECT_EQ(clean_relative_path("a/b/c/"), "a/b/c");
EXPECT_EQ(clean_relative_path("./"), ".");
EXPECT_EQ(clean_relative_path("../"), "..");
EXPECT_EQ(clean_relative_path("../../"), "../..");
// Remove doubled slash
EXPECT_EQ(clean_relative_path("abc//def//ghi"), "abc/def/ghi");
EXPECT_EQ(clean_relative_path("abc///"), "abc");
EXPECT_EQ(clean_relative_path("abc//"), "abc");
// Remove . elements
EXPECT_EQ(clean_relative_path("abc/./def"), "abc/def");
EXPECT_EQ(clean_relative_path("./abc/def"), "abc/def");
EXPECT_EQ(clean_relative_path("abc/."), "abc");
// Remove .. elements
EXPECT_EQ(clean_relative_path("abc/def/ghi/../jkl"), "abc/def/jkl");
EXPECT_EQ(clean_relative_path("abc/def/../ghi/../jkl"), "abc/jkl");
EXPECT_EQ(clean_relative_path("abc/def/.."), "abc");
EXPECT_EQ(clean_relative_path("abc/def/../.."), ".");
EXPECT_EQ(clean_relative_path("abc/def/../../.."), "..");
EXPECT_EQ(clean_relative_path("abc/def/../../../ghi/jkl/../../../mno"), "../../mno");
EXPECT_EQ(clean_relative_path("../abc"), "../abc");
// Combinations
EXPECT_EQ(clean_relative_path("abc/./../def"), "def");
EXPECT_EQ(clean_relative_path("abc//./../def"), "def");
EXPECT_EQ(clean_relative_path("abc/../../././../def"), "../../def");
}
} // namespace Test
} // namespace ONNX_NAMESPACE
#endif

View File

@ -0,0 +1,401 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "onnx/onnx_pb.h"
#include "onnx/shape_inference/implementation.h"
using namespace ONNX_NAMESPACE::shape_inference;
namespace ONNX_NAMESPACE {
namespace Test {
inline bool CompareShape(
const TensorShapeProto& inferredShape,
const TensorShapeProto& expectedShape,
bool checkSameParam = false) {
EXPECT_TRUE(inferredShape.dim_size() == expectedShape.dim_size())
<< "Dim size for inferred and expected shape is different.";
for (int i = 0; i < inferredShape.dim_size(); i++) {
EXPECT_TRUE(
(inferredShape.dim(i).has_dim_value() == expectedShape.dim(i).has_dim_value()) &&
(inferredShape.dim(i).has_dim_param() == expectedShape.dim(i).has_dim_param()))
<< "Inferred and expected dim values are different.";
EXPECT_TRUE(
inferredShape.dim(i).has_dim_value() ? inferredShape.dim(i).dim_value() == expectedShape.dim(i).dim_value()
: checkSameParam ? inferredShape.dim(i).dim_param() == expectedShape.dim(i).dim_param()
: true)
<< "Inferred and expected dims are different.";
}
return true;
}
TensorShapeProto RunDataPropagation(const char* graphCode, int domainVersion = 15) {
// Parses the graph from graphCode
GraphProto graph;
OnnxParser parser(graphCode);
auto status = parser.Parse(graph);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
// Constructs name to TypeProto map from value_info, input, output
std::unordered_map<std::string, TypeProto*> valueTypesByName;
for (auto& vi : *graph.mutable_value_info()) {
if (vi.has_type()) {
valueTypesByName[vi.name()] = vi.mutable_type();
}
}
for (auto& vi : *graph.mutable_input()) {
if (vi.has_type()) {
valueTypesByName[vi.name()] = vi.mutable_type();
}
}
for (auto& vi : *graph.mutable_output()) {
if (vi.has_type()) {
valueTypesByName[vi.name()] = vi.mutable_type();
}
}
// Constructs name to TensorProto map from initializer
std::unordered_map<std::string, const TensorProto*> inputDataByName;
for (const auto& tp : graph.initializer()) {
inputDataByName[tp.name()] = &tp;
}
// Collects data from constant nodes
for (const auto& n : graph.node()) {
if (n.op_type() != "Constant" || n.output().size() != 1) {
continue;
}
for (const auto& attr : n.attribute()) {
if (attr.name() == "value") {
if (attr.type() == AttributeProto::TENSOR && attr.has_t()) {
inputDataByName[n.output(0)] = &attr.t();
}
}
}
}
// Runs data propagation on each node
std::unordered_map<std::string, TensorShapeProto> generatedShapeDataByName;
auto* schemaRegistry = OpSchemaRegistry::Instance();
TensorShapeProto inferredShape;
for (auto n : graph.node()) {
// No need to run data propagation on Constant
if (n.op_type() == "Constant") {
continue;
}
DataPropagationContextImpl dataPropagationCtx(n, valueTypesByName, inputDataByName, generatedShapeDataByName);
const auto schema = schemaRegistry->GetSchema(n.op_type(), domainVersion, n.domain());
EXPECT_TRUE(schema->has_data_propagation_function());
schema->GetDataPropagationFunction()(dataPropagationCtx);
}
// Assuming the graph being tested only has 1 output.
// If this ever changes then fixes are required here.
const auto inputShapeDataIter = generatedShapeDataByName.find(graph.output(0).name());
EXPECT_TRUE(inputShapeDataIter != generatedShapeDataByName.cend());
inferredShape.CopyFrom(inputShapeDataIter->second);
// Returns the partial shape data for output
return inferredShape;
}
TEST(DataPropagationImplTest, ShapeTest) {
const char* code = R"ONNX(
agraph (int32[7,4,1] x) => (int32[3] y)
{
xs = Shape(x)
y = Cast<to = 7>(xs)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(7);
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
expected_tsp.mutable_dim()->Add()->set_dim_value(1);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SymbolicShapeTest) {
const char* code = R"ONNX(
agraph (int32[N,3,256,256] x) => (int32[4] y)
{
xs = Shape(x)
y = Cast<to = 7>(xs)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_param("N");
expected_tsp.mutable_dim()->Add()->set_dim_value(3);
expected_tsp.mutable_dim()->Add()->set_dim_value(256);
expected_tsp.mutable_dim()->Add()->set_dim_value(256);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp, true));
}
TEST(DataPropagationImplTest, CastTest) {
const char* code = R"ONNX(
agraph (int32[2,5] x) => (int32[2] y)
{
xs = Shape(x)
y = Cast<to = 7>(xs)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SqueezeTest) {
const char* code = R"ONNX(
agraph (int32[2,5] x) => (int32[2] z)
{
xs = Shape(x)
y = Squeeze(xs)
z = Cast<to = 7>(y)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, UnsqueezeTest) {
const char* code = R"ONNX(
agraph (int32[2,5] x) => (int32[1,2] w)
{
xs = Shape(x)
axis = Constant<value = int64[1] {1}>()
z = Unsqueeze(xs, axis)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SizeTest) {
const char* code = R"ONNX(
agraph (int64[1] x) => (int32[1] w)
<int64[3] init = {2,3,5}>
{
z = Size(init)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(3);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, AddTest) {
const char* code = R"ONNX(
agraph (int32[2,4,5] x, int32[2,4,5] y) => (int32[3] w)
{
xs = Shape(x)
ys = Shape(y)
z = Add(xs, ys)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
expected_tsp.mutable_dim()->Add()->set_dim_value(8);
expected_tsp.mutable_dim()->Add()->set_dim_value(10);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, AddSymbolicShapeTest) {
const char* code = R"ONNX(
agraph (int32[2,4,5] x, int32[2,4,M] y) => (int32[3] w)
{
xs = Shape(x)
ys = Shape(y)
z = Add(xs, ys)
w = Cast<to = 7>(z)
}
)ONNX";
// Add({2,4,5}, {2,4,M}) = {4,8,?}
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
expected_tsp.mutable_dim()->Add()->set_dim_value(8);
// Not computable so do not set value or param
expected_tsp.mutable_dim()->Add();
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SubTest) {
const char* code = R"ONNX(
agraph (int32[10,11,6] x, int32[5] y) => (int32[3] w)
{
xs = Shape(x)
ys = Shape(y)
z = Sub(xs, ys)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
expected_tsp.mutable_dim()->Add()->set_dim_value(6);
expected_tsp.mutable_dim()->Add()->set_dim_value(1);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, MulTest) {
const char* code = R"ONNX(
agraph (int32[2] x, int32[5,1,7] y) => (int32[3] w)
{
xs = Shape(x)
ys = Shape(y)
z = Mul(xs, ys)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(10);
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(14);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, ConcatTest) {
const char* code = R"ONNX(
agraph (int32[1,2] x, int32[3,4] y) => (int32[4] w)
{
xs = Shape(x)
ys = Shape(y)
z = Concat<axis = 0>(xs, ys)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(1);
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(3);
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, GatherTest) {
const char* code = R"ONNX(
agraph (int32[1,2,3,4,5,6] x) => (int32[3] w)
{
xs = Shape(x)
indices = Constant<value = int64[3] {0,3,5}>()
z = Gather<axis = 0>(xs, indices)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(1);
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
expected_tsp.mutable_dim()->Add()->set_dim_value(6);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, GatherNegativeIndicesTest) {
const char* code = R"ONNX(
agraph (int32[1,2,3,4,5,6] x) => (int32[2] w)
{
xs = Shape(x)
indices = Constant<value = int64[2] {-2,-1}>()
z = Gather<axis = 0>(xs, indices)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
expected_tsp.mutable_dim()->Add()->set_dim_value(6);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SliceTest) {
const char* code = R"ONNX(
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[2] w)
{
xs = Shape(x)
starts = Constant<value = int64[1] {1}>()
ends = Constant<value = int64[1] {7}>()
axes = Constant<value = int64[1] {0}>()
steps = Constant<value = int64[1] {3}>()
z = Slice(xs, starts, ends, axes, steps)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(2);
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SliceDefaultAxesAndStepTest) {
const char* code = R"ONNX(
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[3] w)
{
xs = Shape(x)
starts = Constant<value = int64[1] {2}>()
ends = Constant<value = int64[1] {5}>()
z = Slice(xs, starts, ends)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(3);
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
expected_tsp.mutable_dim()->Add()->set_dim_value(5);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
TEST(DataPropagationImplTest, SliceNegativeStartEndStepTest) {
const char* code = R"ONNX(
agraph (int32[1,2,3,4,5,6,7,8] x) => (int32[3] w)
{
xs = Shape(x)
starts = Constant<value = int64[1] {-3}>()
ends = Constant<value = int64[1] {-7}>()
axes = Constant<value = int64[1] {0}>()
steps = Constant<value = int64[1] {-2}>()
z = Slice(xs, starts, ends, axes, steps)
w = Cast<to = 7>(z)
}
)ONNX";
TensorShapeProto expected_tsp;
expected_tsp.mutable_dim()->Add()->set_dim_value(6);
expected_tsp.mutable_dim()->Add()->set_dim_value(4);
const auto propagated_tsp = RunDataPropagation(code);
EXPECT_TRUE(CompareShape(propagated_tsp, expected_tsp));
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,279 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/common/constants.h"
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
using namespace ONNX_NAMESPACE::checker;
#pragma warning(push)
#pragma warning(disable : 4530)
namespace ONNX_NAMESPACE {
namespace Test {
// Utilities. TODO: Turn them into reusable ONNX utilities for use by
TensorProto ToTensor(double value, TensorProto_DataType elem_type) {
TensorProto t;
t.set_data_type(elem_type);
switch (elem_type) {
case TensorProto_DataType::TensorProto_DataType_FLOAT:
t.add_float_data((float)value);
break;
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
t.add_double_data(value);
break;
// case TensorProto_DataType::TensorProto_DataType_FLOAT16:
// t.add_int32_data(onnxruntime::math::floatToHalf((float)value));
// break;
default:
assert(false);
}
return t;
}
void BuildNodes(FunctionProto& functionProto, const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
for (size_t i = 0; i < node_defs.size(); i++) {
const FunctionBodyHelper::NodeDef& node = node_defs[i];
auto* np = functionProto.add_node();
np->set_op_type(node.op_type);
for (const auto& inp : node.inputs) {
np->add_input(inp);
}
for (const auto& o : node.outputs) {
np->add_output(o);
}
for (const auto& attr : node.attributes) {
*(np->add_attribute()) = attr.proto;
}
}
}
bool BuildFunctionProto(
FunctionProto& functionProto,
const OpSchema& schema,
const std::vector<FunctionBodyHelper::NodeDef>& node_defs) {
BuildNodes(functionProto, node_defs);
schema.BuildFunction(functionProto);
return true;
}
// A monomorphic context-dependent function test-case.
static bool
BuildFloatFunctionBody(const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// Create a scalar-tensor constant 2.0 of float type:
auto two_as_tensor = ToTensor(2.0, TensorProto_DataType::TensorProto_DataType_FLOAT);
std::vector<FunctionBodyHelper::NodeDef> body{// nodes: {outputs, op, inputs, attributes}
{{"Two"}, "Constant", {}, {{"value", two_as_tensor}}},
{{"Y"}, "Mul", {"X", "Two"}}};
return BuildFunctionProto(functionProto, schema, body);
}
void RegisterCustomFuncFloatSchema() {
ONNX_NAMESPACE::OpSchema schema;
schema.SetName("CustomFuncFloat")
.SetDomain(ONNX_DOMAIN)
.SinceVersion(12)
.SetDoc("This operator returns an output tensor that is twice the input tensor.")
.Input(0, "X", "Input tensor", "T", OpSchema::Single)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single)
.TypeConstraint("T", {"tensor(float)"}, "Type of the input and output values")
.SetContextDependentFunctionBodyBuilder(BuildFloatFunctionBody);
ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema);
(void)unused;
}
// Test for Context dependant function without type context
TEST(FunctionAPITest, ContextDependentFunctionTest) {
RegisterCustomFuncFloatSchema();
const auto* schema = OpSchemaRegistry::Schema("CustomFuncFloat", 12, ONNX_DOMAIN);
EXPECT_TRUE(schema);
EXPECT_FALSE(schema->HasFunction());
EXPECT_TRUE(schema->HasContextDependentFunction());
NodeProto nodeProto;
nodeProto.set_op_type("CustomFuncFloat");
nodeProto.add_input("X");
nodeProto.add_output("Y");
FunctionBodyBuildContextImpl ctx(nodeProto);
FunctionProto fnProto;
EXPECT_TRUE(schema->BuildContextDependentFunction(ctx, fnProto));
EXPECT_EQ(fnProto.node_size(), 2);
LexicalScopeContext lexicalScope;
CheckerContext checkerCtx;
std::unordered_map<std::string, int> opset_imports({{ONNX_DOMAIN, 12}});
checkerCtx.set_opset_imports(opset_imports);
checkerCtx.set_ir_version(7);
check_function(fnProto, checkerCtx, lexicalScope);
}
// A polymorphic context-dependent function test-case.
static bool
BuildFunctionBody(const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// Create a scalar-tensor constant 2.0 of input-type:
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type()))
return false;
auto elem_type = (TensorProto_DataType)tp->tensor_type().elem_type();
auto two_as_tensor = ToTensor(2.0, elem_type);
std::vector<FunctionBodyHelper::NodeDef> body{// nodes: {outputs, op, inputs, attributes}
{{"Two"}, "Constant", {}, {{"value", two_as_tensor}}},
{{"Y"}, "Mul", {"X", "Two"}}};
return BuildFunctionProto(functionProto, schema, body);
}
void RegisterCustomFunctionSchema() {
ONNX_NAMESPACE::OpSchema schema;
schema.SetName("CustomFunction")
.SetDomain(ONNX_DOMAIN)
.SinceVersion(12)
.SetDoc("This operator returns an output tensor that is twice the input tensor.")
.Input(0, "X", "Input tensor", "T", OpSchema::Single)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single)
.TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
.SetContextDependentFunctionBodyBuilder(BuildFunctionBody);
ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema);
(void)unused;
}
TEST(FunctionAPITest, VersionedFunctionBodyTest) {
// This test illustrate issues of ONNX function ops.
// It is over simplified in that only one primary op (Sub) is used in function body.
// ONNX opset 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// MySub: 2 9 // MySub function op is created at opset 2.
// // Its semantic is updated at opset 7
// Body Ideal: 2 6 7 9 13 14 16 // Ideally function body shall be provided
// // each time there is any version bump of
// // used primary ops. It will be more
// // frequent
// // if more primary ops are used.
// Body Real: 2 9 16 // In real life, we seldom add function body
// // due to primary op update
// Sub: 1 6 7 13 14 // Version bumps of Sub
// Model: y y y y n n n y y y y n n n y y y // Model can(y)/cannot(n) used
// with opset import version.
ONNX_NAMESPACE::OpSchema schema_ver2;
schema_ver2.SetName("MySub")
.SetDomain(ONNX_DOMAIN)
.SinceVersion(2)
.SetDoc("Z = Sub (X, Y)")
.Input(0, "X", "Input tensor X", "T", OpSchema::Single)
.Input(1, "Y", "Input tensor Y", "T", OpSchema::Single)
.Output(0, "Z", "Output tensor Z", "T", OpSchema::Single)
.TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
.FunctionBody(
R"ONNX(
{
Z = Sub (X, Y)
}
)ONNX",
2);
ONNX_NAMESPACE::OpSchema schema_ver9;
schema_ver9.SetName("MySub")
.SetDomain(ONNX_DOMAIN)
.SinceVersion(9)
.SetDoc("Z = Sub (X, Y)")
.Input(0, "X", "Input tensor X", "T", OpSchema::Single)
.Input(1, "Y", "Input tensor Y", "T", OpSchema::Single)
.Output(0, "Z", "Output tensor Z", "T", OpSchema::Single)
.TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values")
.FunctionBody(
R"ONNX(
{
Z = Sub (X, Y)
}
)ONNX",
9)
.FunctionBody(
R"ONNX(
{
Z = Sub (X, Y)
}
)ONNX",
16);
ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(schema_ver2);
(void)unused2;
ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(schema_ver9);
(void)unused9;
const auto* schema2 = OpSchemaRegistry::Schema("MySub", 2, ONNX_DOMAIN);
EXPECT_TRUE(schema2);
for (int model_opset_import = 2; model_opset_import < 9; model_opset_import++) {
try {
bool validate = true;
const FunctionProto* function = schema2->GetFunction(model_opset_import, validate);
if (model_opset_import >= 6) { // function body should be updated at opset 6 where Sub is updated
ASSERT_TRUE(function == nullptr);
} else {
ASSERT_TRUE(function);
}
} catch (std::runtime_error err) {
ASSERT_TRUE(model_opset_import == 6 || model_opset_import == 7 || model_opset_import == 8);
}
}
const auto* schema9 = OpSchemaRegistry::Schema("MySub", 9, ONNX_DOMAIN);
EXPECT_TRUE(schema9);
for (int model_opset_import = 9; model_opset_import < 10; model_opset_import++) {
try {
const FunctionProto* function = schema9->GetFunction(model_opset_import);
ASSERT_TRUE(function);
} catch (std::runtime_error err) {
ASSERT_TRUE(model_opset_import == 13 || model_opset_import == 14 || model_opset_import == 15);
}
}
}
TEST(FunctionAPITest, TypeContextTest) {
RegisterCustomFunctionSchema();
const auto* schema = OpSchemaRegistry::Schema("CustomFunction", 12, ONNX_DOMAIN);
EXPECT_TRUE(schema);
EXPECT_FALSE(schema->HasFunction());
EXPECT_TRUE(schema->HasContextDependentFunction());
NodeProto nodeProto;
nodeProto.set_op_type("CustomFunction");
nodeProto.add_input("X");
nodeProto.add_output("Y");
TypeProto floatTypeProto;
floatTypeProto.mutable_tensor_type()->set_elem_type(TensorProto_DataType::TensorProto_DataType_FLOAT);
FunctionBodyBuildContextImpl ctx(nodeProto, {floatTypeProto});
FunctionProto fnProto;
EXPECT_TRUE(schema->BuildContextDependentFunction(ctx, fnProto));
EXPECT_EQ(fnProto.node_size(), 2);
LexicalScopeContext lexicalScope;
CheckerContext checkerCtx;
std::unordered_map<std::string, int> opset_imports({{ONNX_DOMAIN, 12}});
checkerCtx.set_opset_imports(opset_imports);
checkerCtx.set_ir_version(7);
check_function(fnProto, checkerCtx, lexicalScope);
}
} // namespace Test
} // namespace ONNX_NAMESPACE
#pragma warning(pop)

View File

@ -0,0 +1,49 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/common/constants.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
namespace Test {
TEST(FunctionAPITest, GetFunctionOpWithVersion) {
const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 9, "");
EXPECT_TRUE(schema);
EXPECT_TRUE(schema->HasFunction());
auto func = schema->GetFunction();
EXPECT_EQ(func->name(), "MeanVarianceNormalization");
}
TEST(FunctionAPITest, GetMeanVarianceNormalizationFunctionWithVersion) {
{
const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 13, "");
EXPECT_TRUE(schema);
EXPECT_TRUE(schema->HasFunction());
auto func = schema->GetFunction();
EXPECT_EQ(func->name(), "MeanVarianceNormalization");
}
{
const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 17, "");
EXPECT_TRUE(schema);
EXPECT_TRUE(schema->HasFunction());
auto func = schema->GetFunction();
EXPECT_EQ(func->name(), "MeanVarianceNormalization");
}
{
const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 18, "");
EXPECT_TRUE(schema);
EXPECT_TRUE(schema->HasFunction());
auto func = schema->GetFunction();
EXPECT_EQ(func->name(), "MeanVarianceNormalization");
}
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,558 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include <set>
#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/common/constants.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
#include "onnx/defs/schema.h"
#include "onnx/onnx-operators_pb.h"
#include "onnx/onnx_pb.h"
#include "onnx/shape_inference/implementation.h"
namespace ONNX_NAMESPACE {
namespace Test {
using namespace checker;
using TENSOR_TYPES_MAP = std::unordered_map<std::string, std::vector<std::string>>;
void GetFunctionProtoOpsetImport(
const OpSchema& op,
const FunctionProto* function_proto,
std::unordered_map<std::string, int>& op_set) {
if (function_proto->opset_import_size() > 0) {
for (const auto& opset_import : function_proto->opset_import()) {
op_set.insert({opset_import.domain(), opset_import.version()});
}
} else {
op_set.insert({op.domain(), op.since_version()});
}
}
void VerifyTypeConstraint(const OpSchema& function_op, const FunctionProto* function_proto, int& counter) {
// This is a simple partial type-checker for a function-body.
// TODO: Revisit to make the type-checker more complete.
TENSOR_TYPES_MAP tc_map;
std::set<std::string> primitive_types(OpSchema::all_tensor_types().begin(), OpSchema::all_tensor_types().end());
for (const auto& input : function_op.inputs()) {
std::string name = input.GetName();
auto& tvec = tc_map[name];
for (const auto& t : input.GetTypes()) {
tvec.emplace_back(*t);
}
}
for (const auto& output : function_op.outputs()) {
std::string name = output.GetName();
auto& tvec = tc_map[name];
for (const auto& t : output.GetTypes()) {
tvec.emplace_back(*t);
}
}
std::unordered_map<std::string, int> op_set;
GetFunctionProtoOpsetImport(function_op, function_proto, op_set);
for (auto& node : function_proto->node()) {
std::string op_type = node.op_type();
std::unordered_map<std::string, int>::const_iterator it = op_set.find(node.domain());
if (it == op_set.end()) {
fail_check(
"Op " + op_type + " of domain " + node.domain() + " used in " + function_op.Name() +
" function body does not has a opset import.");
}
int opset_version = it->second;
const OpSchema* schema = OpSchemaRegistry::Schema(op_type, opset_version, node.domain());
// Check that the types of actual inputs, if known, are legal as per schema
// of called op:
auto num_formal_inputs = static_cast<size_t>(schema->inputs().size());
auto num_actual_inputs = static_cast<size_t>(node.input_size());
for (size_t i = 0; i < num_actual_inputs; ++i) {
auto actual_param_name = node.input(static_cast<int>(i));
auto iter = tc_map.find(actual_param_name);
if (iter != tc_map.end()) {
// if i >= num_formal_inputs, it is a variadic parameter corresponding
// to the last formal parameter.
auto formal_i = std::min(i, num_formal_inputs - 1);
const auto& types = schema->inputs().at(formal_i).GetTypes();
std::unordered_set<std::string> allowed_types;
for (auto& s : types) {
allowed_types.insert(*s);
}
for (auto& actual_type : iter->second) {
if (allowed_types.find(actual_type) == allowed_types.end()) {
fail_check(
"Input type " + actual_type + " of parameter " + actual_param_name + " of function " +
function_op.Name() + " is not allowed by operator " + op_type);
}
}
}
}
// No simple check exists for outputs: we need to integrate type inference
// to identify the possible output types and verify that they are included
// in the function-schema.
}
++counter;
}
// Testing the function-definitions provided for function-ops in ONNX schema registry.
// We type-check the function-definition for all possible input-typings, as permitted
// by the op-schema. Since the type-checking is dependent on attribute-values, we specify
// the attribute-values for which we want to do the testing down below.
// The set of attribute-values (for testing a function) is represented using a vector.
using AttributeValues = std::vector<AttributeProto>;
// FunctionOpAttributeMap: Used to implement a map from OpSchema to a set of AttributeValues
// (implemented as a vector). The testing will be done for each attribute-values specified.
struct FunctionOpAttributeMap {
std::unordered_map<std::string, std::vector<AttributeValues>> map;
std::string key(std::string domain, std::string opname, int opset_version) const {
return domain + ":" + opname + ":" + std::to_string(opset_version);
}
void addTestCase(const std::string& opname, int opset_version, std::initializer_list<const char*> attributes) {
auto& schema_test_cases = map[key("", opname, opset_version)];
schema_test_cases.push_back(AttributeValues());
auto& test_case = schema_test_cases.back();
for (auto attr_text : attributes) {
test_case.push_back(AttributeProto());
OnnxParser::Parse(test_case.back(), attr_text);
}
}
FunctionOpAttributeMap() {
addTestCase("Elu", 6, {"alpha = 1.0"});
addTestCase("LeakyRelu", 16, {"alpha = 0.1"});
addTestCase("HardSigmoid", 6, {"alpha = 0.2", "beta=0.5"});
addTestCase("Selu", 6, {"alpha = 1.6", "gamma=1.05"});
addTestCase("ReduceL1", 18, {}); // Use default-value for attributes
addTestCase("ReduceL1", 18, {"keepdims = 0"});
addTestCase("ReduceL1", 18, {"noop_with_empty_axes = 1"});
addTestCase("ReduceL2", 18, {});
addTestCase("ReduceL2", 18, {"noop_with_empty_axes = 1", "keepdims = 0"});
addTestCase("ReduceSumSquare", 18, {});
addTestCase("ReduceLogSumExp", 18, {});
addTestCase("ThresholdedRelu", 10, {"alpha = 0.9"});
addTestCase("HannWindow", 17, {"output_datatype = 1", "periodic = 1"});
addTestCase("HammingWindow", 17, {"output_datatype = 1", "periodic = 1"});
addTestCase("BlackmanWindow", 17, {"output_datatype = 1", "periodic = 1"});
addTestCase("MeanValueNormalization", 13, {});
addTestCase("AffineGrid", 20, {"align_corners = 0"});
addTestCase("AffineGrid", 20, {"align_corners = 1"});
// The following test-cases fails, correctly so: Some clarification/changes required
// to handle unsigned integers or similar issues:
// addTestCase("Shrink", 9, {"bias = 0.0", "lambd = 0.5"});
// addTestCase("ReduceLogSum", 18, {});
// addTestCase("Range", 11, {});
// The following test-case fails because the checker doesn't support handling of
// default-values of attributes of function-ops
// addTestCase("ThresholdedRelu", 10, {});
}
const std::vector<AttributeValues>& getTestCases(const OpSchema& schema) {
auto key_value = key(schema.domain(), schema.Name(), schema.SinceVersion());
auto it = map.find(key_value);
if (it != map.end())
return it->second;
if (schema.attributes().size() == 0) {
// Test with no-attributes
map[key_value].push_back(std::vector<AttributeProto>());
}
return map[key_value];
}
static FunctionOpAttributeMap& instance() {
static FunctionOpAttributeMap _instance;
return _instance;
}
};
struct FunctionTypeChecker {
const OpSchema& schema;
const FunctionProto& function_proto;
const std::vector<AttributeValues>* attribute_cases;
FunctionTypeChecker(const OpSchema& op_schema, const FunctionProto& proto)
: schema(op_schema), function_proto(proto) {
attribute_cases = &FunctionOpAttributeMap::instance().getTestCases(op_schema);
}
// Binds each type-variable in schema to a type-value
std::unordered_map<std::string, DataType> typeVarBindings;
std::vector<std::string> errors;
void recordError(const std::string& error, AttributeValues attrs) {
std::ostringstream ostr;
ostr << "Type checking failed for instantiation " << schema.Name() << ":" << schema.SinceVersion() << " {";
for (auto& pair : typeVarBindings) {
ostr << pair.first << " = " << *pair.second << ", ";
}
for (auto& attr : attrs) {
ostr << attr << ", ";
}
ostr << "}\n" << error << "\n";
errors.push_back(ostr.str());
}
void recordSuccess(AttributeValues attrs) {
std::cout << "Type checking succeeded for instantiation " << schema.Name() << ":" << schema.SinceVersion() << " {";
for (auto& pair : typeVarBindings) {
std::cout << pair.first << " = " << *pair.second << ", ";
}
for (auto& attr : attrs) {
std::cout << attr << ", ";
}
std::cout << "}\n";
}
// forTypeVar: This is used to iterate through all possible bindings of type-values
// to all type-variables used in the op schema, and invoke the type-checker for
// each possible instantiation.
void forTypeVar(int i) {
auto& typeConstraintVector = schema.typeConstraintParams();
if (i < typeConstraintVector.size()) {
std::string typeVar = typeConstraintVector[i].type_param_str;
auto& values = schema.typeConstraintMap().at(typeVar).first;
for (auto typeValue : values) {
typeVarBindings[typeVar] = typeValue;
// Now, process remaining type-variables
forTypeVar(i + 1);
}
} else {
// Generated a complete instantiation of type-values to all type-variables.
// Now, check for this instantiation.
typeCheckBinding();
}
}
// typeCheckBinding: Type-check the function-body for the current type-instantiation
void typeCheckBinding() {
std::vector<TypeProto> input_types;
for (const auto& input : schema.inputs()) {
DataType datatype = (1 == input.GetTypes().size())
?
// Select the single possible type
(*(input.GetTypes().begin()))
:
// Select the type bound to the type-var in current instantiation
typeVarBindings[input.GetTypeStr()];
input_types.push_back(Utils::DataTypeUtils::ToTypeProto(datatype));
}
for (auto& attribute_vals : *attribute_cases) {
ONNX_TRY {
auto output_types = shape_inference::InferFunctionOutputTypes(function_proto, input_types, attribute_vals);
}
ONNX_CATCH(ONNX_NAMESPACE::InferenceError & e) {
ONNX_HANDLE_EXCEPTION(([&]() { recordError(e.what(), attribute_vals); }));
}
}
}
std::string checkAll() {
if (attribute_cases->size() > 0)
forTypeVar(0);
std::string all_errors = "";
for (const std::string& error : errors)
all_errors += error;
return all_errors;
}
};
void VerifyFunction(const OpSchema& op, const FunctionProto* function_proto, int& counter) {
// Verify function proto is valid
if (!function_proto) {
fail_check("Cannot get function body for op '", op.Name(), "'");
}
CheckerContext ctx;
std::unordered_map<std::string, int> op_set;
GetFunctionProtoOpsetImport(op, function_proto, op_set);
auto version_range = OpSchemaRegistry::DomainToVersionRange::Instance().Map().at(op.domain());
if (op.since_version() > version_range.second || op.since_version() < version_range.first) {
fail_check("Invalid function version in function op '", op.Name(), "'");
}
ctx.set_opset_imports(op_set);
ctx.set_is_main_graph(false);
LexicalScopeContext lex_ctx;
ONNX_TRY {
check_function(*function_proto, ctx, lex_ctx);
}
ONNX_CATCH(ValidationError & ex) {
ONNX_HANDLE_EXCEPTION([&]() { fail_check(ex.what()); });
}
// Verify function op has compatible Type constraints defined in
// op and function body.
VerifyTypeConstraint(op, function_proto, counter);
FunctionTypeChecker type_checker(op, *function_proto);
auto type_errors = type_checker.checkAll();
auto success = (type_errors == "");
ASSERT_TRUE(success) << type_errors;
}
// Verify registered ops with function body has compatible
// definition on TypeConstraints between ops and function body
TEST(FunctionVerification, VerifyFunctionOps) {
const std::vector<OpSchema> schemas = OpSchemaRegistry::get_all_schemas();
int function_counter = 0, verified_counter = 0;
for (const auto s : schemas) {
if (!s.HasFunction())
continue;
// Skip test for functions with known errors that need to be fixed:
// Range currently permits int16 parameters, but the operator Sub, called
// from the body of Range does not yet support int16 parameter.
if (s.Name() == "Range")
continue;
ONNX_TRY {
++function_counter;
std::vector<int> function_versions = s.function_opset_versions();
for (int function_version : function_versions) {
auto function_body = s.GetFunction(function_version);
VerifyFunction(s, function_body, verified_counter);
}
}
ONNX_CATCH(ONNX_NAMESPACE::checker::ValidationError e) {
ONNX_HANDLE_EXCEPTION([&]() { FAIL() << e.what(); });
}
}
std::cerr << "[ ] Verified " << verified_counter << "/" << function_counter << " Functions." << std::endl;
}
// Verify that FunctionExpandHelper obtains missing default attributes
// from schema and adds them to ops in expanded subgraph.
TEST(FunctionVerification, VerifyFunctionExpandHelper) {
GraphProto graph;
NodeProto* new_node = graph.add_node();
new_node->set_op_type("MeanVarianceNormalization");
const auto* schema = OpSchemaRegistry::Schema("MeanVarianceNormalization", 9, "");
const FunctionProto* func = schema->GetFunction();
const auto default_axes_attribute = schema->attributes().at("axes").default_value;
FunctionExpandHelper(*new_node, *func, graph);
for (const auto& node : graph.node()) {
if (node.op_type() == "ReduceMean") {
auto attr = node.attribute(0);
EXPECT_EQ(attr.name(), "axes");
EXPECT_EQ(attr.ints().size(), default_axes_attribute.ints().size());
for (int i = 0; i < default_axes_attribute.ints().size(); ++i) {
EXPECT_EQ(attr.ints(i), default_axes_attribute.ints(i));
}
return;
}
}
FAIL() << "During expanding MeanVarianceNormalization function, "
<< "the default attribute `axes` has not been assigned to ReduceMean op.";
}
void RegisterFunctionSchema() {
ONNX_NAMESPACE::OpSchema function_schema;
function_schema.SetName("DynamicQuantizeLinear_Fake")
.SetDomain(AI_ONNX_ML_DOMAIN)
.SinceVersion(2)
.SetDoc("Test Op")
.Input(0, "x", "Input tensor", "T1")
.Output(0, "y", "Quantized output tensor", "T2")
.Output(
1, "y_scale", "Output scale. It's a scalar, which means a per-tensor/layer quantization.", "tensor(float)")
.Output(2, "y_zero_point", "Output zero point. It's a scalar, which means a per-tensor/layer quantization.", "T2")
.TypeConstraint("T1", {"tensor(float)"}, "Constrain 'x' to float tensor.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain 'y_zero_point' and 'y' to 8-bit unsigned integer tensor.")
.FunctionBody(
FunctionBodyHelper::BuildNodes(
{// nodes: {outputs, op, inputs, attributes}
FunctionBodyHelper::Const<float>("Q_Min", 0.f),
FunctionBodyHelper::Const<float>("Q_Max", 255.f),
{{"X_Min"}, "ReduceMin", {"x"}, {MakeAttribute("keepdims", int64_t(0))}},
{{"X_Min_Adjusted"}, "Min", {"X_Min", "Q_Min"}},
{{"X_Max"}, "ReduceMax", {"x"}, {MakeAttribute("keepdims", int64_t(0))}},
{{"X_Max_Adjusted"}, "Max", {"X_Max", "Q_Min"}},
{{"X_Range"}, "Sub", {"X_Max_Adjusted", "X_Min_Adjusted"}},
{{"Scale"}, "Div", {"X_Range", "Q_Max"}},
{{"Min_Scaled"}, "Div", {"X_Min_Adjusted", "Scale"}},
{{"Initial_ZeroPoint_FP"}, "Sub", {"Q_Min", "Min_Scaled"}},
{{"Clipped_ZeroPoint_FP"}, "Clip", {"Initial_ZeroPoint_FP", "Q_Min", "Q_Max"}},
{{"Rounded_ZeroPoint_FP"}, "Round", {"Clipped_ZeroPoint_FP"}},
{{"Zeropoint"}, "Cast", {"Rounded_ZeroPoint_FP"}, {MakeAttribute("to", int64_t(2))}},
{{"y_scale"}, "Identity", {"Scale"}},
{{"y_zero_point"}, "Identity", {"Zeropoint"}},
{{"y"}, "QuantizeLinear", {"x", "Scale", "Zeropoint"}}}),
[]() {
std::vector<OperatorSetIdProto> operator_sets(2);
auto& onnx_opset = operator_sets[0];
onnx_opset.set_domain("");
onnx_opset.set_version(13);
auto& test_opset = operator_sets[1];
test_opset.set_domain(AI_ONNX_ML_DOMAIN);
test_opset.set_version(2);
return operator_sets;
}());
ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(function_schema);
(void)unused;
}
TEST(FunctionVerification, VerifyFunctionBodyWithMultipleDomains) {
RegisterFunctionSchema();
const auto* schema = OpSchemaRegistry::Schema("DynamicQuantizeLinear_Fake", 2, AI_ONNX_ML_DOMAIN);
EXPECT_TRUE(schema);
EXPECT_TRUE(schema->HasFunction());
EXPECT_FALSE(schema->HasContextDependentFunction());
const FunctionProto* fnProto = schema->GetFunction();
EXPECT_EQ(fnProto->node_size(), 16);
LexicalScopeContext lexicalScope;
CheckerContext checkerCtx;
std::unordered_map<std::string, int> opset_imports({{AI_ONNX_ML_DOMAIN, 2}, {"", 13}});
checkerCtx.set_opset_imports(opset_imports);
checkerCtx.set_ir_version(7);
check_function(*fnProto, checkerCtx, lexicalScope);
}
TEST(FunctionVerification, VerifyModelLocalFunctions) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 13, "custom_domain_1" : 1, "custom_domain_2" : 1],
producer_name: "FunctionProtoTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for model local functions."
>
agraph (float[N] x) => (uint8[N] out)
{
o1, o2 = custom_domain_1.bar(x)
o3 = Add(o1, o2)
o4 = custom_domain_2.foo(o3)
out = Identity(o4)
}
<
domain: "custom_domain_1",
opset_import: [ "" : 13],
doc_string: "Test function proto"
>
bar (x) => (o1, o2) {
o1 = Identity (x)
o2 = Identity (o1)
}
<
domain: "custom_domain_2",
opset_import: [ "" : 13],
doc_string: "Test function proto"
>
foo (x) => (y) {
Q_Min = Constant <value = float[1] {0.0}> ()
Q_Max = Constant <value = float[1] {255.0}> ()
X_Min = ReduceMin <keepdims = 0> (x)
X_Max = ReduceMax <keepdims = 0> (x)
X_Range = Sub (X_Max, X_Min)
Scale = Div (X_Range, Q_Max)
ZeroPoint_FP = Sub (Q_Min, Scale)
Zeropoint = Cast <to = 2> (ZeroPoint_FP)
y = QuantizeLinear (x, Scale, Zeropoint)
}
)ONNX";
ModelProto model;
auto status = OnnxParser::Parse(model, code);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
check_model(model);
ShapeInferenceOptions options{true, 1, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}
TEST(FunctionVerification, VerifyNestedModelLocalFunctions) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 13, "custom_domain_1" : 1, "custom_domain_2" : 1],
producer_name: "FunctionProtoTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for model local functions."
>
agraph (float[N] x) => (uint8[N] out)
{
o1, o2 = custom_domain_1.bar(x)
o3 = Add(o1, o2)
o4 = custom_domain_2.foo(o3)
out = Identity(o4)
}
<
domain: "custom_domain_1",
opset_import: [ "" : 13],
doc_string: "Test function proto"
>
bar (x) => (o1, o2) {
o1 = Identity (x)
o2 = Identity (o1)
}
<
domain: "custom_domain_2",
opset_import: [ "" : 13, "custom_domain_3" : 1],
doc_string: "Test function proto"
>
foo (x) => (o4) {
o1 = custom_domain_3.foo (x)
o4 = Identity (o1)
}
<
domain: "custom_domain_3",
opset_import: [ "" : 13],
doc_string: "Test function proto"
>
foo (x) => (y) {
Q_Min = Constant <value = float[1] {0.0}> ()
Q_Max = Constant <value = float[1] {255.0}> ()
X_Min = ReduceMin <keepdims = 0> (x)
X_Max = ReduceMax <keepdims = 0> (x)
X_Range = Sub (X_Max, X_Min)
Scale = Div (X_Range, Q_Max)
ZeroPoint_FP = Sub (Q_Min, Scale)
Zeropoint = Cast <to = 2> (ZeroPoint_FP)
y = QuantizeLinear (x, Scale, Zeropoint)
}
)ONNX";
ModelProto model;
auto status = OnnxParser::Parse(model, code);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
check_model(model);
ShapeInferenceOptions options{true, 1, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, OpSchemaRegistry::Instance(), options);
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,376 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/common/constants.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
#include "onnx/defs/schema.h"
#include "onnx/inliner/inliner.h"
#include "onnx/shape_inference/implementation.h"
namespace ONNX_NAMESPACE {
namespace Test {
static void InlineFunctions(ModelProto& model, const char* input, const inliner::FunctionIdSet* to_inline = nullptr) {
OnnxParser parser(input);
auto status = parser.Parse(model);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
checker::check_model(model, false, true);
shape_inference::InferShapes(model);
// std::cout << "Before inlining:\n" << ProtoToString(model) << "\n";
if (to_inline != nullptr)
inliner::InlineSelectedFunctions(model, *to_inline);
else
inliner::InlineLocalFunctions(model, true);
// std::cout << "After inlining:\n" << ProtoToString(model) << "\n";
// The following will ensure basic safety checks hold after inlining, including
// absence of duplicate names (multiple assignments to same name).
checker::check_model(model, true, true);
}
TEST(FunctionInliner, BasicTest) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (float[N, 128] X, float[128,10] W, float[10] B) => (float[N, 10] C)
{
T = local.foo (X, W, B)
C = local.square(T)
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function foo."
>
foo (x, w, b) => (c) {
T = MatMul(x, w)
S = Add(T, b)
c = Softmax(S)
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
auto num_nodes = model.graph().node_size();
ASSERT_EQ(num_nodes, 4);
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}
// Test that inlining processes subgraphs.
TEST(FunctionInliner, SubgraphTest) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (bool cond, float[N] X) => (float[N] Y)
{
Y = If (cond) <
then_branch = then_graph () => (y) {
y = local.square (X)
},
else_branch = else_graph () => (y) {
y = local.square (X)
}
>
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
auto& if_node = model.graph().node(0);
auto& graph1 = if_node.attribute(0).g();
ASSERT_EQ(graph1.node(0).op_type(), "Mul");
auto& graph2 = if_node.attribute(1).g();
ASSERT_EQ(graph2.node(0).op_type(), "Mul");
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}
TEST(FunctionInliner, Nested) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
Y = local.foo (X)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
foo (x) => (y) {
temp = Add(x, x)
y = local.bar(temp)
}
<opset_import: [ "" : 17 ], domain: "local">
bar (x) => (y) {
y = Mul (x, x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
auto num_nodes = model.graph().node_size();
ASSERT_EQ(num_nodes, 2);
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}
TEST(FunctionInliner, Renaming) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
temp = local.foo (X)
temp__1 = Mul (temp, temp)
Y = Abs (temp__1)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
foo (x) => (y) {
temp = Add(x, x)
y = Neg (temp)
}
)ONNX";
ModelProto model;
// Check that renaming handles accidental collision of names: when "temp" in "foo" is
// inlined, it will be renamed into something distinct from "temp" and "temp__1" as
// both these names occur in the main graph.
InlineFunctions(model, code);
}
TEST(FunctionInliner, ValueInfoPropagation) {
const char* code = R"ONNX(
<ir_version: 10, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
result = local.foo (X)
Y = Abs (result)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
foo (x) => (y)
<float[N] temp> {
temp = Add(x, x)
y = Neg (temp)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
// Check that valueinfo is propagated fron function to main graph.
auto& graph = model.graph();
auto& temp_new_name = graph.node(0).output(0);
auto& valueinfos = graph.value_info();
for (auto& valueinfo : valueinfos) {
if (valueinfo.name() == temp_new_name) {
ASSERT_TRUE(valueinfo.has_type());
ASSERT_TRUE(valueinfo.type().has_tensor_type());
ASSERT_TRUE(valueinfo.type().tensor_type().has_shape());
ASSERT_TRUE(valueinfo.type().tensor_type().shape().dim_size() == 1);
return;
}
}
ASSERT_TRUE(false) << "ValueInfo not found";
}
TEST(FunctionInliner, TwoCallsToSameFunction) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
temp = local.foo (X)
Y = local.foo (temp)
}
<opset_import: [ "" : 17, "local" : 1 ], domain: "local">
foo (x) => (y) {
temp = Add(x, x)
y = Neg (temp)
}
)ONNX";
ModelProto model;
// The call below will check that multiple assignments to same name does not happen
// after inlining two calls to same function.
InlineFunctions(model, code);
}
TEST(FunctionInliner, OpsetMismatch) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
temp = local.foo (X)
Y = local.bar (temp)
}
<opset_import: [ "" : 18], domain: "local">
foo (x) => (y) {
y = Add(x, x)
}
<opset_import: [ "" : 17], domain: "local">
bar (x) => (y) {
y = Add(x, x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
// The first node's call, to foo, must be inlined.
auto& first_node = model.graph().node(0);
// Check that it is a call to Add
ASSERT_EQ(first_node.op_type(), "Add");
// The second node's call, to bar, must be inlined.
auto& second_node = model.graph().node(1);
// Check that it is a call to Add
ASSERT_EQ(second_node.op_type(), "Add");
ASSERT_EQ(model.functions_size(), 0);
}
TEST(FunctionInliner, SelectiveInlining) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
agraph (float[N] X) => (float[N] Y)
{
temp = local.foo (X)
Y = local.bar (temp)
}
<opset_import: [ "" : 17], domain: "local">
foo (x) => (y) {
y = Add(x, x)
}
<opset_import: [ "" : 17, "local" : 1], domain: "local">
bar (x) => (y) {
y = local.foo(x)
}
)ONNX";
ModelProto model;
inliner::FunctionIdVector to_inline = {{"local", "foo"}};
auto to_inline_set = inliner::FunctionIdSet::Create(std::move(to_inline));
InlineFunctions(model, code, to_inline_set.get());
// The first node's call, to foo, must be inlined.
auto& first_node = model.graph().node(0);
// Check that it is a call to Add
ASSERT_EQ(first_node.op_type(), "Add");
// The second node's call, to bar, must not be inlined.
auto& second_node = model.graph().node(1);
// Check that it is a call to bar
ASSERT_EQ(second_node.op_type(), "bar");
// foo will be removed, bar will remain, in model.functions()
ASSERT_EQ(model.functions_size(), 1);
auto& bar_node = model.functions(0).node(0);
// Check that it is a call to Add, due to inlining
// the call to foo in bar.
ASSERT_EQ(bar_node.op_type(), "Add");
}
TEST(FunctionInliner, VersionConversion) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 18, "local" : 1 ]>
agraph (float[N,M] X) => (float[N,M] Y)
{
Y = local.foo (X)
}
<opset_import: [ "" : 17], domain: "local">
foo (x) => (y) {
y = ReduceLogSum <axes = [0]> (x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
// Inlining ReduceLogSum (version 17) should convert it to ReduceLogSum (version 18)
// by promoting axes from attribute to input.
auto& node = model.graph().node(1);
ASSERT_EQ(node.op_type(), "ReduceLogSum");
ASSERT_EQ(node.input_size(), 2);
ASSERT_EQ(node.attribute_size(), 0);
}
TEST(FunctionInliner, NestedVersionConversion) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 18, "local" : 1 ]>
agraph (float[N,M] X) => (float[N,M] Y)
{
Y = local.foo (X)
}
<opset_import: [ "" : 17, "local" : 1], domain: "local">
foo (x) => (y) {
t = ReduceLogSum <axes = [0]> (x)
y = local.bar (t)
}
<opset_import: [ "" : 17], domain: "local">
bar (x) => (y) {
y = ReduceLogSum <axes = [1]> (x)
}
)ONNX";
ModelProto model;
InlineFunctions(model, code);
// Inlining ReduceLogSum (version 17) should convert it to ReduceLogSum (version 18)
// by promoting axes from attribute to input, with a preceding Constant node for
// the axes value.
// Check that both ReduceLogSum nodes have been converted.
ASSERT_EQ(model.graph().node_size(), 4);
ASSERT_EQ(model.graph().node(0).op_type(), "Constant");
auto& node = model.graph().node(1);
ASSERT_EQ(node.op_type(), "ReduceLogSum");
ASSERT_EQ(node.input_size(), 2);
ASSERT_EQ(node.attribute_size(), 0);
ASSERT_EQ(model.graph().node(2).op_type(), "Constant");
auto node2 = model.graph().node(3);
ASSERT_EQ(node2.op_type(), "ReduceLogSum");
ASSERT_EQ(node2.input_size(), 2);
ASSERT_EQ(node2.attribute_size(), 0);
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,60 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/common/ir.h"
#include "onnx/common/ir_pb_converter.h"
#include "onnx/defs/printer.h"
namespace ONNX_NAMESPACE {
namespace Test {
static bool IsValidIdentifier(const std::string& name) {
if (name.empty()) {
return false;
}
if (!isalpha(name[0]) && name[0] != '_') {
return false;
}
for (size_t i = 1; i < name.size(); ++i) {
if (!isalnum(name[i]) && name[i] != '_') {
return false;
}
}
return true;
}
TEST(IR, ValidIdentifierTest) {
Graph* g = new Graph();
g->setName("test");
Value* x = g->addInput();
x->setUniqueName("x");
x->setElemType(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
x->setSizes({Dimension("M"), Dimension("N")});
Node* node1 = g->create(kNeg, 1);
node1->addInput(x);
g->appendNode(node1);
Value* temp1 = node1->outputs()[0];
Node* node2 = g->create(kNeg, 1);
node2->addInput(temp1);
g->appendNode(node2);
Value* y = node2->outputs()[0];
g->registerOutput(y);
ModelProto model;
ExportModelProto(&model, std::shared_ptr<Graph>(g));
for (auto& node : model.graph().node()) {
for (auto& name : node.output()) {
EXPECT_TRUE(IsValidIdentifier(name));
}
}
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,28 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
namespace Test {
TEST(OpRegistrationTest, GemmOp) {
auto opSchema = OpSchemaRegistry::Schema("Gemm");
EXPECT_TRUE(nullptr != opSchema);
size_t input_size = opSchema->inputs().size();
EXPECT_EQ(input_size, 3);
EXPECT_EQ(opSchema->inputs()[0].GetTypes(), opSchema->outputs()[0].GetTypes());
size_t attr_size = opSchema->attributes().size();
EXPECT_EQ(attr_size, 4);
EXPECT_NE(opSchema->attributes().count("alpha"), 0);
EXPECT_EQ(opSchema->attributes().at("alpha").type, AttributeProto_AttributeType_FLOAT);
EXPECT_NE(opSchema->attributes().count("beta"), 0);
EXPECT_EQ(opSchema->attributes().at("beta").type, AttributeProto_AttributeType_FLOAT);
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,667 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "gtest/gtest.h"
#include "onnx/checker.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
namespace Test {
template <typename T>
static void Parse(T& parsedData, const char* input) {
OnnxParser parser(input);
auto status = parser.Parse(parsedData);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
// Extra checks for printer:
// Check we can convert data back to text form.
std::string text1 = ProtoToString(parsedData);
// Check that we can round-trip between the two representations.
// We cannot expect equality between text1 and input due to white-space and syntactic sugar,
// so, we convert it once more, and check for equality.
T temp;
status = OnnxParser::Parse(temp, text1.c_str());
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
std::string text2 = ProtoToString(temp);
EXPECT_EQ(text1, text2);
}
template <typename T>
static void ExpectParseFailure(T& parsedData, const char* input) {
auto status = OnnxParser::Parse(parsedData, input);
EXPECT_FALSE(status.IsOK());
}
static void CheckModel(const char* code) {
ModelProto model;
Parse(model, code);
checker::check_model(model);
}
TEST(ParserTest, EscapeStringLiteral) {
OnnxParser parser(R"(
"123\"56\\89"
)");
std::string s;
auto status = parser.ParserBase::Parse(s);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
EXPECT_EQ(s, std::string("123\"56\\89"));
}
TEST(ParserTest, TypeTest) {
TypeProto type;
// 1-dimensional tensor type with symbolic dimension:
Parse(type, "float[N]");
EXPECT_TRUE(type.has_tensor_type());
int float_type = static_cast<int>(TensorProto_DataType::TensorProto_DataType_FLOAT);
int int32_type = static_cast<int>(TensorProto_DataType::TensorProto_DataType_INT32);
EXPECT_EQ(type.tensor_type().elem_type(), float_type);
EXPECT_TRUE(type.tensor_type().has_shape());
EXPECT_EQ(type.tensor_type().shape().dim_size(), 1);
EXPECT_EQ(type.tensor_type().shape().dim(0).dim_param(), "N");
// scalar type:
Parse(type, "float");
EXPECT_TRUE(type.has_tensor_type());
EXPECT_EQ(type.tensor_type().elem_type(), float_type);
EXPECT_TRUE(type.tensor_type().has_shape());
EXPECT_EQ(type.tensor_type().shape().dim_size(), 0);
// tensor type with unknown rank:
Parse(type, "float[]");
EXPECT_TRUE(type.has_tensor_type());
EXPECT_EQ(type.tensor_type().elem_type(), float_type);
EXPECT_FALSE(type.tensor_type().has_shape());
// 3-dimensional tensor
Parse(type, "float[N,M,K]");
EXPECT_EQ(type.tensor_type().shape().dim_size(), 3);
// Unspecified dimension (neither symbolic nor constant)
Parse(type, "float[N,?,K]");
EXPECT_FALSE(type.tensor_type().shape().dim(1).has_dim_param());
EXPECT_FALSE(type.tensor_type().shape().dim(1).has_dim_value());
// sequence type:
Parse(type, "seq(float[])");
EXPECT_TRUE(type.has_sequence_type());
auto& elttype = type.sequence_type().elem_type();
EXPECT_TRUE(elttype.has_tensor_type());
EXPECT_EQ(elttype.tensor_type().elem_type(), float_type);
EXPECT_FALSE(elttype.tensor_type().has_shape());
// optional type:
Parse(type, "optional(float)");
EXPECT_TRUE(type.has_optional_type());
auto& optelttype = type.optional_type().elem_type();
EXPECT_TRUE(optelttype.has_tensor_type());
EXPECT_EQ(optelttype.tensor_type().elem_type(), float_type);
EXPECT_TRUE(optelttype.tensor_type().has_shape());
// optional type:
Parse(type, "sparse_tensor(float[1000])");
EXPECT_TRUE(type.has_sparse_tensor_type());
EXPECT_EQ(type.sparse_tensor_type().elem_type(), float_type);
EXPECT_EQ(type.sparse_tensor_type().shape().dim_size(), 1);
// map type:
Parse(type, "map(int32, float[N])");
EXPECT_TRUE(type.has_map_type());
EXPECT_EQ(type.map_type().key_type(), int32_type);
auto& valtype = type.map_type().value_type();
EXPECT_TRUE(valtype.has_tensor_type());
EXPECT_EQ(valtype.tensor_type().elem_type(), float_type);
EXPECT_EQ(valtype.tensor_type().shape().dim_size(), 1);
}
TEST(ParserTest, TensorProtoTest) {
TensorProto tensorProto;
// Concrete tensor-type with numeric dimensions expected:
ExpectParseFailure(tensorProto, "int32[] {1, 2, 3, 4, 5}");
// Symbolic dimensions are not allowed.
ExpectParseFailure(tensorProto, "int32[N] {1, 2, 3, 4, 5}");
Parse(tensorProto, "int32[5] {1, 2, 3, 4, 5}");
Parse(tensorProto, "int32[5] T {1, 2, 3, 4, 5}");
EXPECT_EQ(tensorProto.name(), "T");
Parse(tensorProto, "float[5] {1, 2.0, 3.1, 4, 5.5}");
Parse(tensorProto, "float[5] {1e1, 2.0e-1, 3.1E-1, 4E+1, 5.5e-10}");
Parse(tensorProto, "string[2] { \"Hello\", \"World\" }");
// String literals with escape character
Parse(tensorProto, R"(
string[2] { "Use a \"quoted\" word", "Use a backslash \\ like this." }
)");
}
TEST(ParserTest, AttributeTest) {
AttributeProto attr;
Parse(attr, "x = 2");
EXPECT_EQ(attr.name(), "x");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INT);
EXPECT_EQ(attr.i(), 2);
Parse(attr, "x = 0.625");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT);
EXPECT_FLOAT_EQ(attr.f(), 0.625);
Parse(attr, "x = [2, 4, 6]");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
EXPECT_EQ(attr.ints_size(), 3);
Parse(attr, "x = [0.125, 0.625]");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS);
EXPECT_EQ(attr.floats_size(), 2);
Parse(attr, "x = float[3] {2.1, 4.1, 6.1}");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
Parse(attr, "x = \"astring\"");
EXPECT_EQ(attr.name(), "x");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_STRING);
EXPECT_EQ(attr.s(), "astring");
Parse(attr, "x = [\"abc\", \"def\"]");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS);
Parse(attr, "x : ints = @xyz");
EXPECT_EQ(attr.ref_attr_name(), "xyz");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
Parse(attr, "x : ints = []");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_INTS);
EXPECT_EQ(attr.ints_size(), 0);
Parse(attr, R"ONNX(
body = somegraph (float[N] y, float[N] z) => (float[N] w)
{
x = foo(y, z)
w = bar(x, y)
}
)ONNX");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH);
EXPECT_EQ(attr.g().node_size(), 2);
Parse(attr, "type = float[3]");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TYPE_PROTO);
EXPECT_TRUE(attr.tp().has_tensor_type());
int float_type = static_cast<int>(TensorProto_DataType::TensorProto_DataType_FLOAT);
EXPECT_EQ(attr.tp().tensor_type().elem_type(), float_type);
}
TEST(ParserTest, AttrListTest) {
const char* code = R"ONNX(
<
x = 2,
w = 3
>
)ONNX";
AttrList attributes;
Parse(attributes, code);
EXPECT_EQ(attributes.size(), 2);
EXPECT_EQ(attributes.Get(0).name(), "x");
EXPECT_EQ(attributes.Get(1).name(), "w");
}
TEST(ParserTest, DomainOpCallTest) {
const char* code = "x = somedomain.foo(y, z)";
NodeProto n;
Parse(n, code);
}
TEST(ParserTest, NodeTest) {
const char* code = "x = foo(y, z)";
NodeProto n;
Parse(n, code);
EXPECT_EQ(n.input_size(), 2);
EXPECT_EQ(n.input(0), "y");
EXPECT_EQ(n.input(1), "z");
EXPECT_EQ(n.output_size(), 1);
EXPECT_EQ(n.output(0), "x");
EXPECT_EQ(n.op_type(), "foo");
NodeList nl;
Parse(nl, R"ONNX(
{
sub_result = Sub(limit, start)
sub_result_casted = Cast<to = 1>(sub_result)
delta_casted = Cast<to = 1>(delta)
div_result = Div(sub_result_casted, delta_casted)
ceil_result = Ceil(div_result)
ceil_result_relu = Relu(ceil_result)
ceil_result_relu_int = Cast<to = 7>(ceil_result_relu)
ceil_result_relu_bool = Cast<to = 9>(ceil_result_relu)
variadic_output, output = Loop (ceil_result_relu_int, ceil_result_relu_bool, start)
}
)ONNX");
}
TEST(ParserTest, QualifiedOpNameTest) {
const char* code = "x = com.example.foo(y, z)";
NodeProto n;
Parse(n, code);
EXPECT_EQ(n.domain(), "com.example");
EXPECT_EQ(n.op_type(), "foo");
}
TEST(ParserTest, NodeListTest) {
const char* code = R"ONNX(
{
x = foo(y, z)
w = bar(x, y)
}
)ONNX";
GraphProto graph;
Parse(*graph.mutable_node(), code);
EXPECT_EQ(graph.node_size(), 2);
EXPECT_EQ(graph.node(0).op_type(), "foo");
EXPECT_EQ(graph.node(1).op_type(), "bar");
}
TEST(ParserTest, NodeAttrTest1) {
const char* code = "x = foo <a = 100, b = 200.5, c = \"astring\"> (y, z)";
NodeProto n;
Parse(n, code);
EXPECT_EQ(n.attribute_size(), 3);
EXPECT_EQ(n.attribute(0).name(), "a");
EXPECT_EQ(n.attribute(1).name(), "b");
EXPECT_EQ(n.attribute(2).name(), "c");
}
TEST(ParserTest, NodeAttrTest2) {
const char* code = "x = foo <d = [5, 10], e = [0.55, 0.66], f = [\"str1\", \"str2\"]> (y, z)";
NodeProto n;
Parse(n, code);
EXPECT_EQ(n.attribute_size(), 3);
}
TEST(ParserTest, GraphTest) {
const char* code = R"ONNX(
agraph (float[N] y, float[N] z) => (float[N] w)
<float[2] w1 = {1.0, 2.0}, float[3] w2 = {4.0, 5.0, 6.0}, float[N] x>
{
# This is a comment.
x = foo(y, z, w1) # More comments.
w = bar(x, y, w2)
}
)ONNX";
GraphProto graph;
Parse(graph, code);
EXPECT_EQ(graph.name(), "agraph");
EXPECT_EQ(graph.input_size(), 2);
EXPECT_EQ(graph.output_size(), 1);
EXPECT_EQ(graph.node_size(), 2);
EXPECT_EQ(graph.initializer_size(), 2);
EXPECT_EQ(graph.value_info_size(), 1);
}
TEST(ParserTest, GraphPartialTypeTest) {
const char* code = R"ONNX(
agraph (float[N] y, z) => (float[N] w)
{
x = foo(y, z)
w = bar(x, y)
}
)ONNX";
GraphProto graph;
Parse(graph, code);
EXPECT_EQ(graph.name(), "agraph");
EXPECT_EQ(graph.input_size(), 2);
EXPECT_EQ(graph.output_size(), 1);
}
TEST(ParserTest, FunctionTest) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (y, z) => (w)
{
x = Add(y, z)
w = Mul(x, y)
}
)ONNX";
FunctionProto fp;
Parse(fp, code);
EXPECT_EQ(fp.name(), "f");
EXPECT_EQ(fp.input_size(), 2);
EXPECT_EQ(fp.output_size(), 1);
EXPECT_EQ(fp.node_size(), 2);
EXPECT_EQ(fp.attribute_size(), 0);
EXPECT_EQ(fp.opset_import_size(), 1);
}
TEST(ParserTest, FunctionValueInfoTest) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, float[N] z) => (float[N] w)
{
x = Add(y, z)
w = Mul(x, y)
}
)ONNX";
FunctionProto fp;
Parse(fp, code);
EXPECT_EQ(fp.input_size(), 2);
EXPECT_EQ(fp.output_size(), 1);
ASSERT_EQ(fp.value_info_size(), 3);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "z");
EXPECT_EQ(fp.value_info(2).name(), "w");
}
TEST(ParserTest, FunctionValueInfoTest2) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, float[N] z) => (float[N] w)
<float[N] x>
{
x = Add(y, z)
w = Mul(x, y)
}
)ONNX";
FunctionProto fp;
Parse(fp, code);
EXPECT_EQ(fp.input_size(), 2);
EXPECT_EQ(fp.value_info_size(), 4);
ASSERT_EQ(fp.output_size(), 1);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "z");
EXPECT_EQ(fp.value_info(2).name(), "w");
EXPECT_EQ(fp.value_info(3).name(), "x");
}
TEST(ParserTest, FunctionValueInfoTest3) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, z) => (float[N] w)
<float[N] x, float[N] t>
{
x = Add(y, z)
t = Add(x, x)
w = Mul(t, y)
}
)ONNX";
FunctionProto fp;
Parse(fp, code);
EXPECT_EQ(fp.input_size(), 2);
ASSERT_EQ(fp.value_info_size(), 4);
EXPECT_EQ(fp.output_size(), 1);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "w");
EXPECT_EQ(fp.value_info(2).name(), "x");
EXPECT_EQ(fp.value_info(3).name(), "t");
}
TEST(ParserTest, InitializerTest) {
const char* code = R"ONNX(
agraph (float y = {1.0}, float[N] z) => (float[N] w)
<float[2] w1 = {1.0, 2.0}, float[3] w2 = {4.0, 5.0, 6.0}, float[N] x>
{
x = foo(y, z, w1)
w = bar(x, y, w2)
}
)ONNX";
GraphProto graph;
Parse(graph, code);
EXPECT_EQ(graph.input_size(), 2);
EXPECT_EQ(graph.output_size(), 1);
EXPECT_EQ(graph.initializer_size(), 3); // y, w1, w2
EXPECT_EQ(graph.value_info_size(), 1); // x
}
TEST(ParserTest, IfNodeTest) {
const char* code = R"ONNX(
z = If (b) <
then_branch = g1 () => (float[N] z_then)
{
z_then = foo(y)
},
else_branch = g2 () => (float[N] z_else)
{
z_else = bar(x)
}
>
)ONNX";
NodeProto node;
Parse(node, code);
EXPECT_EQ(node.input_size(), 1);
EXPECT_EQ(node.output_size(), 1);
EXPECT_EQ(node.attribute_size(), 2);
}
TEST(ParserTest, ModelTest) {
const char* code = R"ONNX(
<
ir_version: 7,
opset_import: [ "ai.onnx.ml" : 10 ],
producer_name: "ParserTest",
producer_version: "1.0",
domain: "ai.onnx.ml",
model_version: 1,
doc_string: "A parser test case model.",
metadata_props: [ "somekey" : "somevalue", "key2" : "value2" ]
>
agraph (float[N] y, float[N] z) => (float[N] w)
{
x = foo(y, z)
w = bar(x, y)
}
)ONNX";
ModelProto model;
Parse(model, code);
EXPECT_EQ(model.graph().input_size(), 2);
EXPECT_EQ(model.graph().output_size(), 1);
EXPECT_EQ(model.graph().node_size(), 2);
}
TEST(ParserTest, ModelCheckTest) {
const char* code = R"ONNX(
<
ir_version: 7,
opset_import: [ "" : 10 ]
>
agraph (float[N, 128] X, float[128,10] W, float[10] B) => (float[N] C)
{
T = MatMul(X, W)
S = Add(T, B)
C = Softmax(S)
}
)ONNX";
CheckModel(code);
}
TEST(ParserTest, IfModelTest) {
const char* code = R"ONNX(
<
ir_version: 7,
opset_import: [ "" : 13 ]
>
iftest (bool b, float[128] X, float[128] Y) => (float[128] Z)
{
Z = If (b) <
then_branch = g1 () => (float[128] z_then) { z_then = Identity(X) },
else_branch = g2 () => (float[128] z_else) { z_else = Identity(Y) }
>
}
)ONNX";
CheckModel(code);
}
TEST(ParserTest, FunModelTest) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (float[N, 128] X, float[128,10] W, float[10] B) => (float[N] C)
{
T = local.foo (X, W, B)
C = local.square(T)
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function foo."
>
foo (x, w, b) => (c) {
T = MatMul(x, w)
S = Add(T, b)
c = Softmax(S)
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";
CheckModel(code);
const char* code_function_with_attributes = R"ONNX(
<
ir_version: 9,
opset_import: [ "" : 15, "custom_domain" : 1]
>
agraph (float[N] x) => (float[N] out)
{
out = custom_domain.foo<alpha=2.0, gamma=3.0>(x)
}
<
domain: "custom_domain",
opset_import: [ "" : 15],
doc_string: "function foo"
>
foo
<alpha: float=4.0, gamma>
(X) => (C)
{
constant_alpha = Constant<value_float: float=@alpha>()
constant_gamma = Constant<value_float: float=@gamma>()
constant_alpha_x = Mul(constant_alpha, X)
C = Add(constant_alpha_x, constant_gamma)
}
)ONNX";
CheckModel(code_function_with_attributes);
}
TEST(ParserTest, TypesModelTest1) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 18 ]
>
agraph (seq(float[N]) seqX) => (float[M, N] X)
{
X = ConcatFromSequence < axis = 0, new_axis = 1 >(seqX)
}
)ONNX";
CheckModel(code);
}
TEST(ParserTest, TypesModelTest2) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 18 ]
>
agraph (float[N] tensorX, seq(float[N]) seqX, map(int32, float[N]) mapX, optional(float[N]) optionalX, sparse_tensor(float[N]) sparseX) => (float[N] X)
{
X = Identity (tensorX)
}
)ONNX";
CheckModel(code);
}
TEST(ParserTest, ExternalDataTest) {
const char* code = R"ONNX(
agraph (float y = {1.0}, float[N] z) => (w) <
float[3, 2] m1 = ["location": "weight_1.bin", "offset": "17"],
float[2, 1] m2 = {1.0, 2.0}
>
{
x = Add(y, z)
m = Mul(m1, m1)
}
)ONNX";
GraphProto graph;
Parse(graph, code);
EXPECT_EQ(graph.input_size(), 2);
EXPECT_EQ(graph.output_size(), 1);
EXPECT_EQ(graph.initializer_size(), 3); // m1, m2
EXPECT_EQ(graph.value_info_size(), 0); // x
EXPECT_EQ(graph.initializer().Get(1).data_location(), TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).key(), "location");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(0).value(), "weight_1.bin");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).key(), "offset");
EXPECT_EQ(graph.initializer().Get(1).external_data().Get(1).value(), "17");
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,275 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/schema.h"
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
namespace Test {
TEST(SchemaRegistrationTest, DisabledOnnxStaticRegistrationAPICall) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
EXPECT_TRUE(IsOnnxStaticRegistrationDisabled());
#else
EXPECT_FALSE(IsOnnxStaticRegistrationDisabled());
#endif
}
// Schema of all versions are registered by default
// Further schema manipulation expects to be error-free
TEST(SchemaRegistrationTest, RegisterAllByDefaultAndManipulateSchema) {
#ifndef __ONNX_DISABLE_STATIC_REGISTRATION
// Expects all opset registered by default
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);
// Should find schema for all versions
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
// Clear all opset schema registration
DeregisterOnnxOperatorSetSchema();
// Should not find any opset
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));
// Register all opset versions
RegisterOnnxOperatorSetSchema();
// Should find all opset
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
#endif
}
// By default ONNX registers all opset versions and selective schema loading cannot be tested
// So these tests are run only when static registration is disabled
TEST(SchemaRegistrationTest, RegisterAndDeregisterAllOpsetSchemaVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
// Clear all opset schema registration
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Should not find schema for any op
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Acos"));
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Trilu"));
// Register all opset versions
RegisterOnnxOperatorSetSchema(0);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);
// Should find schema for all ops. Available versions are:
// Acos-7
// Add-1,6,7,13,14
// Trilu-14
auto schema = OpSchemaRegistry::Schema("Acos", 7);
EXPECT_NE(nullptr, schema);
EXPECT_EQ(schema->SinceVersion(), 7);
schema = OpSchemaRegistry::Schema("Add", 14);
EXPECT_NE(nullptr, schema);
EXPECT_EQ(schema->SinceVersion(), 14);
schema = OpSchemaRegistry::Schema("Trilu");
EXPECT_NE(nullptr, schema);
EXPECT_EQ(schema->SinceVersion(), 14);
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
// Clear all opset schema registration
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Should not find schema for any op
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Acos"));
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Add"));
EXPECT_EQ(nullptr, OpSchemaRegistry::Schema("Trilu"));
#endif
}
TEST(SchemaRegistrationTest, RegisterSpecifiedOpsetSchemaVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
RegisterOnnxOperatorSetSchema(13);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 13);
auto opSchema = OpSchemaRegistry::Schema("Add");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 13);
// Should not find opset 12
opSchema = OpSchemaRegistry::Schema("Add", 12);
EXPECT_EQ(nullptr, opSchema);
// Should not find opset 14
opSchema = OpSchemaRegistry::Schema("Trilu");
EXPECT_EQ(nullptr, opSchema);
// Acos-7 is the latest Acos before specified 13
opSchema = OpSchemaRegistry::Schema("Acos", 13);
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
#endif
}
// Regsiter opset-11, then opset-14
// Expects Reg(11, 14) == Reg(11) U Reg(14)
TEST(SchemaRegistrationTest, RegisterMultipleOpsetSchemaVersions_UpgradeVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Register opset 11
RegisterOnnxOperatorSetSchema(11);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);
// Register opset 14
// Do not fail on duplicate schema registration request
RegisterOnnxOperatorSetSchema(14, false);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 14);
// Acos-7 is the latest before/at opset 11 and 14
auto opSchema = OpSchemaRegistry::Schema("Acos");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
// Add-7 is the latest before/at opset 11
// Add-14 is the latest before/at opset 14
// Should find both Add-7,14
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 14));
// Should find the max version 14
opSchema = OpSchemaRegistry::Schema("Add");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 14);
// Should find Add-7 as the max version <=13
opSchema = OpSchemaRegistry::Schema("Add", 13);
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
// Should find opset 14
opSchema = OpSchemaRegistry::Schema("Trilu");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 14);
#endif
}
// Regsiter opset-14, then opset-11
// Expects Reg(14, 11) == Reg(11) U Reg(14)
TEST(SchemaRegistrationTest, RegisterMultipleOpsetSchemaVersions_DowngradeVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Register opset 14
RegisterOnnxOperatorSetSchema(14);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 14);
// Register opset 11
// Do not fail on duplicate schema registration request
RegisterOnnxOperatorSetSchema(11, false);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);
// Acos-7 is the latest before/at opset 11 and 14
auto opSchema = OpSchemaRegistry::Schema("Acos");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
// Add-7 is the latest before/at opset 11
// Add-14 is the latest before/at opset 14
// Should find both Add-7,14
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 14));
// Should find the max version 14
opSchema = OpSchemaRegistry::Schema("Add");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 14);
// Should find Add-7 as the max version <=13
opSchema = OpSchemaRegistry::Schema("Add", 13);
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 7);
// Should find opset 14
opSchema = OpSchemaRegistry::Schema("Trilu");
EXPECT_NE(nullptr, opSchema);
EXPECT_EQ(opSchema->SinceVersion(), 14);
#endif
}
// Register opset-11, then all versions
// Expects no error
TEST(SchemaRegistrationTest, RegisterSpecificThenAllVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Register opset 11
RegisterOnnxOperatorSetSchema(11);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);
// Register all opset versions
// Do not fail on duplicate schema registration request
RegisterOnnxOperatorSetSchema(0, false);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);
// Should find schema for all ops
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Acos"));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Trilu"));
// Should find schema for all versions
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
#endif
}
// Register all versions, then opset 11
// Expects no error
TEST(SchemaRegistrationTest, RegisterAllThenSpecificVersion) {
#ifdef __ONNX_DISABLE_STATIC_REGISTRATION
DeregisterOnnxOperatorSetSchema();
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == -1);
// Register all opset versions
RegisterOnnxOperatorSetSchema(0);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 0);
// Register opset 11
// Do not fail on duplicate schema registration request
RegisterOnnxOperatorSetSchema(11, false);
EXPECT_TRUE(OpSchemaRegistry::Instance()->GetLoadedSchemaVersion() == 11);
// Should find schema for all ops
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Acos"));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add"));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Trilu"));
// Should find schema for all versions
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 1));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 6));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 7));
EXPECT_NE(nullptr, OpSchemaRegistry::Schema("Add", 13));
#endif
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,660 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "onnx/onnx_pb.h"
#include "onnx/shape_inference/implementation.h"
using namespace ONNX_NAMESPACE::shape_inference;
namespace ONNX_NAMESPACE {
// onnx/defs/controlflow/old.cc
void ScanInferenceFunctionOpset8(InferenceContext& ctx);
// onnx/defs/controlflow/defs.cc
void ScanInferenceFunction(InferenceContext& ctx);
namespace Test {
template <class Type>
void CreateDims(Type& proto, int num_dims) {
auto mutable_shape = proto.mutable_shape();
mutable_shape->clear_dim();
for (int i = 0; i < num_dims; ++i)
mutable_shape->add_dim();
}
template <class Type>
void SetDimValues(Type& proto, const std::vector<int>& values) {
auto* mutable_shape = proto.mutable_shape();
EXPECT_TRUE(mutable_shape->dim_size() == values.size());
int idx = 0;
for (auto value : values) {
auto mutable_dim = mutable_shape->mutable_dim(idx++);
if (value != -1)
mutable_dim->set_dim_value(value);
}
}
template <class Type>
void SetDimParams(Type& proto, const std::vector<const std::string*>& values) {
auto mutable_shape = proto.mutable_shape();
EXPECT_TRUE(mutable_shape->dim_size() == values.size());
int idx = 0;
for (auto value : values) {
auto mutable_dim = mutable_shape->mutable_dim(idx++);
if (value)
mutable_dim->set_dim_param(*value);
}
}
template <class Type>
void Dump(const Type& t) {
auto& s_shape = t.shape();
auto num_dims = s_shape.dim_size();
std::cout << num_dims << " dims. ";
for (int i = 0; i < num_dims; ++i) {
auto x = s_shape.dim(0);
auto y = x.has_dim_value();
auto z = x.has_dim_param();
std::cout << "Dim " << i << " Value:" << (y ? ONNX_NAMESPACE::to_string(x.dim_value()) : "<unset>")
<< ", Param:" << (z ? x.dim_param() : "<unset>") << "\n";
}
};
TEST(ShapeInferenceTest, mergeShapeInfo_HasShape) {
// source has shape, target doesn't
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 1);
SetDimValues(source, {1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
// source has no shape, target does
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(target, 1);
SetDimValues(target, {1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
// source has shape, target doesn't
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(source, 1);
SetDimValues(source, {1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
// source has no shape, target does
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(target, 1);
SetDimValues(target, {1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
}
TEST(ShapeInferenceTest, mergeShapeInfo_PreferValueOverParam) {
std::string param = "A";
// source has value, target has param. prefer value
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 1);
SetDimValues(source, {1});
CreateDims(target, 1);
SetDimParams(target, {&param});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
// source has param, target has value.
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 1);
SetDimParams(source, {&param});
CreateDims(target, 1);
SetDimValues(target, {1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim_size() == 1 && shape.dim(0).dim_value() == 1);
}
}
TEST(ShapeInferenceTest, mergeShapeInfo_CombineShapes) {
// merge from both sides, preferring real value over -1
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 2);
SetDimValues(target, {1, -1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_value() == 1 && shape.dim(1).dim_value() == 2);
}
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 2);
SetDimValues(target, {1, -1});
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_value() == 1 && shape.dim(1).dim_value() == 2);
}
// prefer value over param,
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 2);
SetDimValues(target, {1, 0});
// replace second dim with a param. the value from the source should be
// preferred
const std::string param = "A";
target.mutable_shape()->mutable_dim(1)->set_dim_param(param);
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_value() == 1 && shape.dim(1).dim_value() == 2);
}
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 2);
SetDimValues(target, {1, 0});
// replace second dim with a param. the value from the source should be
// preferred
const std::string param = "A";
target.mutable_shape()->mutable_dim(1)->set_dim_param(param);
mergeInShapeInfo(source, target);
Dump(target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_value() == 1 && shape.dim(1).dim_value() == 2);
}
}
TEST(ShapeInferenceTest, mergeShapeInfo_Mismatches) {
#ifndef ONNX_NO_EXCEPTIONS
// mismatched num dims
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 3);
SetDimValues(target, {1, -1, 1});
EXPECT_THROW(mergeInShapeInfo(source, target), ONNX_NAMESPACE::InferenceError);
}
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(source, 2);
SetDimValues(source, {-1, 2});
CreateDims(target, 3);
SetDimValues(target, {1, -1, 1});
EXPECT_THROW(mergeInShapeInfo(source, target), ONNX_NAMESPACE::InferenceError);
}
// mismatched dim values
{
TypeProto_Tensor source;
TypeProto_Tensor target;
CreateDims(source, 2);
SetDimValues(source, {2, 2});
CreateDims(target, 2);
SetDimValues(target, {2, 1});
EXPECT_THROW(mergeInShapeInfo(source, target), ONNX_NAMESPACE::InferenceError);
}
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
CreateDims(source, 2);
SetDimValues(source, {2, 2});
CreateDims(target, 2);
SetDimValues(target, {2, 1});
EXPECT_THROW(mergeInShapeInfo(source, target), ONNX_NAMESPACE::InferenceError);
}
#endif
// mismatched param value. prefer target
{
TypeProto_Tensor source;
TypeProto_Tensor target;
const std::string param_a = "A";
const std::string param_b = "B";
CreateDims(source, 1);
SetDimParams(source, {&param_a});
CreateDims(target, 1);
SetDimParams(target, {&param_b});
mergeInShapeInfo(source, target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_param() == "B");
}
{
TypeProto_SparseTensor source;
TypeProto_SparseTensor target;
const std::string param_a = "A";
const std::string param_b = "B";
CreateDims(source, 1);
SetDimParams(source, {&param_a});
CreateDims(target, 1);
SetDimParams(target, {&param_b});
mergeInShapeInfo(source, target);
auto& shape = target.shape();
EXPECT_TRUE(shape.dim(0).dim_param() == "B");
}
}
// Check subgraph inferencing via GraphInferencer using a Scan
static void doInferencingTest(bool use_scan_opset8) {
auto* schemaRegistry = OpSchemaRegistry::Instance();
GraphProto subgraph;
// simple tensor without shape info
TypeProto simple_tensor_no_shape;
auto* tensor_type = simple_tensor_no_shape.mutable_tensor_type();
tensor_type->set_elem_type(TensorProto_DataType_FLOAT);
// simple tensor with shape info
TypeProto simple_tensor = simple_tensor_no_shape;
simple_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2);
// setup simple graph that can be used with Scan containing two Identity
// nodes. one for the loop state variable. one for the scan output.
{
NodeProto loop_state_identity;
loop_state_identity.set_name("loop_state_identity");
loop_state_identity.set_domain(ONNX_DOMAIN);
loop_state_identity.set_op_type("Identity");
loop_state_identity.set_doc_string("loop state identity");
loop_state_identity.add_input("loop_state_in");
loop_state_identity.add_output("loop_state_out");
*subgraph.add_node() = loop_state_identity;
NodeProto scan_in_out_identity;
scan_in_out_identity.set_name("scan_in_out_identity");
scan_in_out_identity.set_domain(ONNX_DOMAIN);
scan_in_out_identity.set_op_type("Identity");
scan_in_out_identity.set_doc_string("scan identity");
scan_in_out_identity.add_input("scan_in");
scan_in_out_identity.add_output("scan_out");
*subgraph.add_node() = scan_in_out_identity;
ValueInfoProto loop_state_in;
loop_state_in.set_name("loop_state_in");
*loop_state_in.mutable_type() = simple_tensor;
*subgraph.add_input() = loop_state_in;
ValueInfoProto scan_in;
scan_in.set_name("scan_in");
*scan_in.mutable_type() = simple_tensor;
*subgraph.add_input() = scan_in;
ValueInfoProto loop_state_out = loop_state_in;
loop_state_out.set_name("loop_state_out");
*loop_state_out.mutable_type() = simple_tensor_no_shape;
*subgraph.add_output() = loop_state_out;
ValueInfoProto scan_state_out = scan_in;
scan_state_out.set_name("scan_out");
*scan_state_out.mutable_type() = simple_tensor_no_shape;
*subgraph.add_output() = scan_state_out;
}
std::unordered_map<std::string, int> opset_imports;
opset_imports[ONNX_DOMAIN] = 8; // Scan is v8
const std::unordered_map<std::string, TypeProto*> outer_scope_value_types;
SymbolTableImpl symbolTable;
symbolTable.addFromGraph(subgraph);
GraphInferenceContext graphInfCtx(outer_scope_value_types, opset_imports, &symbolTable);
GraphInferencerImpl graphInferencer(subgraph, graphInfCtx);
// loop_state_in and scan_in are the two inputs.
// order in subgraphInputTypes matches their order as graph inputs.
std::vector<const TypeProto*> subgraphInputTypes = {&simple_tensor, &simple_tensor};
std::vector<const TensorProto*> subgraphInputData = {};
ShapeInferenceOptions options{false, 0, false};
auto output = graphInferencer.doInferencing(subgraphInputTypes, subgraphInputData);
// check the subgraph outputs had their shape inferred when we called
// doInferencing directly
EXPECT_TRUE(output.size() == 2);
auto checkType = [](const TypeProto& type, const TypeProto_Tensor& expect) {
auto checkDims = [](const TensorShapeProto& l, const TensorShapeProto& r) {
EXPECT_TRUE(l.dim_size() == r.dim_size());
for (int i = 0, end = l.dim_size(); i < end; ++i) {
// if (l.dim().Get(i).dim_value() != r.dim().Get(i).dim_value())
// break;
EXPECT_TRUE(l.dim().Get(i).dim_value() == r.dim().Get(i).dim_value());
}
};
EXPECT_TRUE(type.has_tensor_type());
EXPECT_TRUE(type.tensor_type().elem_type() == expect.elem_type());
checkDims(type.tensor_type().shape(), expect.shape());
};
checkType(*output[0], simple_tensor.tensor_type());
checkType(*output[1], simple_tensor.tensor_type());
// setup Scan node to test subgraph inferencing works as expected when called
// from the operators type/shape inferencing function
NodeProto scan;
{
AttributeProto num_scan_inputs;
num_scan_inputs.set_name("num_scan_inputs");
num_scan_inputs.set_i(1);
AttributeProto body;
body.set_name("body");
*body.mutable_g() = subgraph;
*scan.add_attribute() = num_scan_inputs;
*scan.add_attribute() = body;
scan.set_name("Scan");
scan.set_domain(ONNX_DOMAIN);
scan.set_doc_string("Scan node");
scan.set_op_type("Scan");
if (use_scan_opset8)
scan.add_input(""); // optional sequence lens
scan.add_input("loop_state_start");
scan.add_input("scan_op_in");
scan.add_output("loop_state_final");
scan.add_output("scan_op_out");
}
TypeProto loop_state_in_tensor = simple_tensor_no_shape;
auto* shape = loop_state_in_tensor.mutable_tensor_type()->mutable_shape();
if (use_scan_opset8)
shape->add_dim()->set_dim_value(1); // batch size
shape->add_dim()->set_dim_value(2); // input size. must match subgraph
TypeProto loop_state_out_tensor = loop_state_in_tensor; // should be unchanged
TypeProto scan_in_tensor = simple_tensor_no_shape;
shape = scan_in_tensor.mutable_tensor_type()->mutable_shape();
if (use_scan_opset8)
shape->add_dim()->set_dim_value(1); // batch size
shape->add_dim()->set_dim_value(1); // sequence length
shape->add_dim()->set_dim_value(2); // input size. must match subgraph
TypeProto scan_out_tensor = scan_in_tensor; // should be unchanged
std::unordered_map<std::string, TypeProto*> valueTypesByName;
valueTypesByName["loop_state_start"] = &loop_state_in_tensor;
valueTypesByName["scan_op_in"] = &scan_in_tensor;
InferenceContextImpl ctx(scan, valueTypesByName, {}, {}, options, {}, &graphInfCtx);
if (use_scan_opset8)
ScanInferenceFunctionOpset8(ctx);
else
ScanInferenceFunction(ctx);
EXPECT_TRUE(ctx.getNumOutputs() == 2);
checkType(*ctx.getOutputType(0), loop_state_out_tensor.tensor_type());
checkType(*ctx.getOutputType(1), scan_out_tensor.tensor_type());
}
// Check subgraph inferencing via GraphInferencer using a Scan (from opset 8)
TEST(GraphInferencerImplTest, Scan8_BasicTest) {
doInferencingTest(true);
}
// Check subgraph inferencing via GraphInferencer using a Scan (from opset 9)
TEST(GraphInferencerImplTest, Scan9_BasicTest) {
doInferencingTest(false);
}
void ParseAndInfer(ModelProto& model, const char* modelStr) {
OnnxParser parser(modelStr);
auto status = parser.Parse(model);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected.";
ShapeInferenceOptions options{true, 1, true};
ONNX_NAMESPACE::shape_inference::InferShapes(model, ONNX_NAMESPACE::OpSchemaRegistry::Instance(), options);
}
void RunReshapeShapeInfTest(const char* modelStr, TensorShapeProto& expectedShape) {
ModelProto model;
ParseAndInfer(model, modelStr);
const auto inferredShape = model.graph().output(0).type().tensor_type().shape();
EXPECT_TRUE(inferredShape.dim_size() == expectedShape.dim_size());
for (int i = 0; i < inferredShape.dim_size(); i++) {
EXPECT_TRUE(
(inferredShape.dim(i).has_dim_value() && expectedShape.dim(i).has_dim_value()) ||
(inferredShape.dim(i).has_dim_param() && expectedShape.dim(i).has_dim_param()));
EXPECT_TRUE(
inferredShape.dim(i).has_dim_value() ? inferredShape.dim(i).dim_value() == expectedShape.dim(i).dim_value()
: inferredShape.dim(i).dim_param() == expectedShape.dim(i).dim_param());
}
}
TEST(ShapeInferenceTest, ReshapeTestWithShapeAsSymInput) {
const char* modelStr = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 15],
producer_name: "DataPropagationTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for data propagation."
>
agraph (float[batch_size, 256, 768, 3] x, float[batch_size, 196608] m) => (float[?, ?, ?] z)
{
y = Shape<start = 0, end = 3>(x)
z = Reshape(m, y)
}
)ONNX";
TensorShapeProto expectedShape;
expectedShape.mutable_dim()->Add()->set_dim_param("batch_size");
expectedShape.mutable_dim()->Add()->set_dim_value(256);
expectedShape.mutable_dim()->Add()->set_dim_value(768);
RunReshapeShapeInfTest(modelStr, expectedShape);
}
TEST(ShapeInferenceTest, ReshapeTestWithShapeAsInitializer) {
const char* modelStr = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 15],
producer_name: "DataPropagationTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for data propagation."
>
agraph (float[1, 196608] m) => (float[?, ?, ?] z)
<int64[3] shape = {1, 768, 256}>
{
z = Reshape(m, shape)
}
)ONNX";
TensorShapeProto expectedShape;
expectedShape.mutable_dim()->Add()->set_dim_value(1);
expectedShape.mutable_dim()->Add()->set_dim_value(768);
expectedShape.mutable_dim()->Add()->set_dim_value(256);
RunReshapeShapeInfTest(modelStr, expectedShape);
}
TEST(ShapeInferenceTest, ReshapeTestWithShapeAsInitializer1) {
const char* modelStr = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 15],
producer_name: "DataPropagationTest",
producer_version: "1.0",
model_version: 1,
doc_string: "A test model for data propagation."
>
agraph (float[1, 196608] m) => (float[?, ?, ?] z)
<int64[3] shape = {1, -1, 256}>
{
z = Reshape(m, shape)
}
)ONNX";
TensorShapeProto expectedShape;
expectedShape.mutable_dim()->Add()->set_dim_value(1);
expectedShape.mutable_dim()->Add()->set_dim_value(768);
expectedShape.mutable_dim()->Add()->set_dim_value(256);
RunReshapeShapeInfTest(modelStr, expectedShape);
}
TEST(ShapeInferenceTest, CheckShapesAndTypesTest) {
#ifndef ONNX_NO_EXCEPTIONS
// Tensor element types mis-match should cause an exception.
TypeProto tensor_infer;
auto* tensor_infer_type = tensor_infer.mutable_tensor_type();
tensor_infer_type->set_elem_type(TensorProto_DataType_FLOAT);
TypeProto tensor_exist;
auto* tensor_exist_type = tensor_exist.mutable_tensor_type();
tensor_exist_type->set_elem_type(TensorProto_DataType_UINT8);
EXPECT_THROW(checkShapesAndTypes(tensor_infer, tensor_exist), ONNX_NAMESPACE::InferenceError);
#endif
}
TEST(ShapeInferenceTest, CustomOpTest) {
const char* modelStr = R"ONNX(
<ir_version: 8, opset_import: ["" : 15, "custom.domain" : 1]>
agraph (float[256, 768, 3] x) => (z1, z2)
{
z1 = custom.domain.CustomOp (x)
# Inference cannot determine the type/shape of z1
z2 = Abs(x)
# Inference SHOULD determine the type/shape of z2 (same as that of x)
}
)ONNX";
ModelProto model;
ParseAndInfer(model, modelStr);
auto& z1_value_info = model.graph().output(0);
// Check no inferred type for z1 (It's a quirk of the implementation that it
// has a dummy TypeProto, but it should have no values filled in.)
ASSERT_TRUE(z1_value_info.has_type());
ASSERT_FALSE(z1_value_info.type().has_tensor_type());
// Check inferred type for z2:
auto& z2_value_info = model.graph().output(1);
ASSERT_TRUE(z2_value_info.has_type());
ASSERT_TRUE(z2_value_info.type().has_tensor_type());
EXPECT_EQ(z2_value_info.type().tensor_type().elem_type(), TensorProto_DataType_FLOAT);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim_size(), 3);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(0).dim_value(), 256);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(1).dim_value(), 768);
EXPECT_EQ(z2_value_info.type().tensor_type().shape().dim(2).dim_value(), 3);
}
} // namespace Test
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,15 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <iostream>
#include "gtest/gtest.h"
GTEST_API_ int main(int argc, char** argv) {
std::cout << "Running main() from test_main.cc" << std::endl;
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}