896 lines
27 KiB
C++
896 lines
27 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
|
|
// by this parser is preliminary and may change.
|
|
|
|
#include "onnx/defs/parser.h"
|
|
|
|
#include <cctype>
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "onnx/common/common.h"
|
|
#include "onnx/onnx_pb.h"
|
|
#include "onnx/string_utils.h"
|
|
|
|
#define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x))
|
|
#define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__))
|
|
#define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__))
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
|
|
Status ParserBase::Parse(Literal& result) {
|
|
bool decimal_point = false;
|
|
auto nextch = NextChar();
|
|
auto from = next_;
|
|
if (nextch == '"') {
|
|
++next_;
|
|
bool has_escape = false;
|
|
while ((next_ < end_) && (*next_ != '"')) {
|
|
if (*next_ == '\\') {
|
|
has_escape = true;
|
|
++next_;
|
|
if (next_ >= end_)
|
|
return ParseError("Incomplete string literal.");
|
|
}
|
|
++next_;
|
|
}
|
|
if (next_ >= end_)
|
|
return ParseError("Incomplete string literal.");
|
|
++next_;
|
|
result.type = LiteralType::STRING_LITERAL;
|
|
if (has_escape) {
|
|
std::string& target = result.value;
|
|
target.clear();
|
|
target.reserve(next_ - from - 2); // upper bound
|
|
// *from is the starting quote. *(next_-1) is the ending quote.
|
|
// Copy what is in-between, except for the escape character
|
|
while (++from < next_ - 1) {
|
|
// Copy current char, if not escape, or next char otherwise.
|
|
target.push_back(*from != '\\' ? (*from) : *(++from));
|
|
}
|
|
} else
|
|
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
|
|
return Status::OK();
|
|
}
|
|
|
|
// Simplify the next ifs by consuming a possible negative sign.
|
|
if (nextch == '-') {
|
|
++next_;
|
|
nextch = NextChar();
|
|
}
|
|
|
|
// Check for float literals that start with alphabet characters.
|
|
if (isalpha(nextch)) {
|
|
// Has to be a special float literal now: (-)*(nan|inf|infinity).
|
|
if (NextIsValidFloatString()) {
|
|
while (next_ < end_ && isalpha(*next_)) {
|
|
++next_;
|
|
}
|
|
ONNX_TRY {
|
|
static_cast<void>(std::stof(std::string(from, next_ - from)));
|
|
result.type = LiteralType::FLOAT_LITERAL;
|
|
result.value = std::string(from, next_ - from);
|
|
}
|
|
ONNX_CATCH(...) {
|
|
ONNX_HANDLE_EXCEPTION([&]() { return ParseError("Encountered invalid float literal!"); });
|
|
}
|
|
} else {
|
|
return ParseError("Encountered invalid float literal!");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Checking for numeric ints or float literal.
|
|
if (isdigit(nextch)) {
|
|
++next_;
|
|
|
|
while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
|
|
if (*next_ == '.') {
|
|
if (decimal_point)
|
|
break; // Only one decimal point allowed in numeric literal
|
|
decimal_point = true;
|
|
}
|
|
++next_;
|
|
}
|
|
|
|
if (next_ == from)
|
|
return ParseError("Value expected but not found.");
|
|
|
|
// Optional exponent syntax: (e|E)(+|-)?[0-9]+
|
|
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
|
|
decimal_point = true; // treat as float-literal
|
|
++next_;
|
|
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
|
|
++next_;
|
|
while ((next_ < end_) && (isdigit(*next_)))
|
|
++next_;
|
|
}
|
|
|
|
result.value = std::string(from, next_ - from);
|
|
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
bool ParserBase::NextIsValidFloatString() {
|
|
auto nextch = NextChar();
|
|
auto from = next_;
|
|
constexpr int INFINITY_LENGTH = 8;
|
|
|
|
if (isalpha(nextch)) {
|
|
while (next_ < end_ && isalpha(*next_) && (next_ - from) <= INFINITY_LENGTH) {
|
|
++next_;
|
|
}
|
|
|
|
if (isdigit(*next_)) { // No trailing digits
|
|
next_ = from;
|
|
return false;
|
|
}
|
|
|
|
std::string candidate = std::string(from, next_ - from);
|
|
|
|
// Reset parser location before continuing.
|
|
next_ = from;
|
|
|
|
std::transform(
|
|
candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char c) { return std::tolower(c); });
|
|
if (candidate == std::string("inf") || candidate == std::string("infinity") || candidate == std::string("nan")) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Status OnnxParser::Parse(IdList& idlist) {
|
|
idlist.Clear();
|
|
std::string id;
|
|
ParseOptionalIdentifier(id);
|
|
if (id.empty())
|
|
return Status::OK(); // Treat as empty list of identifiers
|
|
*idlist.Add() = id;
|
|
while (Matches(',')) {
|
|
ParseOptionalIdentifier(id);
|
|
*idlist.Add() = id;
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(char open, IdList& idlist, char close) {
|
|
idlist.Clear();
|
|
if (Matches(open)) {
|
|
PARSE(idlist);
|
|
MATCH(close);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(IdList& idlist, AttrList& attrlist) {
|
|
idlist.Clear();
|
|
attrlist.Clear();
|
|
do {
|
|
std::string id;
|
|
ParseIdentifier(id);
|
|
auto next = NextChar();
|
|
if (next == ':' || next == '=')
|
|
Parse(*attrlist.Add(), id);
|
|
else
|
|
*idlist.Add() = id;
|
|
} while (Matches(','));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(char open, IdList& idlist, AttrList& attrlist, char close) {
|
|
if (Matches(open)) {
|
|
PARSE(idlist, attrlist);
|
|
MATCH(close);
|
|
} else {
|
|
idlist.Clear();
|
|
attrlist.Clear();
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(TensorShapeProto& shape) {
|
|
shape.clear_dim();
|
|
do {
|
|
if (Matches('?')) {
|
|
shape.add_dim();
|
|
} else {
|
|
// Check for a symbolic identifier ...
|
|
std::string id;
|
|
CHECK_PARSER_STATUS(ParseOptionalIdentifier(id));
|
|
if (!id.empty()) {
|
|
shape.add_dim()->set_dim_param(id);
|
|
} else {
|
|
// ...or a integer value
|
|
int64_t dimval = 0;
|
|
PARSE_TOKEN(dimval);
|
|
shape.add_dim()->set_dim_value(dimval);
|
|
}
|
|
}
|
|
} while (Matches(','));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(TypeProto& typeProto) {
|
|
std::string id;
|
|
CHECK_PARSER_STATUS(ParseIdentifier(id));
|
|
int dtype = PrimitiveTypeNameMap::Lookup(id);
|
|
if (dtype != 0) {
|
|
auto* tensortype = typeProto.mutable_tensor_type();
|
|
tensortype->set_elem_type(dtype);
|
|
tensortype->clear_shape();
|
|
// Grammar:
|
|
// float indicates scalar (rank 0)
|
|
// float [] indicates unknown rank tensor (not a zero rank tensor)
|
|
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
|
|
if (Matches('[')) {
|
|
if (!Matches(']')) {
|
|
PARSE(*tensortype->mutable_shape());
|
|
MATCH(']');
|
|
}
|
|
} else {
|
|
// Create shape with zero dimensions for scalar
|
|
(void)(tensortype->mutable_shape());
|
|
}
|
|
} else {
|
|
switch (KeyWordMap::Lookup(id)) {
|
|
case KeyWordMap::KeyWord::SEQ_TYPE: {
|
|
// Grammar: seq ( type )
|
|
MATCH('(');
|
|
auto* seqtype = typeProto.mutable_sequence_type();
|
|
PARSE(*seqtype->mutable_elem_type());
|
|
MATCH(')');
|
|
break;
|
|
}
|
|
case KeyWordMap::KeyWord::MAP_TYPE: {
|
|
// Grammar: map ( prim-type , type )
|
|
MATCH('(');
|
|
auto* maptype = typeProto.mutable_map_type();
|
|
CHECK_PARSER_STATUS(ParseIdentifier(id));
|
|
dtype = PrimitiveTypeNameMap::Lookup(id);
|
|
if (dtype == 0) {
|
|
return ParseError("Expecting primitive type as map key type.");
|
|
}
|
|
maptype->set_key_type(dtype);
|
|
MATCH(',');
|
|
PARSE(*maptype->mutable_value_type());
|
|
MATCH(')');
|
|
break;
|
|
}
|
|
case KeyWordMap::KeyWord::OPTIONAL_TYPE: {
|
|
// Grammar: optional ( type )
|
|
MATCH('(');
|
|
auto* opttype = typeProto.mutable_optional_type();
|
|
PARSE(*opttype->mutable_elem_type());
|
|
MATCH(')');
|
|
break;
|
|
}
|
|
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: {
|
|
// Grammar: sparse_tensor ( tensor-type )
|
|
MATCH('(');
|
|
CHECK_PARSER_STATUS(ParseIdentifier(id));
|
|
dtype = PrimitiveTypeNameMap::Lookup(id);
|
|
if (dtype != 0) {
|
|
auto* sparsetype = typeProto.mutable_sparse_tensor_type();
|
|
sparsetype->set_elem_type(dtype);
|
|
sparsetype->clear_shape();
|
|
// Grammar:
|
|
// float indicates scalar (rank 0)
|
|
// float [] indicates unknown rank tensor (not a zero rank tensor)
|
|
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
|
|
if (Matches('[')) {
|
|
if (!Matches(']')) {
|
|
PARSE(*sparsetype->mutable_shape());
|
|
MATCH(']');
|
|
}
|
|
} else {
|
|
// Create shape with zero dimensions for scalar
|
|
(void)(sparsetype->mutable_shape());
|
|
}
|
|
} else {
|
|
return ParseError("Unexpected type in sparse-tensor element type.");
|
|
}
|
|
MATCH(')');
|
|
break;
|
|
}
|
|
default:
|
|
return ParseError("Unexpected type.");
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
|
|
if (NextIsType())
|
|
PARSE(*valueinfo.mutable_type());
|
|
std::string name;
|
|
CHECK_PARSER_STATUS(ParseIdentifier(name));
|
|
valueinfo.set_name(name);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) {
|
|
MATCH(open);
|
|
if (!Matches(close)) {
|
|
do {
|
|
PARSE(*vilist.Add());
|
|
} while (Matches(','));
|
|
MATCH(close);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) {
|
|
vilist.Clear();
|
|
PARSE('(', vilist, ')');
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) {
|
|
// Do not clear vilist, as it accumulates values over inputs and outputs.
|
|
idlist.Clear();
|
|
MATCH('(');
|
|
if (!Matches(')')) {
|
|
do {
|
|
// Function inputs/outputs can be optionally typed.
|
|
// Syntax: Name | Type Name
|
|
// The name is added to idlist. If the optional type is present, an entry is
|
|
// added to vilist.
|
|
|
|
std::string* name = idlist.Add();
|
|
ValueInfoProto* vi = nullptr;
|
|
|
|
if (NextIsType()) {
|
|
vi = vilist.Add();
|
|
PARSE(*(vi->mutable_type()));
|
|
}
|
|
CHECK_PARSER_STATUS(ParseIdentifier(*name));
|
|
if (vi != nullptr)
|
|
vi->set_name(*name);
|
|
} while (Matches(','));
|
|
MATCH(')');
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Each input element is a value-info with an optional initializer of the form "= initial-value".
|
|
// The value-info is added to the "inputs", while the initializer is added to initializers.
|
|
Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) {
|
|
inputs.Clear();
|
|
if (Matches('(')) {
|
|
if (!Matches(')')) {
|
|
do {
|
|
ValueInfoProto vi;
|
|
PARSE(vi);
|
|
*inputs.Add() = vi;
|
|
if (Matches('=')) {
|
|
// default value for input
|
|
TensorProto& tp = *initializers.Add();
|
|
tp.set_name(vi.name());
|
|
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
|
|
}
|
|
} while (Matches(','));
|
|
MATCH(')');
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// This is handled slightly different from the inputs.
|
|
// Each element is either a value-info or an initializer.
|
|
// A value-info is added to the "value_infos", while an initializer is added to initializers.
|
|
Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) {
|
|
value_infos.Clear();
|
|
if (Matches('<')) {
|
|
if (!Matches('>')) {
|
|
do {
|
|
ValueInfoProto vi;
|
|
PARSE(vi);
|
|
if (Matches('=')) {
|
|
// initializer
|
|
TensorProto& tp = *initializers.Add();
|
|
tp.set_name(vi.name());
|
|
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
|
|
} else {
|
|
// valueinfo
|
|
*value_infos.Add() = vi;
|
|
}
|
|
} while (Matches(','));
|
|
MATCH('>');
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(StringStringList& stringStringList) {
|
|
std::string strval;
|
|
do {
|
|
auto* metadata = stringStringList.Add();
|
|
PARSE_TOKEN(strval);
|
|
metadata->set_key(strval);
|
|
MATCH(':');
|
|
PARSE_TOKEN(strval);
|
|
metadata->set_value(strval);
|
|
} while (Matches(','));
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(TensorProto& tensorProto) {
|
|
tensorProto = TensorProto();
|
|
// Parse the concrete tensor-type with numeric dimensions:
|
|
TypeProto typeProto;
|
|
PARSE(typeProto);
|
|
ParseOptionalIdentifier(*tensorProto.mutable_name());
|
|
(void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts
|
|
return Parse(tensorProto, typeProto);
|
|
}
|
|
|
|
// Parse TensorProto data given its type:
|
|
Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) {
|
|
if (!tensorTypeProto.has_tensor_type())
|
|
return ParseError("Error parsing TensorProto (expected a tensor type).");
|
|
auto elem_type = tensorTypeProto.tensor_type().elem_type();
|
|
tensorProto.set_data_type(elem_type);
|
|
if (!tensorTypeProto.tensor_type().has_shape())
|
|
return ParseError("Error parsing TensorProto (expected a tensor shape).");
|
|
for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) {
|
|
if (!dim.has_dim_value())
|
|
return ParseError("Error parsing TensorProto shape (expected numeric dimension).");
|
|
auto dimval = dim.dim_value();
|
|
tensorProto.add_dims(dimval);
|
|
}
|
|
|
|
// tensorProto.mutable_int64_data()->Reserve(n);
|
|
// Parse the actual values:
|
|
|
|
int64_t intval;
|
|
uint64_t uintval = 0;
|
|
float floatval = 0.0;
|
|
double dblval = 0.0;
|
|
std::string strval;
|
|
if (Matches('{')) {
|
|
if (!Matches('}')) {
|
|
do {
|
|
switch (static_cast<TensorProto::DataType>(elem_type)) {
|
|
case TensorProto::DataType::TensorProto_DataType_INT4:
|
|
case TensorProto::DataType::TensorProto_DataType_INT8:
|
|
case TensorProto::DataType::TensorProto_DataType_INT16:
|
|
case TensorProto::DataType::TensorProto_DataType_INT32:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT4:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT8:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT16:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT16:
|
|
case TensorProto::DataType::TensorProto_DataType_BFLOAT16:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FN:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
|
|
case TensorProto::DataType::TensorProto_DataType_BOOL:
|
|
PARSE_TOKEN(intval);
|
|
// TODO: check values are in the correct range.
|
|
tensorProto.add_int32_data(intval);
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_INT64:
|
|
PARSE_TOKEN(intval);
|
|
tensorProto.add_int64_data(intval);
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_UINT32:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT64:
|
|
PARSE_TOKEN(uintval);
|
|
tensorProto.add_uint64_data(uintval);
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_COMPLEX64:
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
|
PARSE_TOKEN(floatval);
|
|
tensorProto.add_float_data(floatval);
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
|
|
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
|
PARSE_TOKEN(dblval);
|
|
tensorProto.add_double_data(dblval);
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_STRING:
|
|
PARSE_TOKEN(strval);
|
|
tensorProto.add_string_data(strval);
|
|
break;
|
|
default:
|
|
return ParseError("Unhandled type: %d", elem_type);
|
|
}
|
|
} while (Matches(','));
|
|
MATCH('}');
|
|
}
|
|
} else if (Matches('[')) {
|
|
tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL);
|
|
auto& externalData = *tensorProto.mutable_external_data();
|
|
PARSE(externalData);
|
|
MATCH(']');
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
bool OnnxParser::NextIsIdentifier() {
|
|
std::string id("");
|
|
(void)PeekIdentifier(id);
|
|
return !(id.empty());
|
|
}
|
|
|
|
bool OnnxParser::NextIsType() {
|
|
std::string id("");
|
|
(void)PeekIdentifier(id);
|
|
if (PrimitiveTypeNameMap::IsTypeName(id))
|
|
return true;
|
|
switch (KeyWordMap::Lookup(id)) {
|
|
case KeyWordMap::KeyWord::SEQ_TYPE:
|
|
case KeyWordMap::KeyWord::MAP_TYPE:
|
|
case KeyWordMap::KeyWord::OPTIONAL_TYPE:
|
|
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected) {
|
|
// Parse a single-value
|
|
auto next = NextChar();
|
|
if (isalpha(next) || next == '_') {
|
|
if (NextIsType()) {
|
|
TypeProto typeProto;
|
|
Parse(typeProto);
|
|
next = NextChar();
|
|
if ((next == '{') || (next == '=') || (NextIsIdentifier())) {
|
|
attr.set_type(AttributeProto_AttributeType_TENSOR);
|
|
auto& tensorProto = *attr.mutable_t();
|
|
ParseOptionalIdentifier(*tensorProto.mutable_name());
|
|
(void)Matches('='); // Optional, to unify handling of initializers
|
|
Parse(tensorProto, typeProto);
|
|
} else {
|
|
attr.set_type(AttributeProto_AttributeType_TYPE_PROTO);
|
|
attr.mutable_tp()->CopyFrom(typeProto);
|
|
}
|
|
} else {
|
|
if (NextIsValidFloatString()) {
|
|
Literal literal;
|
|
PARSE_TOKEN(literal);
|
|
attr.set_type(AttributeProto_AttributeType_FLOAT);
|
|
attr.set_f(static_cast<float>(std::stof(literal.value)));
|
|
} else {
|
|
attr.set_type(AttributeProto_AttributeType_GRAPH);
|
|
PARSE(*attr.mutable_g());
|
|
}
|
|
}
|
|
} else if (Matches('@')) {
|
|
std::string name;
|
|
CHECK_PARSER_STATUS(ParseIdentifier(name));
|
|
attr.set_ref_attr_name(name);
|
|
} else {
|
|
Literal literal;
|
|
PARSE_TOKEN(literal);
|
|
switch (literal.type) {
|
|
case LiteralType::INT_LITERAL:
|
|
attr.set_type(AttributeProto_AttributeType_INT);
|
|
attr.set_i(std::stol(literal.value));
|
|
break;
|
|
case LiteralType::FLOAT_LITERAL:
|
|
attr.set_type(AttributeProto_AttributeType_FLOAT);
|
|
attr.set_f(static_cast<float>(std::stof(literal.value)));
|
|
break;
|
|
case LiteralType::STRING_LITERAL:
|
|
attr.set_type(AttributeProto_AttributeType_STRING);
|
|
attr.set_s(literal.value);
|
|
break;
|
|
}
|
|
}
|
|
if ((expected != AttributeProto_AttributeType_UNDEFINED) && (expected != attr.type())) {
|
|
// Mismatch between type-annotation and attribute-value. We do an implicit cast
|
|
// only in the special case of FLOAT type and integral value like 2
|
|
if ((expected == AttributeProto_AttributeType_FLOAT) && (attr.type() == AttributeProto_AttributeType_INT)) {
|
|
attr.set_type(AttributeProto_AttributeType_FLOAT);
|
|
attr.set_f(static_cast<float>(attr.i()));
|
|
} else {
|
|
return ParseError(
|
|
"Mismatch between expected type ",
|
|
AttributeProto_AttributeType_Name(expected),
|
|
" and specified value's type",
|
|
AttributeProto_AttributeType_Name(attr.type()));
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(AttributeProto& attr) {
|
|
attr.Clear();
|
|
std::string name;
|
|
CHECK_PARSER_STATUS(ParseIdentifier(name));
|
|
return Parse(attr, name);
|
|
}
|
|
|
|
bool IsSingletonAttribute(AttributeProto_AttributeType type) {
|
|
switch (type) {
|
|
case AttributeProto_AttributeType_FLOAT:
|
|
case AttributeProto_AttributeType_INT:
|
|
case AttributeProto_AttributeType_STRING:
|
|
case AttributeProto_AttributeType_TENSOR:
|
|
case AttributeProto_AttributeType_GRAPH:
|
|
case AttributeProto_AttributeType_SPARSE_TENSOR:
|
|
case AttributeProto_AttributeType_TYPE_PROTO:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
AttributeProto_AttributeType ToSingletonType(AttributeProto_AttributeType type) {
|
|
switch (type) {
|
|
case AttributeProto_AttributeType_FLOATS:
|
|
return AttributeProto_AttributeType_FLOAT;
|
|
case AttributeProto_AttributeType_INTS:
|
|
return AttributeProto_AttributeType_INT;
|
|
case AttributeProto_AttributeType_STRINGS:
|
|
return AttributeProto_AttributeType_STRING;
|
|
case AttributeProto_AttributeType_TENSORS:
|
|
return AttributeProto_AttributeType_TENSOR;
|
|
case AttributeProto_AttributeType_GRAPHS:
|
|
return AttributeProto_AttributeType_GRAPH;
|
|
case AttributeProto_AttributeType_SPARSE_TENSORS:
|
|
return AttributeProto_AttributeType_SPARSE_TENSOR;
|
|
case AttributeProto_AttributeType_TYPE_PROTOS:
|
|
return AttributeProto_AttributeType_TYPE_PROTO;
|
|
default:
|
|
return type;
|
|
}
|
|
}
|
|
|
|
Status OnnxParser::Parse(AttributeProto& attr, std::string& name) {
|
|
attr.set_name(name);
|
|
if (Matches(':')) {
|
|
CHECK_PARSER_STATUS(ParseIdentifier(name));
|
|
int attrtype = AttributeTypeNameMap::Lookup(name);
|
|
if (attrtype != 0) {
|
|
attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
|
|
} else {
|
|
return ParseError("Unexpected attribute type.");
|
|
}
|
|
}
|
|
MATCH('=');
|
|
if (NextChar() == '[') {
|
|
// Parse a list of values. For an empty list, the type MUST be specified
|
|
// using the type-annotation syntax of ": type".
|
|
MATCH('[');
|
|
if (NextChar() != ']') {
|
|
do {
|
|
AttributeProto nextval;
|
|
auto expected_type = ToSingletonType(attr.type());
|
|
CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval, expected_type));
|
|
switch (nextval.type()) {
|
|
case AttributeProto_AttributeType_INT:
|
|
attr.set_type(AttributeProto_AttributeType_INTS);
|
|
attr.add_ints(nextval.i());
|
|
break;
|
|
case AttributeProto_AttributeType_FLOAT:
|
|
attr.set_type(AttributeProto_AttributeType_FLOATS);
|
|
attr.add_floats(nextval.f());
|
|
break;
|
|
case AttributeProto_AttributeType_STRING:
|
|
attr.add_strings(nextval.s());
|
|
attr.set_type(AttributeProto_AttributeType_STRINGS);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
} while (Matches(','));
|
|
} else {
|
|
if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
|
|
return ParseError("Empty list attribute value requires type annotation.");
|
|
if (IsSingletonAttribute(attr.type()))
|
|
return ParseError("Singleton attribute value cannot be specified as a list.");
|
|
}
|
|
MATCH(']');
|
|
} else {
|
|
CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr, attr.type()));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(AttrList& attrlist) {
|
|
attrlist.Clear();
|
|
if (Matches('<')) {
|
|
do {
|
|
PARSE(*attrlist.Add());
|
|
} while (Matches(','));
|
|
MATCH('>');
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(NodeProto& node) {
|
|
PARSE(*node.mutable_output());
|
|
MATCH('=');
|
|
std::string domain("");
|
|
std::string id;
|
|
ParseIdentifier(id);
|
|
while (Matches('.')) {
|
|
if (!domain.empty())
|
|
domain += ".";
|
|
domain += id;
|
|
ParseIdentifier(id);
|
|
}
|
|
node.set_domain(domain);
|
|
node.set_op_type(id);
|
|
|
|
if (Matches(':')) {
|
|
std::string overload;
|
|
ParseIdentifier(overload);
|
|
node.set_overload(overload);
|
|
}
|
|
PARSE(*node.mutable_attribute());
|
|
MATCH('(');
|
|
PARSE(*node.mutable_input());
|
|
MATCH(')');
|
|
if (node.attribute_size() == 0) {
|
|
// Permit attributes to be specified before or after parameters.
|
|
PARSE(*node.mutable_attribute());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(NodeList& nodelist) {
|
|
nodelist.Clear();
|
|
MATCH('{');
|
|
while (!Matches('}')) {
|
|
PARSE(*nodelist.Add());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(GraphProto& graph) {
|
|
std::string id;
|
|
ParseIdentifier(id);
|
|
return Parse(id, graph);
|
|
}
|
|
|
|
Status OnnxParser::Parse(std::string name, GraphProto& graph) {
|
|
graph.set_name(name);
|
|
graph.mutable_initializer()->Clear();
|
|
CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
|
|
MATCH('=');
|
|
MATCH('>', false);
|
|
CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output()));
|
|
CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
|
|
return Parse(*graph.mutable_node());
|
|
}
|
|
|
|
Status OnnxParser::Parse(FunctionProto& fn) {
|
|
fn.Clear();
|
|
std::string strval;
|
|
if (Matches('<')) {
|
|
do {
|
|
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
|
|
PARSE_TOKEN(keyword);
|
|
MATCH(':');
|
|
switch (keyword) {
|
|
case KeyWordMap::KeyWord::OPSET_IMPORT:
|
|
PARSE(*fn.mutable_opset_import());
|
|
break;
|
|
case KeyWordMap::KeyWord::DOC_STRING:
|
|
PARSE_TOKEN(strval);
|
|
fn.set_doc_string(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::DOMAIN_KW:
|
|
PARSE_TOKEN(strval);
|
|
fn.set_domain(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::OVERLOAD_KW:
|
|
PARSE_TOKEN(strval);
|
|
fn.set_overload(strval);
|
|
break;
|
|
default:
|
|
return ParseError("Unhandled keyword.");
|
|
}
|
|
} while (Matches(','));
|
|
MATCH('>');
|
|
}
|
|
std::string id;
|
|
ParseIdentifier(id);
|
|
fn.set_name(id);
|
|
|
|
PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>');
|
|
fn.mutable_value_info()->Clear();
|
|
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info()));
|
|
MATCH('=');
|
|
MATCH('>', false);
|
|
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info()));
|
|
if (NextChar() == '<') {
|
|
PARSE('<', *fn.mutable_value_info(), '>');
|
|
}
|
|
return Parse(*fn.mutable_node());
|
|
}
|
|
|
|
Status OnnxParser::Parse(OpsetIdList& opsets) {
|
|
std::string strval;
|
|
int64_t intval = 0;
|
|
MATCH('[');
|
|
if (!Matches(']')) {
|
|
do {
|
|
auto* import = opsets.Add();
|
|
PARSE_TOKEN(strval);
|
|
import->set_domain(strval);
|
|
MATCH(':');
|
|
PARSE_TOKEN(intval);
|
|
import->set_version(intval);
|
|
} while (Matches(','));
|
|
MATCH(']');
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status OnnxParser::Parse(ModelProto& model) {
|
|
model.Clear();
|
|
std::string strval;
|
|
int64_t intval;
|
|
if (Matches('<')) {
|
|
do {
|
|
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
|
|
PARSE_TOKEN(keyword);
|
|
MATCH(':');
|
|
switch (keyword) {
|
|
case KeyWordMap::KeyWord::IR_VERSION:
|
|
PARSE_TOKEN(intval);
|
|
model.set_ir_version(intval);
|
|
break;
|
|
case KeyWordMap::KeyWord::OPSET_IMPORT:
|
|
PARSE(*model.mutable_opset_import());
|
|
break;
|
|
case KeyWordMap::KeyWord::PRODUCER_NAME:
|
|
PARSE_TOKEN(strval);
|
|
model.set_producer_name(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::PRODUCER_VERSION:
|
|
PARSE_TOKEN(strval);
|
|
model.set_producer_version(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::DOMAIN_KW:
|
|
PARSE_TOKEN(strval);
|
|
model.set_domain(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::MODEL_VERSION:
|
|
PARSE_TOKEN(intval);
|
|
model.set_model_version(intval);
|
|
break;
|
|
case KeyWordMap::KeyWord::DOC_STRING:
|
|
PARSE_TOKEN(strval);
|
|
model.set_doc_string(strval);
|
|
break;
|
|
case KeyWordMap::KeyWord::METADATA_PROPS: {
|
|
auto& metadata_props = *model.mutable_metadata_props();
|
|
MATCH('[');
|
|
if (!Matches(']')) {
|
|
PARSE(metadata_props);
|
|
MATCH(']');
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return ParseError("Unhandled keyword.");
|
|
}
|
|
} while (Matches(','));
|
|
MATCH('>');
|
|
}
|
|
PARSE(*model.mutable_graph());
|
|
|
|
auto* functions = model.mutable_functions();
|
|
while (!EndOfInput()) {
|
|
PARSE(*functions->Add());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace ONNX_NAMESPACE
|