// Copyright (c) ONNX Project Contributors /* * SPDX-License-Identifier: Apache-2.0 */ #include "onnx/version_converter/convert.h" #include #include namespace ONNX_NAMESPACE { namespace version_conversion { ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) { // Get initial_opsetid from mp_in OpSetID initial_struct(0); for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { if (it->domain() == "" || it->domain() == "ai.onnx") { initial_struct.setVersion(it->version()); break; } } OpSetID target_struct = OpSetID(target_version); DefaultVersionConverter v; return v.convert_version(mp_in, initial_struct, target_struct); } void DefaultVersionConverter::convert_graph( std::shared_ptr g, const OpSetID& initial_version, const OpSetID& target_version) const { assertNonNull(g); // TODO: Move to Inter-Domain Converter // Get initial model versions // std::vector initial_versions = g->opset_versions_mutable(); // No conversion necessary if Model has single, equivalent opset version // if (initial_versions.size() == 1 && initial_versions[0].version == // target_version.version && initial_versions[0].domain == // target_version.domain) { // return mp_in; // } // Check if versions are valid assertInVersionRange(initial_version.version()); assertInVersionRange(target_version.version()); // Iterate over all versions to target_version for specified int64_t curr_version = initial_version.version(); int64_t step; if (target_version.version() > initial_version.version()) { step = 1; } else { step = -1; } // Identify index of this domain in g.opset_versions unsigned int domain_index = 0; for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) { if (g->opset_versions_mutable()[i].domain() == "") { domain_index = i; } } while (curr_version != target_version.version()) { debug( "curr_version: " + ONNX_NAMESPACE::to_string(curr_version) + ", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step)); Node* cur_op; graph_node_list_iterator it = g->begin(); // Iterate through and call adapter returned by adapter_lookup for ops from // current_version opset. We have to manipulate the iterator explicitly because cur_op // might change when applying the adapter (e.g. for deprecated ops) while (it != g->end()) { cur_op = *it; debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString())); const std::string op_name = cur_op->kind().toString(); if (op_name == "ConstantFill") { if (DEBUG) { std::cerr << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. " "Please be advised the converted model may not be working properly if target runtime does not support this " "experimental op." << std::endl; } } else if (cur_op->domain() != "" && cur_op->domain() != "ai.onnx") { if (DEBUG) { std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << std::endl; } } else if (op_name != "Undefined" && op_name != "Captured") { auto& op_domain_map = all_schemas.at(op_name); OpSetID curr_id(curr_version); OpSetID next_id(curr_version + step); if (searchOpDomainMap(op_domain_map, curr_version, step)) { // Op is specifically defined for this domain and version auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id); // If adapter_lookup returns null, no adapter is present. // Error thrown by adapter_lookup if (DEBUG) { std::cerr << "Applying adapter" << std::endl; } // adapt should handle replacing node in graph cur_op = op_adapter.adapt(g, cur_op); it = graph_node_list_iterator(cur_op, kNextDirection); } // Recursively convert any subgraph attributes for (const auto& attr : cur_op->attributeNames()) { if (cur_op->kindOf(attr) == AttributeKind::g) { convert_graph(cur_op->g(attr), curr_id, next_id); } } } it++; } // Update model version curr_version += step; g->opset_versions_mutable()[domain_index].incrementVersion(step); } } ModelProto DefaultVersionConverter::convert_version( const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const { const std::string& initial_domain = initial_version.domain(); const std::string& target_domain = target_version.domain(); assertDefaultDomain(initial_domain, target_domain); for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { if (it->domain() == initial_version.domain()) { ONNX_ASSERTM( initial_version.version() == it->version(), "initial_version does not reflect current state of model"); } } std::shared_ptr g(ImportModelProto(mp_in)); convert_graph(g, initial_version, target_version); // Export g as ModelProto debug("Finished conversion; returning model"); ModelProto mp_out = PrepareOutput(mp_in); ExportModelProto(&mp_out, g); return mp_out; } } // namespace version_conversion } // namespace ONNX_NAMESPACE