474 lines
13 KiB
C++
474 lines
13 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#include "onnx/defs/printer.h"
|
|
|
|
#include <iomanip>
|
|
#include <vector>
|
|
|
|
#include "onnx/defs/tensor_proto_util.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
|
|
using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
|
|
|
|
class ProtoPrinter {
|
|
public:
|
|
ProtoPrinter(std::ostream& os) : output_(os) {}
|
|
|
|
void print(const TensorShapeProto_Dimension& dim);
|
|
|
|
void print(const TensorShapeProto& shape);
|
|
|
|
void print(const TypeProto_Tensor& tensortype);
|
|
|
|
void print(const TypeProto& type);
|
|
|
|
void print(const TypeProto_Sequence& seqType);
|
|
|
|
void print(const TypeProto_Map& mapType);
|
|
|
|
void print(const TypeProto_Optional& optType);
|
|
|
|
void print(const TypeProto_SparseTensor& sparseType);
|
|
|
|
void print(const TensorProto& tensor, bool is_initializer = false);
|
|
|
|
void print(const ValueInfoProto& value_info);
|
|
|
|
void print(const ValueInfoList& vilist);
|
|
|
|
void print(const AttributeProto& attr);
|
|
|
|
void print(const AttrList& attrlist);
|
|
|
|
void print(const NodeProto& node);
|
|
|
|
void print(const NodeList& nodelist);
|
|
|
|
void print(const GraphProto& graph);
|
|
|
|
void print(const FunctionProto& fn);
|
|
|
|
void print(const ModelProto& model);
|
|
|
|
void print(const OperatorSetIdProto& opset);
|
|
|
|
void print(const OpsetIdList& opsets);
|
|
|
|
void print(const StringStringEntryProtos& stringStringProtos) {
|
|
printSet("[", ", ", "]", stringStringProtos);
|
|
}
|
|
|
|
void print(const StringStringEntryProto& metadata) {
|
|
printQuoted(metadata.key());
|
|
output_ << ": ";
|
|
printQuoted(metadata.value());
|
|
}
|
|
|
|
private:
|
|
template <typename T>
|
|
inline void print(T prim) {
|
|
output_ << prim;
|
|
}
|
|
|
|
void printQuoted(const std::string& str) {
|
|
output_ << "\"";
|
|
for (const char* p = str.c_str(); *p; ++p) {
|
|
if ((*p == '\\') || (*p == '"'))
|
|
output_ << '\\';
|
|
output_ << *p;
|
|
}
|
|
output_ << "\"";
|
|
}
|
|
|
|
template <typename T>
|
|
inline void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) {
|
|
if (addsep)
|
|
output_ << "," << std::endl;
|
|
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
|
|
print(val);
|
|
}
|
|
|
|
inline void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) {
|
|
output_ << "," << std::endl;
|
|
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
|
|
printQuoted(val);
|
|
}
|
|
|
|
template <typename Collection>
|
|
inline void printSet(const char* open, const char* separator, const char* close, Collection coll) {
|
|
const char* sep = "";
|
|
output_ << open;
|
|
for (auto& elt : coll) {
|
|
output_ << sep;
|
|
print(elt);
|
|
sep = separator;
|
|
}
|
|
output_ << close;
|
|
}
|
|
|
|
std::ostream& output_;
|
|
int indent_level = 3;
|
|
|
|
void indent() {
|
|
indent_level += 3;
|
|
}
|
|
|
|
void outdent() {
|
|
indent_level -= 3;
|
|
}
|
|
};
|
|
|
|
void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) {
|
|
if (dim.has_dim_value())
|
|
output_ << dim.dim_value();
|
|
else if (dim.has_dim_param())
|
|
output_ << dim.dim_param();
|
|
else
|
|
output_ << "?";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TensorShapeProto& shape) {
|
|
printSet("[", ",", "]", shape.dim());
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto_Tensor& tensortype) {
|
|
output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
|
|
if (tensortype.has_shape()) {
|
|
if (tensortype.shape().dim_size() > 0)
|
|
print(tensortype.shape());
|
|
} else
|
|
output_ << "[]";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto_Sequence& seqType) {
|
|
output_ << "seq(";
|
|
print(seqType.elem_type());
|
|
output_ << ")";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto_Map& mapType) {
|
|
output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", ";
|
|
print(mapType.value_type());
|
|
output_ << ")";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto_Optional& optType) {
|
|
output_ << "optional(";
|
|
print(optType.elem_type());
|
|
output_ << ")";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) {
|
|
output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type());
|
|
if (sparseType.has_shape()) {
|
|
if (sparseType.shape().dim_size() > 0)
|
|
print(sparseType.shape());
|
|
} else
|
|
output_ << "[]";
|
|
|
|
output_ << ")";
|
|
}
|
|
|
|
void ProtoPrinter::print(const TypeProto& type) {
|
|
if (type.has_tensor_type())
|
|
print(type.tensor_type());
|
|
else if (type.has_sequence_type())
|
|
print(type.sequence_type());
|
|
else if (type.has_map_type())
|
|
print(type.map_type());
|
|
else if (type.has_optional_type())
|
|
print(type.optional_type());
|
|
else if (type.has_sparse_tensor_type())
|
|
print(type.sparse_tensor_type());
|
|
}
|
|
|
|
void ProtoPrinter::print(const TensorProto& tensor, bool is_initializer) {
|
|
output_ << PrimitiveTypeNameMap::ToString(tensor.data_type());
|
|
if (tensor.dims_size() > 0)
|
|
printSet("[", ",", "]", tensor.dims());
|
|
|
|
if (!tensor.name().empty()) {
|
|
output_ << " " << tensor.name();
|
|
}
|
|
if (is_initializer) {
|
|
output_ << " = ";
|
|
}
|
|
// TODO: does not yet handle all types
|
|
if (tensor.has_data_location() && tensor.data_location() == TensorProto_DataLocation_EXTERNAL) {
|
|
print(tensor.external_data());
|
|
} else if (tensor.has_raw_data()) {
|
|
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
|
|
case TensorProto::DataType::TensorProto_DataType_INT32:
|
|
printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_INT64:
|
|
printSet(" {", ",", "}", ParseData<int64_t>(&tensor));
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
|
printSet(" {", ",", "}", ParseData<float>(&tensor));
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
|
printSet(" {", ",", "}", ParseData<double>(&tensor));
|
|
break;
|
|
default:
|
|
output_ << "..."; // ParseData not instantiated for other types.
|
|
break;
|
|
}
|
|
} else {
|
|
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
|
|
case TensorProto::DataType::TensorProto_DataType_INT8:
|
|
case TensorProto::DataType::TensorProto_DataType_INT16:
|
|
case TensorProto::DataType::TensorProto_DataType_INT32:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT8:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT16:
|
|
case TensorProto::DataType::TensorProto_DataType_BOOL:
|
|
printSet(" {", ",", "}", tensor.int32_data());
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_INT64:
|
|
printSet(" {", ",", "}", tensor.int64_data());
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_UINT32:
|
|
case TensorProto::DataType::TensorProto_DataType_UINT64:
|
|
printSet(" {", ",", "}", tensor.uint64_data());
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_FLOAT:
|
|
printSet(" {", ",", "}", tensor.float_data());
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
|
|
printSet(" {", ",", "}", tensor.double_data());
|
|
break;
|
|
case TensorProto::DataType::TensorProto_DataType_STRING: {
|
|
const char* sep = "{";
|
|
for (auto& elt : tensor.string_data()) {
|
|
output_ << sep;
|
|
printQuoted(elt);
|
|
sep = ", ";
|
|
}
|
|
output_ << "}";
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ProtoPrinter::print(const ValueInfoProto& value_info) {
|
|
print(value_info.type());
|
|
output_ << " " << value_info.name();
|
|
}
|
|
|
|
void ProtoPrinter::print(const ValueInfoList& vilist) {
|
|
printSet("(", ", ", ")", vilist);
|
|
}
|
|
|
|
void ProtoPrinter::print(const AttributeProto& attr) {
|
|
// Special case of attr-ref:
|
|
if (attr.has_ref_attr_name()) {
|
|
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name();
|
|
return;
|
|
}
|
|
// General case:
|
|
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = ";
|
|
switch (attr.type()) {
|
|
case AttributeProto_AttributeType_INT:
|
|
output_ << attr.i();
|
|
break;
|
|
case AttributeProto_AttributeType_INTS:
|
|
printSet("[", ", ", "]", attr.ints());
|
|
break;
|
|
case AttributeProto_AttributeType_FLOAT:
|
|
output_ << attr.f();
|
|
break;
|
|
case AttributeProto_AttributeType_FLOATS:
|
|
printSet("[", ", ", "]", attr.floats());
|
|
break;
|
|
case AttributeProto_AttributeType_STRING:
|
|
output_ << "\"" << attr.s() << "\"";
|
|
break;
|
|
case AttributeProto_AttributeType_STRINGS: {
|
|
const char* sep = "[";
|
|
for (auto& elt : attr.strings()) {
|
|
output_ << sep << "\"" << elt << "\"";
|
|
sep = ", ";
|
|
}
|
|
output_ << "]";
|
|
break;
|
|
}
|
|
case AttributeProto_AttributeType_GRAPH:
|
|
indent();
|
|
print(attr.g());
|
|
outdent();
|
|
break;
|
|
case AttributeProto_AttributeType_GRAPHS:
|
|
indent();
|
|
printSet("[", ", ", "]", attr.graphs());
|
|
outdent();
|
|
break;
|
|
case AttributeProto_AttributeType_TENSOR:
|
|
print(attr.t());
|
|
break;
|
|
case AttributeProto_AttributeType_TENSORS:
|
|
printSet("[", ", ", "]", attr.tensors());
|
|
break;
|
|
case AttributeProto_AttributeType_TYPE_PROTO:
|
|
print(attr.tp());
|
|
break;
|
|
case AttributeProto_AttributeType_TYPE_PROTOS:
|
|
printSet("[", ", ", "]", attr.type_protos());
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
void ProtoPrinter::print(const AttrList& attrlist) {
|
|
printSet(" <", ", ", ">", attrlist);
|
|
}
|
|
|
|
void ProtoPrinter::print(const NodeProto& node) {
|
|
output_ << std::setw(indent_level) << ' ';
|
|
printSet("", ", ", "", node.output());
|
|
output_ << " = ";
|
|
if (node.domain() != "")
|
|
output_ << node.domain() << ".";
|
|
output_ << node.op_type();
|
|
if (node.overload() != "")
|
|
output_ << ":" << node.overload();
|
|
bool has_subgraph = false;
|
|
for (auto attr : node.attribute())
|
|
if (attr.has_g() || (attr.graphs_size() > 0))
|
|
has_subgraph = true;
|
|
if ((!has_subgraph) && (node.attribute_size() > 0))
|
|
print(node.attribute());
|
|
printSet(" (", ", ", ")", node.input());
|
|
if ((has_subgraph) && (node.attribute_size() > 0))
|
|
print(node.attribute());
|
|
output_ << "\n";
|
|
}
|
|
|
|
void ProtoPrinter::print(const NodeList& nodelist) {
|
|
output_ << "{\n";
|
|
for (auto& node : nodelist) {
|
|
print(node);
|
|
}
|
|
if (indent_level > 3)
|
|
output_ << std::setw(indent_level - 3) << " ";
|
|
output_ << "}";
|
|
}
|
|
|
|
void ProtoPrinter::print(const GraphProto& graph) {
|
|
output_ << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
|
|
if ((graph.initializer_size() > 0) || (graph.value_info_size() > 0)) {
|
|
output_ << std::endl << std::setw(indent_level) << ' ' << '<';
|
|
const char* sep = "";
|
|
for (auto& init : graph.initializer()) {
|
|
output_ << sep;
|
|
print(init, true);
|
|
sep = ", ";
|
|
}
|
|
for (auto& vi : graph.value_info()) {
|
|
output_ << sep;
|
|
print(vi);
|
|
sep = ", ";
|
|
}
|
|
output_ << ">" << std::endl;
|
|
}
|
|
print(graph.node());
|
|
}
|
|
|
|
void ProtoPrinter::print(const ModelProto& model) {
|
|
output_ << "<\n";
|
|
printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false);
|
|
printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import());
|
|
if (model.has_producer_name())
|
|
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name());
|
|
if (model.has_producer_version())
|
|
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version());
|
|
if (model.has_domain())
|
|
printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain());
|
|
if (model.has_model_version())
|
|
printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version());
|
|
if (model.has_doc_string())
|
|
printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string());
|
|
if (model.metadata_props_size() > 0)
|
|
printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props());
|
|
output_ << std::endl << ">" << std::endl;
|
|
|
|
print(model.graph());
|
|
for (const auto& fn : model.functions()) {
|
|
output_ << std::endl;
|
|
print(fn);
|
|
}
|
|
}
|
|
|
|
void ProtoPrinter::print(const OperatorSetIdProto& opset) {
|
|
output_ << "\"" << opset.domain() << "\" : " << opset.version();
|
|
}
|
|
|
|
void ProtoPrinter::print(const OpsetIdList& opsets) {
|
|
printSet("[", ", ", "]", opsets);
|
|
}
|
|
|
|
void ProtoPrinter::print(const FunctionProto& fn) {
|
|
output_ << "<\n";
|
|
output_ << " "
|
|
<< "domain: \"" << fn.domain() << "\",\n";
|
|
if (!fn.overload().empty())
|
|
output_ << " "
|
|
<< "overload: \"" << fn.overload() << "\",\n";
|
|
|
|
output_ << " "
|
|
<< "opset_import: ";
|
|
printSet("[", ",", "]", fn.opset_import());
|
|
output_ << "\n>\n";
|
|
output_ << fn.name() << " ";
|
|
if (fn.attribute_size() > 0)
|
|
printSet("<", ",", ">", fn.attribute());
|
|
printSet("(", ", ", ")", fn.input());
|
|
output_ << " => ";
|
|
printSet("(", ", ", ")", fn.output());
|
|
output_ << "\n";
|
|
print(fn.node());
|
|
}
|
|
|
|
#define DEF_OP(T) \
|
|
std::ostream& operator<<(std::ostream& os, const T& proto) { \
|
|
ProtoPrinter printer(os); \
|
|
printer.print(proto); \
|
|
return os; \
|
|
};
|
|
|
|
DEF_OP(TensorShapeProto_Dimension)
|
|
|
|
DEF_OP(TensorShapeProto)
|
|
|
|
DEF_OP(TypeProto_Tensor)
|
|
|
|
DEF_OP(TypeProto)
|
|
|
|
DEF_OP(TensorProto)
|
|
|
|
DEF_OP(ValueInfoProto)
|
|
|
|
DEF_OP(ValueInfoList)
|
|
|
|
DEF_OP(AttributeProto)
|
|
|
|
DEF_OP(AttrList)
|
|
|
|
DEF_OP(NodeProto)
|
|
|
|
DEF_OP(NodeList)
|
|
|
|
DEF_OP(GraphProto)
|
|
|
|
DEF_OP(FunctionProto)
|
|
|
|
DEF_OP(ModelProto)
|
|
|
|
} // namespace ONNX_NAMESPACE
|