458 lines
13 KiB
C++
458 lines
13 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
|
|
// by this parser is preliminary and may change.
|
|
|
|
#pragma once
|
|
|
|
#include <ctype.h>
|
|
|
|
#include <iostream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
#include "onnx/common/status.h"
|
|
#include "onnx/onnx_pb.h"
|
|
#include "onnx/string_utils.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
|
|
using namespace ONNX_NAMESPACE::Common;
|
|
|
|
using IdList = google::protobuf::RepeatedPtrField<std::string>;
|
|
|
|
using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
|
|
|
|
using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
|
|
|
|
using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
|
|
|
|
using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
|
|
|
|
using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;
|
|
|
|
using StringStringList = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
|
|
|
|
#define CHECK_PARSER_STATUS(status) \
|
|
{ \
|
|
auto local_status_ = status; \
|
|
if (!local_status_.IsOK()) \
|
|
return local_status_; \
|
|
}
|
|
|
|
template <typename Map>
|
|
class StringIntMap {
|
|
public:
|
|
static const std::unordered_map<std::string, int32_t>& Instance() {
|
|
static Map instance;
|
|
return instance.map_;
|
|
}
|
|
|
|
static int32_t Lookup(const std::string& dtype) {
|
|
auto it = Instance().find(dtype);
|
|
if (it != Instance().end())
|
|
return it->second;
|
|
return 0;
|
|
}
|
|
|
|
static const std::string& ToString(int32_t dtype) {
|
|
static std::string undefined("undefined");
|
|
for (const auto& pair : Instance()) {
|
|
if (pair.second == dtype)
|
|
return pair.first;
|
|
}
|
|
return undefined;
|
|
}
|
|
|
|
protected:
|
|
std::unordered_map<std::string, int32_t> map_;
|
|
};
|
|
|
|
class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
|
|
public:
|
|
PrimitiveTypeNameMap() : StringIntMap() {
|
|
map_["float"] = TensorProto_DataType_FLOAT;
|
|
map_["uint8"] = TensorProto_DataType_UINT8;
|
|
map_["int8"] = TensorProto_DataType_INT8;
|
|
map_["uint16"] = TensorProto_DataType_UINT16;
|
|
map_["int16"] = TensorProto_DataType_INT16;
|
|
map_["int32"] = TensorProto_DataType_INT32;
|
|
map_["int64"] = TensorProto_DataType_INT64;
|
|
map_["string"] = TensorProto_DataType_STRING;
|
|
map_["bool"] = TensorProto_DataType_BOOL;
|
|
map_["float16"] = TensorProto_DataType_FLOAT16;
|
|
map_["double"] = TensorProto_DataType_DOUBLE;
|
|
map_["uint32"] = TensorProto_DataType_UINT32;
|
|
map_["uint64"] = TensorProto_DataType_UINT64;
|
|
map_["complex64"] = TensorProto_DataType_COMPLEX64;
|
|
map_["complex128"] = TensorProto_DataType_COMPLEX128;
|
|
map_["bfloat16"] = TensorProto_DataType_BFLOAT16;
|
|
map_["float8e4m3fn"] = TensorProto_DataType_FLOAT8E4M3FN;
|
|
map_["float8e4m3fnuz"] = TensorProto_DataType_FLOAT8E4M3FNUZ;
|
|
map_["float8e5m2"] = TensorProto_DataType_FLOAT8E5M2;
|
|
map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
|
|
map_["uint4"] = TensorProto_DataType_UINT4;
|
|
map_["int4"] = TensorProto_DataType_INT4;
|
|
}
|
|
|
|
static bool IsTypeName(const std::string& dtype) {
|
|
return Lookup(dtype) != 0;
|
|
}
|
|
};
|
|
|
|
class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
|
|
public:
|
|
AttributeTypeNameMap() : StringIntMap() {
|
|
map_["float"] = AttributeProto_AttributeType_FLOAT;
|
|
map_["int"] = AttributeProto_AttributeType_INT;
|
|
map_["string"] = AttributeProto_AttributeType_STRING;
|
|
map_["tensor"] = AttributeProto_AttributeType_TENSOR;
|
|
map_["graph"] = AttributeProto_AttributeType_GRAPH;
|
|
map_["sparse_tensor"] = AttributeProto_AttributeType_SPARSE_TENSOR;
|
|
map_["type_proto"] = AttributeProto_AttributeType_TYPE_PROTO;
|
|
map_["floats"] = AttributeProto_AttributeType_FLOATS;
|
|
map_["ints"] = AttributeProto_AttributeType_INTS;
|
|
map_["strings"] = AttributeProto_AttributeType_STRINGS;
|
|
map_["tensors"] = AttributeProto_AttributeType_TENSORS;
|
|
map_["graphs"] = AttributeProto_AttributeType_GRAPHS;
|
|
map_["sparse_tensors"] = AttributeProto_AttributeType_SPARSE_TENSORS;
|
|
map_["type_protos"] = AttributeProto_AttributeType_TYPE_PROTOS;
|
|
}
|
|
};
|
|
|
|
class KeyWordMap {
|
|
public:
|
|
enum class KeyWord {
|
|
NONE,
|
|
IR_VERSION,
|
|
OPSET_IMPORT,
|
|
PRODUCER_NAME,
|
|
PRODUCER_VERSION,
|
|
DOMAIN_KW,
|
|
MODEL_VERSION,
|
|
DOC_STRING,
|
|
METADATA_PROPS,
|
|
SEQ_TYPE,
|
|
MAP_TYPE,
|
|
OPTIONAL_TYPE,
|
|
SPARSE_TENSOR_TYPE,
|
|
OVERLOAD_KW
|
|
};
|
|
|
|
KeyWordMap() {
|
|
map_["ir_version"] = KeyWord::IR_VERSION;
|
|
map_["opset_import"] = KeyWord::OPSET_IMPORT;
|
|
map_["producer_name"] = KeyWord::PRODUCER_NAME;
|
|
map_["producer_version"] = KeyWord::PRODUCER_VERSION;
|
|
map_["domain"] = KeyWord::DOMAIN_KW;
|
|
map_["model_version"] = KeyWord::MODEL_VERSION;
|
|
map_["doc_string"] = KeyWord::DOC_STRING;
|
|
map_["metadata_props"] = KeyWord::METADATA_PROPS;
|
|
map_["seq"] = KeyWord::SEQ_TYPE;
|
|
map_["map"] = KeyWord::MAP_TYPE;
|
|
map_["optional"] = KeyWord::OPTIONAL_TYPE;
|
|
map_["sparse_tensor"] = KeyWord::SPARSE_TENSOR_TYPE;
|
|
map_["overload"] = KeyWord::OVERLOAD_KW;
|
|
}
|
|
|
|
static const std::unordered_map<std::string, KeyWord>& Instance() {
|
|
static KeyWordMap instance;
|
|
return instance.map_;
|
|
}
|
|
|
|
static KeyWord Lookup(const std::string& id) {
|
|
auto it = Instance().find(id);
|
|
if (it != Instance().end())
|
|
return it->second;
|
|
return KeyWord::NONE;
|
|
}
|
|
|
|
static const std::string& ToString(KeyWord kw) {
|
|
static std::string undefined("undefined");
|
|
for (const auto& pair : Instance()) {
|
|
if (pair.second == kw)
|
|
return pair.first;
|
|
}
|
|
return undefined;
|
|
}
|
|
|
|
private:
|
|
std::unordered_map<std::string, KeyWord> map_;
|
|
};
|
|
|
|
class ParserBase {
|
|
public:
|
|
ParserBase(const std::string& str)
|
|
: start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
|
|
|
|
ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
|
|
|
|
void SavePos() {
|
|
saved_pos_ = next_;
|
|
}
|
|
|
|
void RestorePos() {
|
|
next_ = saved_pos_;
|
|
}
|
|
|
|
std::string GetCurrentPos() {
|
|
uint32_t line = 1, col = 1;
|
|
for (const char* p = start_; p < next_; ++p) {
|
|
if (*p == '\n') {
|
|
++line;
|
|
col = 1;
|
|
} else {
|
|
++col;
|
|
}
|
|
}
|
|
return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
|
|
}
|
|
|
|
// Return a suitable suffix of what has been parsed to provide error message context:
|
|
// return the line containing the last non-space character preceding the error (if it exists).
|
|
std::string GetErrorContext() {
|
|
// Special cases: empty input string, and parse-error at first character.
|
|
const char* p = next_ < end_ ? next_ : next_ - 1;
|
|
while ((p > start_) && isspace(*p))
|
|
--p;
|
|
while ((p > start_) && (*p != '\n'))
|
|
--p;
|
|
// Start at character after '\n' unless we are at start of input
|
|
const char* context_start = (p > start_) ? (p + 1) : start_;
|
|
for (p = context_start; (p < end_) && (*p != '\n'); ++p)
|
|
;
|
|
return std::string(context_start, p - context_start);
|
|
}
|
|
|
|
template <typename... Args>
|
|
Status ParseError(const Args&... args) {
|
|
return Status(
|
|
NONE,
|
|
FAIL,
|
|
ONNX_NAMESPACE::MakeString(
|
|
"[ParseError at position ", GetCurrentPos(), "]\n", "Error context: ", GetErrorContext(), "\n", args...));
|
|
}
|
|
|
|
void SkipWhiteSpace() {
|
|
do {
|
|
while ((next_ < end_) && (isspace(*next_)))
|
|
++next_;
|
|
if ((next_ >= end_) || ((*next_) != '#'))
|
|
return;
|
|
// Skip rest of the line:
|
|
while ((next_ < end_) && ((*next_) != '\n'))
|
|
++next_;
|
|
} while (true);
|
|
}
|
|
|
|
int NextChar(bool skipspace = true) {
|
|
if (skipspace)
|
|
SkipWhiteSpace();
|
|
return (next_ < end_) ? *next_ : 0;
|
|
}
|
|
|
|
bool Matches(char ch, bool skipspace = true) {
|
|
if (skipspace)
|
|
SkipWhiteSpace();
|
|
if ((next_ < end_) && (*next_ == ch)) {
|
|
++next_;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Status Match(char ch, bool skipspace = true) {
|
|
if (!Matches(ch, skipspace))
|
|
return ParseError("Expected character ", ch, " not found.");
|
|
return Status::OK();
|
|
}
|
|
|
|
bool EndOfInput() {
|
|
SkipWhiteSpace();
|
|
return (next_ >= end_);
|
|
}
|
|
|
|
enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
|
|
|
|
struct Literal {
|
|
LiteralType type;
|
|
std::string value;
|
|
};
|
|
|
|
Status Parse(Literal& result);
|
|
|
|
Status Parse(int64_t& val) {
|
|
Literal literal;
|
|
CHECK_PARSER_STATUS(Parse(literal));
|
|
if (literal.type != LiteralType::INT_LITERAL)
|
|
return ParseError("Integer value expected, but not found.");
|
|
std::string s = literal.value;
|
|
val = std::stoll(s);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Parse(uint64_t& val) {
|
|
Literal literal;
|
|
CHECK_PARSER_STATUS(Parse(literal));
|
|
if (literal.type != LiteralType::INT_LITERAL)
|
|
return ParseError("Integer value expected, but not found.");
|
|
std::string s = literal.value;
|
|
val = std::stoull(s);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Parse(float& val) {
|
|
Literal literal;
|
|
CHECK_PARSER_STATUS(Parse(literal));
|
|
switch (literal.type) {
|
|
case LiteralType::INT_LITERAL:
|
|
case LiteralType::FLOAT_LITERAL:
|
|
val = std::stof(literal.value);
|
|
break;
|
|
default:
|
|
return ParseError("Unexpected literal type.");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Parse(double& val) {
|
|
Literal literal;
|
|
CHECK_PARSER_STATUS(Parse(literal));
|
|
switch (literal.type) {
|
|
case LiteralType::INT_LITERAL:
|
|
case LiteralType::FLOAT_LITERAL:
|
|
val = std::stod(literal.value);
|
|
break;
|
|
default:
|
|
return ParseError("Unexpected literal type.");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// Parse a string-literal enclosed within doube-quotes.
|
|
Status Parse(std::string& val) {
|
|
Literal literal;
|
|
CHECK_PARSER_STATUS(Parse(literal));
|
|
if (literal.type != LiteralType::STRING_LITERAL)
|
|
return ParseError("String value expected, but not found.");
|
|
val = literal.value;
|
|
return Status::OK();
|
|
}
|
|
|
|
// Parse an identifier, including keywords. If none found, this will
|
|
// return an empty-string identifier.
|
|
Status ParseOptionalIdentifier(std::string& id) {
|
|
SkipWhiteSpace();
|
|
auto from = next_;
|
|
if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
|
|
++next_;
|
|
while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
|
|
++next_;
|
|
}
|
|
id = std::string(from, next_ - from);
|
|
return Status::OK();
|
|
}
|
|
|
|
Status ParseIdentifier(std::string& id) {
|
|
ParseOptionalIdentifier(id);
|
|
if (id.empty())
|
|
return ParseError("Identifier expected but not found.");
|
|
return Status::OK();
|
|
}
|
|
|
|
Status PeekIdentifier(std::string& id) {
|
|
SavePos();
|
|
ParseOptionalIdentifier(id);
|
|
RestorePos();
|
|
return Status::OK();
|
|
}
|
|
|
|
Status Parse(KeyWordMap::KeyWord& keyword) {
|
|
std::string id;
|
|
CHECK_PARSER_STATUS(ParseIdentifier(id));
|
|
keyword = KeyWordMap::Lookup(id);
|
|
return Status::OK();
|
|
}
|
|
|
|
protected:
|
|
const char* start_;
|
|
const char* next_;
|
|
const char* end_;
|
|
const char* saved_pos_;
|
|
|
|
bool NextIsValidFloatString();
|
|
};
|
|
|
|
class OnnxParser : public ParserBase {
|
|
public:
|
|
OnnxParser(const char* cstr) : ParserBase(cstr) {}
|
|
|
|
Status Parse(TensorShapeProto& shape);
|
|
|
|
Status Parse(TypeProto& typeProto);
|
|
|
|
Status Parse(StringStringList& stringStringList);
|
|
|
|
Status Parse(TensorProto& tensorProto);
|
|
|
|
Status Parse(AttributeProto& attr);
|
|
|
|
Status Parse(AttributeProto& attr, std::string& name);
|
|
|
|
Status Parse(AttrList& attrlist);
|
|
|
|
Status Parse(NodeProto& node);
|
|
|
|
Status Parse(NodeList& nodelist);
|
|
|
|
Status Parse(GraphProto& graph);
|
|
|
|
Status Parse(FunctionProto& fn);
|
|
|
|
Status Parse(ModelProto& model);
|
|
|
|
template <typename T>
|
|
static Status Parse(T& parsedData, const char* input) {
|
|
OnnxParser parser(input);
|
|
return parser.Parse(parsedData);
|
|
}
|
|
|
|
private:
|
|
Status Parse(std::string name, GraphProto& graph);
|
|
|
|
Status Parse(IdList& idlist);
|
|
|
|
Status Parse(char open, IdList& idlist, char close);
|
|
|
|
Status Parse(IdList& idlist, AttrList& attrlist);
|
|
|
|
Status Parse(char open, IdList& idlist, AttrList& attrlist, char close);
|
|
|
|
Status ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected);
|
|
|
|
Status Parse(ValueInfoProto& valueinfo);
|
|
|
|
Status ParseGraphInputOutput(ValueInfoList& vilist);
|
|
|
|
Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist);
|
|
|
|
Status Parse(char open, ValueInfoList& vilist, char close);
|
|
|
|
Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
|
|
|
|
Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
|
|
|
|
Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
|
|
|
|
Status Parse(OpsetIdList& opsets);
|
|
|
|
bool NextIsType();
|
|
|
|
bool NextIsIdentifier();
|
|
};
|
|
|
|
} // namespace ONNX_NAMESPACE
|