176 lines
5.4 KiB
C++
176 lines
5.4 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#include "onnx/defs/function.h"
|
|
|
|
#include <map>
|
|
|
|
#include "onnx/defs/schema.h"
|
|
#include "onnx/string_utils.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
std::string InteralTensorNameGenerator(const std::string& node_name, const std::string& internal_name) {
|
|
std::string new_name = "Func_" + node_name + internal_name;
|
|
return new_name;
|
|
}
|
|
|
|
void FunctionExpandHelper(
|
|
const NodeProto& node,
|
|
const FunctionProto& func,
|
|
GraphProto& g,
|
|
const std::string& node_prefix) {
|
|
// Create a temporary unique node prefix for tensor names
|
|
std::string uniq_prefix = node_prefix;
|
|
if (uniq_prefix.empty()) {
|
|
const void* address = static_cast<const void*>(&node);
|
|
std::stringstream ss;
|
|
ss << address;
|
|
uniq_prefix = ss.str();
|
|
}
|
|
std::string node_name = node.has_name() ? node.name() : func.name() + uniq_prefix;
|
|
std::unordered_map<std::string, std::string> io_names_map;
|
|
std::unordered_map<std::string, AttributeProto> attr_map;
|
|
|
|
for (int idx = 0; idx < node.input_size(); ++idx) {
|
|
if (idx >= func.input_size()) {
|
|
ONNX_THROW("Input for function node " + node_name + " is out of bounds");
|
|
}
|
|
io_names_map[func.input().Get(idx)] = node.input().Get(idx);
|
|
}
|
|
for (int idx = 0; idx < node.output_size(); ++idx) {
|
|
if (idx >= func.output_size()) {
|
|
ONNX_THROW("Output for function node " + node_name + " is out of bounds");
|
|
}
|
|
// If the node output is missing, the corresponding function output should
|
|
// be treated as an internal value (not as missing) because it could also be
|
|
// an intermediate value.
|
|
if (node.output().Get(idx) == "") {
|
|
continue;
|
|
}
|
|
io_names_map[func.output().Get(idx)] = node.output().Get(idx);
|
|
}
|
|
|
|
for (auto& attr : node.attribute()) {
|
|
attr_map[attr.name()] = attr;
|
|
}
|
|
|
|
// For undefined attributes of the function node
|
|
// add default values obtained from the function schema.
|
|
// get the domain version for function schema
|
|
int domain_version = -1;
|
|
for (const auto& opset_import : func.opset_import()) {
|
|
if (opset_import.domain() == node.domain()) {
|
|
domain_version = static_cast<int>(opset_import.version());
|
|
}
|
|
}
|
|
if (domain_version == -1) {
|
|
ONNX_THROW("No opset import registered for domain '" + node.domain() + "' in function proto");
|
|
}
|
|
|
|
const OpSchemaRegistry* schema_registry = OpSchemaRegistry::Instance();
|
|
const auto schema = schema_registry->GetSchema(node.op_type(), domain_version, node.domain());
|
|
std::map<std::string, OpSchema::Attribute> default_attrs = schema->attributes();
|
|
|
|
for (const auto& pair : default_attrs) {
|
|
const auto& attr_name = pair.first;
|
|
const auto& attr = pair.second;
|
|
if (!attr_map.count(attr_name)) {
|
|
attr_map[attr_name] = attr.default_value;
|
|
}
|
|
}
|
|
|
|
for (auto& function_node : func.node()) {
|
|
NodeProto* new_node = g.add_node();
|
|
new_node->CopyFrom(function_node);
|
|
new_node->clear_input();
|
|
new_node->clear_output();
|
|
new_node->clear_attribute();
|
|
for (auto& input : function_node.input()) {
|
|
if (io_names_map.count(input)) {
|
|
new_node->add_input(io_names_map[input]);
|
|
} else {
|
|
new_node->add_input(InteralTensorNameGenerator(node_name, input));
|
|
}
|
|
}
|
|
for (auto& output : function_node.output()) {
|
|
if (io_names_map.count(output)) {
|
|
new_node->add_output(io_names_map[output]);
|
|
} else {
|
|
new_node->add_output(InteralTensorNameGenerator(node_name, output));
|
|
}
|
|
}
|
|
for (auto& attr : function_node.attribute()) {
|
|
if (attr.has_ref_attr_name()) {
|
|
if (attr_map.count(attr.ref_attr_name())) {
|
|
AttributeProto* new_attr = new_node->add_attribute();
|
|
new_attr->CopyFrom(attr_map[attr.ref_attr_name()]);
|
|
new_attr->set_name(attr.name());
|
|
}
|
|
} else {
|
|
AttributeProto* new_attr = new_node->add_attribute();
|
|
new_attr->CopyFrom(attr);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<NodeProto> FunctionBodyHelper::BuildNodes(const std::vector<NodeDef>& node_defs) {
|
|
std::vector<NodeProto> nodes(node_defs.size());
|
|
|
|
for (size_t i = 0; i < node_defs.size(); i++) {
|
|
const NodeDef& node = node_defs[i];
|
|
NodeProto& n = nodes[i];
|
|
|
|
n.set_op_type(node.op_type);
|
|
n.set_domain(node.domain);
|
|
for (const auto& i : node.inputs) {
|
|
n.add_input(i);
|
|
}
|
|
for (const auto& o : node.outputs) {
|
|
n.add_output(o);
|
|
}
|
|
for (const auto& attr : node.attributes) {
|
|
*(n.add_attribute()) = attr.proto;
|
|
}
|
|
}
|
|
|
|
return nodes;
|
|
}
|
|
|
|
void FunctionBodyHelper::BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs) {
|
|
for (size_t i = 0; i < node_defs.size(); i++) {
|
|
const NodeDef& node = node_defs[i];
|
|
auto* np = functionProto.add_node();
|
|
|
|
np->set_op_type(node.op_type);
|
|
np->set_domain(node.domain);
|
|
for (const auto& inp : node.inputs) {
|
|
np->add_input(inp);
|
|
}
|
|
for (const auto& o : node.outputs) {
|
|
np->add_output(o);
|
|
}
|
|
for (const auto& attr : node.attributes) {
|
|
*(np->add_attribute()) = attr.proto;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool FunctionBodyHelper::BuildFunctionProto(
|
|
FunctionProto& functionProto,
|
|
const OpSchema& schema,
|
|
const std::vector<NodeDef>& node_defs,
|
|
const std::vector<OperatorSetIdProto>& relied_opsets) {
|
|
BuildNodes(functionProto, node_defs);
|
|
|
|
for (auto& relied_opset : relied_opsets) {
|
|
*(functionProto.mutable_opset_import()->Add()) = relied_opset;
|
|
}
|
|
|
|
schema.BuildFunction(functionProto);
|
|
return true;
|
|
}
|
|
|
|
} // namespace ONNX_NAMESPACE
|