I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,96 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Version converter interface for ONNX models between different opset versions.
#pragma once
#include <stdlib.h>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "onnx/common/ir.h"
#include "onnx/common/ir_pb_converter.h"
#include "onnx/defs/schema.h"
#include "onnx/proto_utils.h"
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
// TODO: Consider creating interface for this class.
class BaseVersionConverter {
// Schema for adapters: {<op_name>:{<from_domain>$<from_version>:{<to_domain>
// <to_version>: adapter}}}
protected:
std::unordered_map<
std::string,
std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<Adapter>>>>
adapters;
// Map of All Versions of format {op_name: {domain: {version: schema}}}
std::unordered_map<std::string, std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>> all_schemas;
public:
BaseVersionConverter() = default;
virtual ~BaseVersionConverter() = default;
// adapter_lookup should be called in convert_version when the user would
// like to identify the proper registered adapter in the adapters map for
// a given Node from a certain version to another. It should only be called
// when the user knows that an adapter should exist for the given context.
const Adapter& adapter_lookup(const Node* op, const OpSetID& initial_version, const OpSetID& target_version) const {
const std::string op_name = op->kind().toString();
const std::string initial = initial_version.toString();
const std::string target = target_version.toString();
// Find appropriate adapter in adapters map for provided initial and target versions
// TODO: Consider abstracting elements of this that are specific to
// DefaultConverter to separate methods here and maintain the procedure in Base Converter
const auto op_adapters = adapters.find(op_name);
if (op_adapters != adapters.end()) {
// If we're adapting downwards, we just want to find the one downwards
// adapter implemented for initial_version. If we're adapting upwards, we
// want to actually use the SinceVersion value for the given op.
const auto target_map = op_adapters->second.find(initial);
if (target_map != op_adapters->second.end()) {
// Either adapt from SinceVersion or Incompatible Breaking Change
const auto adapter_ptr = target_map->second.find(target);
if (adapter_ptr != target_map->second.end()) {
return *(adapter_ptr->second);
} else {
ONNX_ASSERTM(false, "No Adapter To Version %s for %s", target.c_str(), op_name.c_str());
}
} else {
ONNX_ASSERTM(false, "No Adapter From Version %s for %s", initial.c_str(), op_name.c_str());
}
} else {
// No adapters exist for the given op
ONNX_ASSERTM(false, "No Adapter For %s", op_name.c_str());
}
}
virtual ModelProto
convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version) const = 0;
void registerAdapter(std::unique_ptr<Adapter> a_ptr) {
const OpSetID& iv = a_ptr->initial_version();
const OpSetID& tv = a_ptr->target_version();
adapters[a_ptr->name()][iv.toString()][tv.toString()] = std::move(a_ptr);
}
void registerAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer) {
registerAdapter(std::make_unique<GenericAdapter>(op, from, to, transformer));
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,68 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Interface for Op Version Adapters
#pragma once
#include <functional>
#include <memory>
#include <string>
#include "onnx/onnx_pb.h"
#include "onnx/version_converter/helper.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Adapter {
private:
std::string name_;
OpSetID initial_version_;
OpSetID target_version_;
public:
virtual ~Adapter() noexcept = default;
explicit Adapter(const std::string& name, const OpSetID& initial_version, const OpSetID& target_version)
: name_(name), initial_version_(initial_version), target_version_(target_version) {}
// This will almost always return its own node argument after modifying it in place.
// The only exception are adapters for deprecated operators: in this case the input
// node must be destroyed and a new one must be created and returned. See e.g.
// upsample_9_10.h
virtual Node* adapt(std::shared_ptr<Graph> /*graph*/, Node* node) const = 0;
const std::string& name() const {
return name_;
}
const OpSetID& initial_version() const {
return initial_version_;
}
const OpSetID& target_version() const {
return target_version_;
}
};
using NodeTransformerFunction = std::function<Node*(std::shared_ptr<Graph>, Node* node)>;
class GenericAdapter final : public Adapter {
public:
GenericAdapter(const char* op, int64_t from, int64_t to, NodeTransformerFunction transformer)
: Adapter(op, OpSetID(from), OpSetID(to)), transformer_(transformer) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
return transformer_(graph, node);
}
private:
NodeTransformerFunction transformer_;
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,50 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class AxesAttributeToInput : public Adapter {
public:
explicit AxesAttributeToInput(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
auto& data = t.int64s();
for (auto a : axes) {
data.emplace_back(a);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
if (node->hasAttribute(kaxes)) {
attrToInput(graph, node, node->is(kaxes));
node->removeAttribute(kaxes);
}
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,71 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class AxesInputToAttribute : public Adapter {
public:
explicit AxesInputToAttribute(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// Identify if axes is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();
// Get axes from initializer or constant operator
// Identify whether we have a Constant Op or an Initializer
Value* const_val = inputs[1];
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node_ptr->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
node->is_(kaxes, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
} else {
node->is_(kaxes, std::forward<const std::vector<int64_t>>(int64s));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(1);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
}
} else {
// Get Value name, find Initializer with same name
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->is_(kaxes, std::forward<const std::vector<int64_t>>(initializer.int64s()));
node->removeInput(1);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
break;
}
}
}
ONNX_ASSERTM(node->hasAttribute(kaxes), "No initializer or constant input to node found");
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,74 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class AxisAttributeToInput : public Adapter {
public:
AxisAttributeToInput(
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
size_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
if (node->hasAttribute(kaxis)) {
AttrToInput(graph, node, node->i(kaxis), this->axis_index);
node->removeAttribute(kaxis);
return node;
}
// Fill in the default value for axis
AttrToInput(graph, node, default_axis, this->axis_index);
return node;
}
private:
size_t axis_index;
int64_t default_axis;
void AttrToInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis, size_t axis_index) const {
const ArrayRef<Value*>& inputs = node->inputs();
// Add the optional inputs if they don't exist
for (size_t i = inputs.size(); i < axis_index; ++i) {
Node* empty_input = graph->create(kUndefined);
empty_input->insertBefore(node);
node->addInput(empty_input->output());
}
// Add the axis input
Node* constant = CreateAxisInput(graph, node, axis);
node->addInput(constant->output());
}
Node* CreateAxisInput(std::shared_ptr<Graph> graph, Node* node, int64_t axis) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{};
auto& data = t.int64s();
data.emplace_back(axis);
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
return constant;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,99 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class AxisInputToAttribute : public Adapter {
public:
explicit AxisInputToAttribute(
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
size_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
if (!HasAxisInput(node)) {
node->i_(kaxis, this->default_axis);
return EnsureAndReturnNode(node);
}
const ArrayRef<Value*>& inputs = node->inputs();
Value* axis_val = inputs[this->axis_index];
Node* axis_node = axis_val->node();
if (axis_node->kind() == kConstant) {
HandleConstantNode(node, axis_node, axis_val);
return EnsureAndReturnNode(node);
}
if (graph->is_constant_initializer(axis_val)) {
HandleInitializerNode(graph, node, axis_val);
return EnsureAndReturnNode(node);
}
ONNX_ASSERTM(false, "Axis input must be a constant or initializer for promotion to attribute.");
}
private:
size_t axis_index;
int64_t default_axis;
bool HasAxisInput(const Node* node) const {
const ArrayRef<const Value*>& inputs = node->inputs();
return inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined;
}
void HandleConstantNode(Node* node, Node* axis_node, Value* axis_val) const {
const std::vector<int64_t>& int64s = axis_node->t(kvalue).int64s();
if (int64s.empty()) {
std::string raw_data = axis_node->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
const int64_t* raw = reinterpret_cast<const int64_t*>(raw_data.c_str());
node->i_(kaxis, raw[0]);
} else {
node->i_(kaxis, int64s.at(0));
}
node->removeInput(this->axis_index);
if (axis_val->uses().size() < 1) {
axis_node->destroy();
}
}
void HandleInitializerNode(std::shared_ptr<Graph> graph, Node* node, Value* axis_val) const {
const std::string initializer_name = axis_val->uniqueName();
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == initializer_name) {
node->i_(kaxis, initializer.int64s().at(0));
node->removeInput(this->axis_index);
// Remove initializer
if (axis_val->uses().size() < 1)
graph->eraseInitializer(initializer_name);
break;
}
}
}
inline Node* EnsureAndReturnNode(Node* node) const {
ONNX_ASSERTM(node->hasAttribute(kaxis), "Axis attribute not created. This may be a bug.");
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,34 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for BatchNormalization in default domain from version 13 to 14
#pragma once
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class BatchNormalization_13_14 final : public Adapter {
public:
explicit BatchNormalization_13_14() : Adapter("BatchNormalization", OpSetID(13), OpSetID(14)) {}
void adapt_batch_normalization_13_14(Node* node) const {
ONNX_ASSERTM(
node->outputs().size() < 4,
"BatchNormalization outputs 4 and 5 are not "
"supported in Opset 14.");
}
Node* adapt(std::shared_ptr<Graph>, Node* node) const override {
adapt_batch_normalization_13_14(node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,60 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for broadcasting ops in default domain from version 7 to 6
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class BroadcastBackwardCompatibility final : public Adapter {
public:
explicit BroadcastBackwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
void adapt_broadcast_backward_compatibility(std::shared_ptr<Graph>, Node* node) const {
// Verify that broadcasts are allowed in limited spec of opset version 6
// Multidirectional broadcasting, as defined in Broadcasting.md
// MathDocGenerator provides differences
// Main change: encode broadcasting commands as explicit attribute
const ArrayRef<Value*>& inputs = node->inputs();
assertInputsAvailable(inputs, name().c_str(), 2);
const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
// Ensure that first input is larger than or equal to the second
// numpy_unibroadcastable here is considered to be equivalent to opset1_broadcastable
// This is because backwards conversion does not allow for an axis that is not
// suffix matching
int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(A_sizes, B_sizes);
ONNX_ASSERTM(
req_broadcast != -1,
"%s being converted from %d to %d does "
"not have broadcastable inputs.",
name().c_str(),
initial_version().version(),
target_version().version());
if (req_broadcast == 1) {
// If conditional is not fulfilled, we have a default broadcast
// Add broadcast attribute
node->i_(kbroadcast, 1);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_broadcast_backward_compatibility(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,87 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for broadcasting ops in default domain from version 6 to 7
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class BroadcastForwardCompatibility final : public Adapter {
public:
explicit BroadcastForwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
void adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph, Node* node) const {
// Remove axis and broadcast attributes
// Assess whether axis requires reshaping
if (node->hasAttribute(kbroadcast)) {
const ArrayRef<Value*>& inputs = node->inputs();
assertInputsAvailable(inputs, name().c_str(), 2);
const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
// Also assert that broadcasting syntax are correct if axis is not present
if (node->hasAttribute(kaxis)) {
if (node->i(kaxis) != (int)(A_sizes.size() - B_sizes.size())) {
// Add a Reshape node before input B
Node* n = graph->create(kUnsqueeze);
n->addInput(inputs[1]);
std::vector<int64_t> axes;
std::vector<Dimension> new_sizes = B_sizes;
auto size = A_sizes.size() > B_sizes.size() ? A_sizes.size() - B_sizes.size() : 0;
axes.reserve(size);
new_sizes.reserve(new_sizes.size() + size);
for (size_t i = 0; i < size; i++) {
axes.emplace_back(B_sizes.size() + i);
new_sizes.emplace_back(Dimension(1));
}
if (target_version().version() >= 13) { // Unsqueeze takes 'axes' input
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
auto& data = t.int64s();
for (auto a : axes) {
data.emplace_back(a);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
} else { // Unsqueeze takes 'axes' attribute
n->is_(kaxes, std::forward<const std::vector<int64_t>>(axes));
}
// Move n before node
n->insertBefore(node);
// Set 2nd input to node to 1st of n and output of n to 2nd input to node
n->output()->setSizes(new_sizes);
node->replaceInput(1, n->output());
}
}
node->removeAttribute(kbroadcast);
}
if (node->hasAttribute(kaxis))
node->removeAttribute(kaxis);
// Assert multi_broadcastable on inputs
const ArrayRef<Value*>& inputs = node->inputs();
assert_numpy_multibroadcastable(inputs[0]->sizes(), inputs[1]->sizes());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_broadcast_forward_compatibility(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,34 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Cast in default domain from version 9 to 8
#pragma once
#include <memory>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Cast_9_8 final : public Adapter {
public:
explicit Cast_9_8() : Adapter("Cast", OpSetID(9), OpSetID(8)) {}
void adapt_cast_9_8(std::shared_ptr<Graph>, Node* node) const {
if (node->inputs()[0]->elemType() == TensorProto_DataType_STRING || node->i(kto) == TensorProto_DataType_STRING)
ONNX_ASSERTM(false, "Casting From/To STRING data type is not supported")
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_cast_9_8(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,56 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Clip in default domain from version 10 to 11
#pragma once
#include <limits>
#include <memory>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Clip_10_11 final : public Adapter {
public:
explicit Clip_10_11() : Adapter("Clip", OpSetID(10), OpSetID(11)) {}
void adapt_clip_10_11(std::shared_ptr<Graph> graph, Node* node) const {
bool has_min = node->hasAttribute(kmin);
bool has_max = node->hasAttribute(kmax);
// Turn min/max attributes into tensor (if present) and add value as input
if (has_min) {
attrToInput(graph, node, node->f(kmin));
node->removeAttribute(kmin);
}
if (has_max) {
if (!has_min) {
attrToInput(graph, node, std::numeric_limits<float>::lowest());
}
attrToInput(graph, node, node->f(kmax));
node->removeAttribute(kmax);
}
}
void attrToInput(std::shared_ptr<Graph> graph, Node* node, float val) const {
Tensor t;
t.elem_type() = TensorProto_DataType_FLOAT;
auto& data = t.floats();
data.emplace_back(val);
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_clip_10_11(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,30 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter indicating compatibility of op between opsets with separate
// definitions
#pragma once
#include <memory>
#include <string>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct CompatibleAdapter final : public Adapter {
explicit CompatibleAdapter(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* adapt(std::shared_ptr<Graph>, Node* node) const override {
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,46 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Dropout in default domain from version 11 to 12
#pragma once
#include <memory>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Dropout_11_12 final : public Adapter {
public:
explicit Dropout_11_12() : Adapter("Dropout", OpSetID(11), OpSetID(12)) {}
void adapt_dropout_11_12(std::shared_ptr<Graph> graph, Node* node) const {
float ratio;
if (node->hasAttribute(kratio)) {
ratio = node->f(kratio);
node->removeAttribute(kratio);
} else {
ratio = 0.5;
}
Tensor t_ratio;
t_ratio.elem_type() = TensorProto_DataType_FLOAT;
auto& data_ratio = t_ratio.floats();
data_ratio.emplace_back(ratio);
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t_ratio);
node->addInput(constant->output());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_dropout_11_12(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,106 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter indicating compatibility of op between opsets with separate
// definitions
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct ExtendSupportedTypes final : public Adapter {
explicit ExtendSupportedTypes(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* create_cast_op(
std::shared_ptr<Graph> graph,
ArrayRef<Value*> inputs,
const int to_type,
const std::vector<Dimension>& output_shape,
const std::string& name) const {
Node* node = graph->create(kCast, inputs);
node->i_(kto, to_type);
node->output()->setUniqueName(name);
node->output()->setSizes(output_shape);
node->output()->setElemType(to_type);
return node;
}
void adapt_type_extension(std::shared_ptr<Graph> graph, Node* node) const {
const ArrayRef<Value*>& inputs = node->inputs();
const ArrayRef<Value*>& outputs = node->outputs();
const std::string original_output_name = node->output()->uniqueName();
const int input_type = inputs.size() > 0 ? inputs[0]->elemType() : -1;
const int output_type = outputs[0]->elemType();
const std::unordered_set<int>& supported_version8_types = {
TensorProto_DataType::TensorProto_DataType_FLOAT,
TensorProto_DataType::TensorProto_DataType_FLOAT16,
TensorProto_DataType::TensorProto_DataType_DOUBLE,
};
const std::unordered_set<int>& unsupported_version9_types = {
TensorProto_DataType::TensorProto_DataType_COMPLEX128,
TensorProto_DataType::TensorProto_DataType_COMPLEX64,
TensorProto_DataType::TensorProto_DataType_STRING,
};
ONNX_ASSERTM(
unsupported_version9_types.find(input_type) == unsupported_version9_types.end(), "Unsupported Input Type");
ONNX_ASSERTM(
unsupported_version9_types.find(output_type) == unsupported_version9_types.end(), "Unsupported Output Type");
bool castInput = (node->kind() != kConstant);
bool castOutput = (node->kind() != kGreater && node->kind() != kLess);
if (castInput && supported_version8_types.find(input_type) == supported_version8_types.end()) {
for (size_t i = 0; i < inputs.size(); i++) {
Node* pre_cast = create_cast_op(
graph,
inputs[i],
TensorProto_DataType::TensorProto_DataType_FLOAT,
inputs[i]->sizes(),
"pre_cast_" + ONNX_NAMESPACE::to_string(i));
pre_cast->insertBefore(node);
node->replaceInput(i, pre_cast->output());
}
}
if (castOutput && supported_version8_types.find(output_type) == supported_version8_types.end()) {
const use_list original_uses(node->output()->uses());
node->output()->setElemType(TensorProto_DataType::TensorProto_DataType_FLOAT);
node->output()->setUniqueName(original_output_name + "_intermediate_output");
Node* post_cast = create_cast_op(graph, outputs[0], output_type, outputs[0]->sizes(), original_output_name);
post_cast->insertAfter(node);
for (Use u : original_uses) {
u.user->replaceInputWith(node->output(), post_cast->output());
}
for (size_t i = 0; i < graph->outputs().size(); i++) {
if (graph->outputs()[i]->uniqueName() == node->output()->uniqueName()) {
graph->return_node()->replaceInput(i, post_cast->output());
}
}
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_type_extension(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,57 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Gemm in default domain from version 6 to 7
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Gemm_6_7 final : public Adapter {
public:
explicit Gemm_6_7() : Adapter("Gemm", OpSetID(6), OpSetID(7)) {}
void adapt_gemm_6_7(std::shared_ptr<Graph>, Node* node) const {
const ArrayRef<Value*>& inputs = node->inputs();
assertInputsAvailable(inputs, name().c_str(), 3);
const auto& A_shape = inputs[0]->sizes();
const auto& B_shape = inputs[1]->sizes();
// Determine if C is broadcastable
const auto& C_shape = inputs[2]->sizes();
// Create (M, N) to input to numpy_unibroadcastable
std::vector<Dimension> MN;
if (node->hasAttribute(ktransA) && node->i(ktransA) == 1) {
MN.emplace_back(A_shape[1]);
} else {
MN.emplace_back(A_shape[0]);
}
if (node->hasAttribute(ktransB) && node->i(ktransB) == 1) {
MN.emplace_back(B_shape[0]);
} else {
MN.emplace_back(B_shape[1]);
}
ONNX_ASSERTM(
check_numpy_unibroadcastable_and_require_broadcast(MN, C_shape) != -1,
"Gemm being converted from 6 to 7 does not have "
"broadcastable inputs.");
if (node->hasAttribute(kbroadcast))
node->removeAttribute(kbroadcast);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_gemm_6_7(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,63 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Gemm in default domain from version 7 to 6
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Gemm_7_6 final : public Adapter {
public:
explicit Gemm_7_6() : Adapter("Gemm", OpSetID(7), OpSetID(6)) {}
void adapt_gemm_7_6(std::shared_ptr<Graph>, Node* node) const {
const ArrayRef<Value*>& inputs = node->inputs();
assertInputsAvailable(inputs, name().c_str(), 3);
const auto& A_shape = inputs[0]->sizes();
const auto& B_shape = inputs[1]->sizes();
// Determine if C is broadcastable
const auto& C_shape = inputs[2]->sizes();
// Create (M, N) to input to numpy_unibroadcastable
// TODO: Reconcile fact that shapes aren't determined for 1st 2 inputs
std::vector<Dimension> MN;
if (node->hasAttribute(ktransA) && node->i(ktransA) == 1) {
MN.emplace_back(A_shape[1]);
} else {
MN.emplace_back(A_shape[0]);
}
if (node->hasAttribute(ktransB) && node->i(ktransB) == 1) {
MN.emplace_back(B_shape[0]);
} else {
MN.emplace_back(B_shape[1]);
}
int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(MN, C_shape);
ONNX_ASSERTM(
req_broadcast != -1,
"%s being converted from %d to %d does "
"not have broadcastable inputs.",
name().c_str(),
initial_version().version(),
target_version().version());
if (req_broadcast == 1) {
node->i_(kbroadcast, 1);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_gemm_7_6(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,36 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for GridSample in default domain from version 19 to 20
#pragma once
#include <memory>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class GridSample_19_20 final : public Adapter {
public:
explicit GridSample_19_20() : Adapter("GridSample", OpSetID(19), OpSetID(20)) {}
void adapt_gridsample_19_20(std::shared_ptr<Graph>, Node* node) const {
if (node->hasAttribute(kmode) && (node->s(kmode) == "bilinear")) {
node->s_(kmode, "linear");
}
if (node->hasAttribute(kmode) && (node->s(kmode) == "bicubic")) {
node->s_(kmode, "cubic");
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_gridsample_19_20(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,128 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for GroupNormalization in default domain from version 20 to 21
#pragma once
#include <memory>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class GroupNormalization_20_21 final : public Adapter {
public:
explicit GroupNormalization_20_21() : Adapter("GroupNormalization", OpSetID(20), OpSetID(21)) {}
void transform_input(
std::shared_ptr<Graph> graph,
Node* node,
int64_t input_id,
Value* reshape0_shape,
Value* reshape1_shape,
Value* expand_shape) const {
Node* reshape0 = graph->create(kReshape);
reshape0->addInput(node->inputs()[input_id]);
reshape0->addInput(reshape0_shape);
reshape0->insertBefore(node);
Node* expand = graph->create(kExpand);
expand->addInput(reshape0->output());
expand->addInput(expand_shape);
expand->insertBefore(node);
Node* reshape1 = graph->create(kReshape);
reshape1->addInput(expand->output());
reshape1->addInput(reshape1_shape);
reshape1->insertBefore(node);
node->replaceInput(input_id, reshape1->output());
}
void adapt_group_normalization_20_21(std::shared_ptr<Graph> graph, Node* node) const {
// Perform following sequence of ops on scale/bias, effect is similar to numpy.repeat()
//
// Shape<start=1,end=2>(input0) -- Div(Shape_out (C), num_groups)
// |
// Reshape(input1/2, [-1, 1]) ----------- Expand(Reshape_out, [1, Div_out]) -- Reshape(Expand_out, [-1])
//
// The helper function transform_input() implements the bottom row of the diagram
// Get number of channels: C
Symbol kShape("Shape");
Node* C = graph->create(kShape);
C->i_(kstart, 1);
C->i_(kend, 2);
C->addInput(node->inputs()[0]);
C->insertBefore(node);
// Get number of channels per group
Tensor tensor_num_groups;
tensor_num_groups.elem_type() = TensorProto_DataType_INT64;
int64_t num_groups = node->i(knum_groups);
tensor_num_groups.sizes() = {1};
tensor_num_groups.int64s() = {num_groups};
Node* constant_num_groups = graph->create(kConstant);
constant_num_groups->t_(kvalue, tensor_num_groups);
constant_num_groups->insertBefore(node);
Node* div = graph->create(kDiv);
div->addInput(C->output());
div->addInput(constant_num_groups->output());
div->insertBefore(node);
// Get Expand shape: [1, Div_out]
Tensor tensor_one;
tensor_one.elem_type() = TensorProto_DataType_INT64;
tensor_one.sizes() = {1};
tensor_one.int64s() = {1};
Node* constant_one = graph->create(kConstant);
constant_one->t_(kvalue, tensor_one);
constant_one->insertBefore(node);
Node* concat = graph->create(kConcat);
concat->i_(kaxis, 0);
concat->addInput(constant_one->output());
concat->addInput(div->output());
concat->insertBefore(node);
// Get shape of first reshape: [-1, 1]
Tensor tensor_reshape0_shape;
tensor_reshape0_shape.elem_type() = TensorProto_DataType_INT64;
tensor_reshape0_shape.sizes() = {2};
tensor_reshape0_shape.int64s() = {-1, 1};
Node* constant_reshape0_shape = graph->create(kConstant);
constant_reshape0_shape->t_(kvalue, tensor_reshape0_shape);
constant_reshape0_shape->insertBefore(node);
// Get shape of last reshape: [-1]
Tensor tensor_reshape1_shape;
tensor_reshape1_shape.elem_type() = TensorProto_DataType_INT64;
tensor_reshape1_shape.sizes() = {1};
tensor_reshape1_shape.int64s() = {-1};
Node* constant_reshape1_shape = graph->create(kConstant);
constant_reshape1_shape->t_(kvalue, tensor_reshape1_shape);
constant_reshape1_shape->insertBefore(node);
// transform scale and bias
transform_input(
graph, node, 1, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
transform_input(
graph, node, 2, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
// Set stash_type
node->i_(kstash_type, node->inputs()[0]->elemType());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_group_normalization_20_21(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,36 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for MaxPool in default domain from version 8 to 7
#pragma once
#include <memory>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class MaxPool_8_7 final : public Adapter {
public:
explicit MaxPool_8_7() : Adapter("MaxPool", OpSetID(8), OpSetID(7)) {}
void adapt_maxpool_8_7(std::shared_ptr<Graph>, Node* node) const {
const ArrayRef<Value*>& outputs = node->outputs();
ONNX_ASSERTM(outputs.size() != 2, "Opset version 7 of MaxPool cannot include Indices output");
if (node->hasAttribute(kstorage_order))
node->removeAttribute(kstorage_order);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_maxpool_8_7(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,32 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter indicating lack of a previous version of some op before a given
// opset version.
#pragma once
#include <memory>
#include <string>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class NoPreviousVersionAdapter final : public Adapter {
public:
explicit NoPreviousVersionAdapter(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* adapt(std::shared_ptr<Graph>, Node* node) const override {
ONNX_ASSERTM(false, "No Previous Version of %s exists", name().c_str());
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,56 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Pad in default domain from version 10 to 11
#pragma once
#include <memory>
#include <vector>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Pad_10_11 final : public Adapter {
public:
explicit Pad_10_11() : Adapter("Pad", OpSetID(10), OpSetID(11)) {}
void adapt_pad_10_11(std::shared_ptr<Graph> graph, Node* node) const {
// Turn pads attribute into input
Tensor t_pads;
t_pads.elem_type() = TensorProto_DataType_INT64;
auto& data_pads = t_pads.int64s();
for (int64_t shape : node->is(kpads)) {
data_pads.emplace_back(shape);
}
t_pads.sizes() = std::vector<int64_t>{(int64_t)data_pads.size()};
Value* v_pads = graph->addInitializerAndCreateValue(t_pads);
node->addInput(v_pads);
node->removeAttribute(kpads);
// Turn value attribute into input
if (!node->hasAttribute(kmode) || node->s(kmode) == "constant") {
if (!node->hasAttribute(kvalue))
node->f_(kvalue, 0.);
Tensor t_value;
t_value.elem_type() = TensorProto_DataType_FLOAT;
auto& data_value = t_value.floats();
data_value.emplace_back(node->f(kvalue));
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t_value);
node->addInput(constant->output());
node->removeAttribute(kvalue);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_pad_10_11(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,77 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Cast in default domain from version 9 to 8
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/type_restriction.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
static const std::vector<TensorProto_DataType> q_dq_20_unallowed_types = {
TensorProto_DataType_UINT16,
TensorProto_DataType_INT16,
TensorProto_DataType_UINT4,
TensorProto_DataType_INT4};
class QuantizeLinear_21_20 final : public TypeRestriction {
public:
explicit QuantizeLinear_21_20()
: TypeRestriction("QuantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
void adapt_quantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
if (node->hasAttribute(kblock_size)) {
if ((node->i(kblock_size) != 0)) {
ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
}
node->removeAttribute(kblock_size);
}
if (node->hasAttribute(koutput_dtype)) {
if (node->i(koutput_dtype) != TensorProto_DataType_UINT8 && node->inputs().size() < 3) {
ONNX_ASSERTM(
false,
"Attribute output_dtype is not supported for Opset Version %d, supply a zero-point tensor instead",
target_version().version())
}
node->removeAttribute(koutput_dtype);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_type_restriction(graph, node);
adapt_quantize_linear_21_20(graph, node);
return node;
}
};
class DequantizeLinear_21_20 final : public TypeRestriction {
public:
explicit DequantizeLinear_21_20()
: TypeRestriction("DequantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
void adapt_dequantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
if (node->hasAttribute(kblock_size)) {
if ((node->i(kblock_size) != 0)) {
ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
}
node->removeAttribute(kblock_size);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_type_restriction(graph, node);
adapt_dequantize_linear_21_20(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,32 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include <memory>
#include <string>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class RemoveConsumedInputs : public Adapter {
public:
explicit RemoveConsumedInputs(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
Node* adapt(std::shared_ptr<Graph>, Node* node) const override {
if (node->hasAttribute(kconsumed_inputs))
node->removeAttribute(kconsumed_inputs);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,49 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Reshape in default domain from version 4 to 5
#pragma once
#include <memory>
#include "onnx/version_converter/adapters/remove_consumed_inputs.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Reshape_4_5 final : public RemoveConsumedInputs {
public:
explicit Reshape_4_5() : RemoveConsumedInputs("Reshape", OpSetID(4), OpSetID(5)) {}
void adapt_reshape_4_5(std::shared_ptr<Graph> graph, Node* node) const {
// Create Input from Attribute - add as Initializer
// Create tensor for value attribute
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
auto& data = t.int64s();
// Turn shapes attribute into tensor
for (int64_t shape : node->is(kshape)) {
data.emplace_back(shape);
}
// Add value as input to node
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
// Remove kshape attribute
node->removeAttribute(kshape);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
RemoveConsumedInputs::adapt(graph, node);
adapt_reshape_4_5(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,73 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Reshape in default domain from version 5 to 4
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Reshape_5_4 final : public Adapter {
public:
explicit Reshape_5_4() : Adapter("Reshape", OpSetID(5), OpSetID(4)) {}
void adapt_reshape_5_4(std::shared_ptr<Graph> graph, Node* node) const {
// Identify if shape is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();
// Get shape from initializer or constant operator, not actual shape
// Identify whether we have a Constant Op or an Initializer
Value* const_val = inputs[1];
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node_ptr->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
node->is_(kshape, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
} else {
node->is_(kshape, std::forward<const std::vector<int64_t>>(int64s));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(1);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
}
} else {
// Get Value name, find Initializer with same name
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->is_(kshape, std::forward<const std::vector<int64_t>>(initializer.int64s()));
node->removeInput(1);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
break;
}
}
}
ONNX_ASSERTM(node->hasAttribute(kshape), "No initializer or constant input to Reshape node found");
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_reshape_5_4(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,50 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Resize in default domain from version 10 to 11
#pragma once
#include <memory>
#include <vector>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Resize_10_11 final : public Adapter {
public:
explicit Resize_10_11() : Adapter("Resize", OpSetID(10), OpSetID(11)) {}
void adapt_resize_10_11(std::shared_ptr<Graph> graph, Node* node) const {
int input_rank = node->inputs()[0]->sizes().size();
Value* scales_input = node->inputs()[1];
node->addInput(scales_input);
Tensor t;
t.sizes() = std::vector<int64_t>{2 * input_rank};
t.elem_type() = TensorProto_DataType_FLOAT;
auto& data = t.floats();
for (int i = 0; i < input_rank; i++)
data.emplace_back(0);
for (int i = 0; i < input_rank; i++)
data.emplace_back(1);
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->replaceInput(1, constant->output());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_resize_10_11(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,64 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Scan in default domain from version 8 to 9
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct Scan_8_9 final : public Adapter {
explicit Scan_8_9() : Adapter("Scan", OpSetID(8), OpSetID(9)) {}
void adapt_scan_8_9(std::shared_ptr<Graph>, Node* node) const {
const std::vector<Value*> inputs(node->inputs().vec());
const std::vector<Value*> outputs(node->outputs().vec());
// Handling Attribute Changes
Symbol dirs = Symbol("directions");
if (node->hasAttribute(dirs)) {
const std::vector<int64_t> directions(node->is(dirs));
node->removeAttribute(dirs);
node->is_(Symbol("scan_input_directions"), std::move(directions));
}
// Handling Input and Output Changes
node->removeAllInputs();
ONNX_ASSERTM(inputs[0]->uniqueName() == "", "Unsupported conversion to opset 9");
for (Value* input : inputs) {
if (!input->sizes().empty()) {
std::vector<Dimension> new_sizes(input->sizes().begin() + 1, input->sizes().end());
input->setSizes(new_sizes);
node->addInput(input);
}
}
for (Value* output : outputs) {
if (!output->sizes().empty()) {
std::vector<Dimension> new_sizes(output->sizes().begin() + 1, output->sizes().end());
output->setSizes(new_sizes);
}
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_scan_8_9(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,93 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Scan in default domain from version 9 to 8
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct Scan_9_8 final : public Adapter {
explicit Scan_9_8() : Adapter("Scan", OpSetID(9), OpSetID(8)) {}
void adapt_scan_9_8(std::shared_ptr<Graph>, Node* node) const {
const std::vector<Value*> inputs(node->inputs().vec());
const std::vector<Value*> outputs(node->outputs().vec());
// Handling Attribute Changes
Symbol input_dirs = Symbol("scan_input_directions");
if (node->hasAttribute(input_dirs)) {
const std::vector<int64_t> scan_input_directions(node->is(input_dirs));
node->removeAttribute(input_dirs);
node->is_(Symbol("directions"), std::move(scan_input_directions));
}
Symbol output_dirs = Symbol("scan_output_directions");
if (node->hasAttribute(output_dirs)) {
const std::vector<int64_t> scan_output_directions(node->is(output_dirs));
for (int64_t x : scan_output_directions) {
ONNX_ASSERTM(x == 0, "Unsupported output direction for Version 8");
}
node->removeAttribute(output_dirs);
}
Symbol input_axes = Symbol("scan_input_axes");
if (node->hasAttribute(input_axes)) {
const std::vector<int64_t> scan_input_axes(node->is(input_axes));
for (int64_t x : scan_input_axes) {
ONNX_ASSERTM(x == 0, "Unsupported input axes for Version 8");
}
node->removeAttribute(input_axes);
}
Symbol output_axes = Symbol("scan_output_axes");
if (node->hasAttribute(output_axes)) {
const std::vector<int64_t> scan_output_axes(node->is(output_axes));
for (int64_t x : scan_output_axes) {
ONNX_ASSERTM(x == 0, "Unsupported output axes for Version 8");
}
node->removeAttribute(output_axes);
}
// Handling Input and Output Changes
node->removeAllInputs();
Value* v = new Value(node, 0);
v->setUniqueName("");
v->setElemType(TensorProto_DataType::TensorProto_DataType_INT32);
node->addInput(v);
for (Value* input : inputs) {
std::vector<Dimension> new_sizes{Dimension(1)};
new_sizes.insert(new_sizes.end(), input->sizes().begin(), input->sizes().end());
input->setSizes(new_sizes);
node->addInput(input);
}
for (Value* output : outputs) {
std::vector<Dimension> new_sizes{Dimension(1)};
new_sizes.insert(new_sizes.end(), output->sizes().begin(), output->sizes().end());
output->setSizes(new_sizes);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_scan_9_8(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,43 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Scatter in default domain from version 10 to 11
#pragma once
#include <memory>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Scatter_10_11 final : public Adapter {
public:
explicit Scatter_10_11() : Adapter("Scatter", OpSetID(10), OpSetID(11)) {}
Node* adapt_scatter_10_11(std::shared_ptr<Graph> graph, Node* node) const {
int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0;
// Replace the node with an equivalent ScatterElements node
Node* scatter_elements = graph->create(kScatterElements);
scatter_elements->i_(kaxis, axis);
scatter_elements->addInput(node->inputs()[0]);
scatter_elements->addInput(node->inputs()[1]);
scatter_elements->addInput(node->inputs()[2]);
node->replaceAllUsesWith(scatter_elements);
scatter_elements->insertBefore(node);
node->destroy();
return scatter_elements;
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
return adapt_scatter_10_11(graph, node);
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,54 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Slice in default domain from version 9 to 10
#pragma once
#include <memory>
#include <vector>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Slice_9_10 final : public Adapter {
public:
explicit Slice_9_10() : Adapter("Slice", OpSetID(9), OpSetID(10)) {}
void attrToInput(std::shared_ptr<Graph> graph, Node* node, const std::vector<int64_t>& attr) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{static_cast<int64_t>(attr.size())};
auto& data = t.int64s();
for (auto a : attr) {
data.emplace_back(a);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
}
void adapt_slice_9_10(std::shared_ptr<Graph> graph, Node* node) const {
attrToInput(graph, node, node->is(kstarts));
node->removeAttribute(kstarts);
attrToInput(graph, node, node->is(kends));
node->removeAttribute(kends);
if (node->hasAttribute(kaxes)) {
attrToInput(graph, node, node->is(kaxes));
node->removeAttribute(kaxes);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_slice_9_10(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,87 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Softmax amd LogSoftmax in default domain from version 12 to 13
#pragma once
#include <memory>
#include <string>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Softmax_12_13 final : public Adapter {
public:
explicit Softmax_12_13(const std::string& op_name) : Adapter(op_name, OpSetID(12), OpSetID(13)) {}
void adapt_softmax_12_13(std::shared_ptr<Graph> graph, Node* node) const {
int old_axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 1;
int input_rank = node->inputs()[0]->sizes().size();
if (old_axis < 0)
old_axis = input_rank + old_axis;
if (old_axis == input_rank - 1)
node->i_(kaxis, -1);
else {
// -- shape ------------------
// / |
// ----- flatten -- softmax -- reshape
// get original softmax's input shape
Symbol kShape("Shape");
Node* shape = graph->create(kShape);
shape->addInput(node->inputs()[0]);
shape->insertBefore(node);
// Insert Flatten node before softmax
Node* flatten = graph->create(kFlatten);
flatten->addInput(node->inputs()[0]);
flatten->insertBefore(node);
flatten->i_(kaxis, old_axis);
node->replaceInput(0, flatten->output());
// Softmax along the last axis of the flattened 2D tensor
node->i_(kaxis, -1);
// Insert Reshape node after softmax
const std::string original_output_name = node->output()->uniqueName();
const use_list original_uses(node->output()->uses());
node->output()->setUniqueName(original_output_name + "_intermediate");
Node* reshape = graph->create(kReshape);
reshape->addInput(node->outputs()[0]);
reshape->addInput(shape->output());
reshape->output()->setUniqueName(original_output_name);
reshape->insertAfter(node);
// Fix outputs & wiring
if (node->output()->sizes().size() != 0) {
reshape->output()->setSizes(node->output()->sizes());
}
reshape->output()->setElemType(node->output()->elemType());
node->output()->wipeSizes();
for (Use u : original_uses) {
u.user->replaceInputWith(node->output(), reshape->output());
}
for (size_t i = 0; i < graph->outputs().size(); i++) {
if (graph->outputs()[i]->uniqueName() == original_output_name) {
graph->return_node()->replaceInput(i, reshape->output());
}
}
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_softmax_12_13(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,47 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Split_12_13 : public Adapter {
public:
explicit Split_12_13() : Adapter("Split", OpSetID(12), OpSetID(13)) {}
void attrToInput(std::shared_ptr<Graph> graph, Node* node, std::vector<int64_t> axes) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
auto& data = t.int64s();
for (auto a : axes) {
data.emplace_back(a);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
if (node->hasAttribute(ksplit)) {
attrToInput(graph, node, node->is(ksplit));
node->removeAttribute(ksplit);
}
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,70 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for all ops that remove consumed_inputs
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Split_13_12 : public Adapter {
public:
explicit Split_13_12() : Adapter("Split", OpSetID(13), OpSetID(12)) {}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// Identify if 'split' is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();
// Get 'split' from initializer or constant operator
// Identify whether we have a Constant Op or an Initializer
Value* const_val = inputs[1];
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node_ptr->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
node->is_(ksplit, std::vector<int64_t>(raw, raw + node_ptr->t(kvalue).size_from_dim(0)));
} else {
node->is_(ksplit, std::forward<const std::vector<int64_t>>(int64s));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(1);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
}
} else {
// Get Value name, find Initializer with same name
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->is_(ksplit, std::forward<const std::vector<int64_t>>(initializer.int64s()));
node->removeInput(1);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
break;
}
}
}
ONNX_ASSERTM(node->hasAttribute(ksplit), "No initializer or constant input to node found");
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,38 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Split in default domain from version 17 to 18
#pragma once
#include <memory>
#include "onnx/version_converter/adapters/adapter.h"
#include "onnx/version_converter/adapters/transformers.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Split_17_18 : public Adapter {
public:
explicit Split_17_18() : Adapter("Split", OpSetID(17), OpSetID(18)) {}
void adapt_split_17_18(std::shared_ptr<Graph>, Node* node) const {
const auto num_outputs = node->outputs().size();
node->i_(knum_outputs, num_outputs);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// if node does not have neither 'num_outputs' attribute nor 'split' input
if (!node->hasAttribute(knum_outputs) && node->inputs().size() != 2) {
adapt_split_17_18(graph, node);
}
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,41 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Sum in default domain from version 8 to 7
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Sum_8_7 final : public Adapter {
public:
explicit Sum_8_7() : Adapter("Sum", OpSetID(8), OpSetID(7)) {}
void adapt_sum_8_7(std::shared_ptr<Graph>, Node* node) const {
// Throw an exception if any broadcasting occurs
const ArrayRef<Value*>& inputs = node->inputs();
// Determine if inputs are of different sizes
for (int i = 1; i < (int)inputs.size(); i++) {
std::vector<Dimension> A_sizes = inputs[i - 1]->sizes();
std::vector<Dimension> B_sizes = inputs[i]->sizes();
assert_numpy_multibroadcastable(A_sizes, B_sizes);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_sum_8_7(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,43 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for TopK in default domain from version 9 to 10
#pragma once
#include <memory>
#include <vector>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class TopK_9_10 final : public Adapter {
public:
explicit TopK_9_10() : Adapter("TopK", OpSetID(9), OpSetID(10)) {}
void adapt_topk_9_10(std::shared_ptr<Graph> graph, Node* node) const {
Tensor t;
t.elem_type() = TensorProto_DataType_INT64;
t.sizes() = std::vector<int64_t>{1};
auto& data = t.int64s();
data.emplace_back(node->i(kk));
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
node->removeAttribute(kk);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_topk_9_10(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,84 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <cinttypes>
#include <string>
#include <utility>
#include <vector>
// Node transformers commonly used in version-adapters:
// Capture context by copying values; the graph is unused by these transformers.
#define NODE_TRANSFORMER(node) [=](std::shared_ptr<Graph>, Node * node)
namespace ONNX_NAMESPACE {
namespace version_conversion {
inline NodeTransformerFunction RemoveAttribute(Symbol attr) {
return NODE_TRANSFORMER(node) {
if (node->hasAttribute(attr)) {
node->removeAttribute(attr);
}
return node;
};
}
inline NodeTransformerFunction RemoveAttribute(Symbol attr, int64_t value) {
return NODE_TRANSFORMER(node) {
if (node->hasAttribute(attr)) {
ONNX_ASSERTM(node->i(attr) == value, "Attribute %s must have value %" PRId64, attr.toString(), value);
node->removeAttribute(attr);
}
return node;
};
}
inline NodeTransformerFunction RemoveAttributeNotEq(Symbol attr, int64_t value) {
return NODE_TRANSFORMER(node) {
if (node->hasAttribute(attr)) {
ONNX_ASSERTM(node->i(attr) != value, "Attribute %s must not have value %" PRId64, attr.toString(), value);
node->removeAttribute(attr);
}
return node;
};
}
inline NodeTransformerFunction SetAttribute(Symbol attr, int64_t value) {
return NODE_TRANSFORMER(node) {
node->i_(attr, value);
return node;
};
}
inline NodeTransformerFunction SetAttribute(Symbol attr, const std::string& value) {
return NODE_TRANSFORMER(node) {
node->s_(attr, value);
return node;
};
}
inline NodeTransformerFunction SetAttribute(Symbol attr, std::vector<int64_t> value) {
return NODE_TRANSFORMER(node) {
std::vector<int64_t> local(value);
node->is_(attr, std::move(local));
return node;
};
}
inline NodeTransformerFunction SetAttributeIfAbsent(Symbol attr, int64_t value) {
return NODE_TRANSFORMER(node) {
if (!node->hasAttribute(attr)) {
node->i_(attr, value);
}
return node;
};
}
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,61 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Add in default domain from version 6 to 5
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class TypeRestriction : public Adapter {
public:
explicit TypeRestriction(
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
const std::vector<TensorProto_DataType>& unallowed_types)
: Adapter(op_name, initial, target), unallowed_types_(unallowed_types) {}
void adapt_type_restriction(std::shared_ptr<Graph>, Node* node) const {
// Since consumed_inputs is optional, no need to add it (as in batchnorm)
// Iterate over all inputs and outputs
for (Value* input : node->inputs()) {
isUnallowed(input);
}
for (Value* output : node->outputs()) {
isUnallowed(output);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_type_restriction(graph, node);
return node;
}
private:
std::vector<TensorProto_DataType> unallowed_types_;
void isUnallowed(Value* val) const {
ONNX_ASSERTM(
std::find(std::begin(unallowed_types_), std::end(unallowed_types_), val->elemType()) ==
std::end(unallowed_types_),
"DataType (%d) of Input or Output"
" of operator '%s' is unallowed for Opset Version %d.",
val->elemType(),
name().c_str(),
target_version().version());
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,49 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Upsample in default domain from version 6 to 7
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct Upsample_6_7 final : public Adapter {
explicit Upsample_6_7() : Adapter("Upsample", OpSetID(6), OpSetID(7)) {}
void adapt_upsample_6_7(std::shared_ptr<Graph>, Node* node) const {
Symbol width_scale_symbol = Symbol("width_scale");
Symbol height_scale_symbol = Symbol("height_scale");
ONNX_ASSERTM(
node->hasAttribute(width_scale_symbol) && node->hasAttribute(height_scale_symbol),
"Upsample in opset 1 needs to have width_scale and height_scale attributes");
auto width_scale = node->f(width_scale_symbol);
auto height_scale = node->f(height_scale_symbol);
auto input_shape = node->inputs()[0]->sizes();
ONNX_ASSERTM(input_shape.size() == 4, "Upsample in opset 1 supports only 4D input tensor");
std::vector<double> scales = {1.0, 1.0, height_scale, width_scale};
node->fs_(kscales, std::move(scales));
node->removeAttribute(width_scale_symbol);
node->removeAttribute(height_scale_symbol);
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_upsample_6_7(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,50 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Upsample in default domain from version 8 to 9
#pragma once
#include <memory>
#include <vector>
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct Upsample_8_9 final : public Adapter {
explicit Upsample_8_9() : Adapter("Upsample", OpSetID(8), OpSetID(9)) {}
void adapt_upsample_8_9(std::shared_ptr<Graph> graph, Node* node) const {
Symbol input_dirs = Symbol("scales");
int dim = (int)(node->fs(kscales).size());
Tensor t;
t.elem_type() = TensorProto_DataType_FLOAT;
t.sizes() = std::vector<int64_t>{dim};
auto& data = t.floats();
if (node->hasAttribute(input_dirs)) {
for (double scale : node->fs(kscales)) {
data.emplace_back((float)scale);
}
Node* constant = graph->create(kConstant);
constant->insertBefore(node);
constant->t_(kvalue, t);
node->addInput(constant->output());
node->removeAttribute(kscales);
}
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_upsample_8_9(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,42 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Upsample in default domain from version 9 to 10
#pragma once
#include <memory>
#include <string>
namespace ONNX_NAMESPACE {
namespace version_conversion {
class Upsample_9_10 final : public Adapter {
public:
explicit Upsample_9_10() : Adapter("Upsample", OpSetID(9), OpSetID(10)) {}
Node* adapt_upsample_9_10(std::shared_ptr<Graph> graph, Node* node) const {
std::string mode = node->hasAttribute(kmode) ? node->s(kmode) : "nearest";
// Replace the node with an equivalent Resize node
Node* resize = graph->create(kResize);
resize->s_(kmode, mode);
resize->addInput(node->inputs()[0]);
resize->addInput(node->inputs()[1]);
node->replaceAllUsesWith(resize);
resize->insertBefore(node);
node->destroy();
return resize;
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
return adapt_upsample_9_10(graph, node);
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,79 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Adapter for Upsample in default domain from version 9 to 8
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "onnx/defs/tensor_proto_util.h"
#include "onnx/defs/tensor_util.h"
#include "onnx/version_converter/adapters/adapter.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
struct Upsample_9_8 final : public Adapter {
explicit Upsample_9_8() : Adapter("Upsample", OpSetID(9), OpSetID(8)) {}
void adapt_upsample_9_8(std::shared_ptr<Graph> graph, Node* node) const {
const ArrayRef<Value*>& inputs = node->inputs();
const std::vector<Tensor>& initializers = graph->initializers();
ONNX_ASSERTM(inputs.size() == 2, "Upsample in opset 9 needs to have 2 inputs.");
std::string scale_input_name = node->inputs()[1]->uniqueName();
for (size_t i = 0; i < initializers.size(); i++) {
if (initializers[i].name() == inputs[1]->uniqueName()) {
std::vector<float> value = ParseData<float>(&initializers[i]);
std::vector<double> d_values;
d_values.reserve(value.size());
for (size_t j = 0; j < value.size(); j++) {
d_values.push_back(static_cast<double>(value[j]));
}
node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
node->removeInput(1);
graph->eraseInitializer(initializers[i].name());
for (size_t j = 0; j < graph->inputs().size(); j++) {
if (graph->inputs()[j]->uniqueName() == scale_input_name) {
graph->eraseInput(j);
break;
}
}
return;
}
}
for (Node* op : graph->nodes()) {
if (op->kind() == kConstant && op->outputs()[0]->uniqueName() == scale_input_name) {
std::vector<float> value = ParseData<float>(&op->t(kvalue));
std::vector<double> d_values;
d_values.reserve(value.size());
for (size_t j = 0; j < value.size(); j++) {
d_values.push_back(static_cast<double>(value[j]));
}
node->fs_(kscales, const_cast<std::vector<double>&&>(d_values));
node->removeInput(1);
op->destroy();
return;
}
}
ONNX_ASSERTM(false, "Unsuppported conversion due to unavailable input: scale");
}
Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
adapt_upsample_9_8(graph, node);
return node;
}
};
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,148 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/version_converter/convert.h"
#include <memory>
#include <string>
namespace ONNX_NAMESPACE {
namespace version_conversion {
ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) {
// Get initial_opsetid from mp_in
OpSetID initial_struct(0);
for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
if (it->domain() == "" || it->domain() == "ai.onnx") {
initial_struct.setVersion(it->version());
break;
}
}
OpSetID target_struct = OpSetID(target_version);
DefaultVersionConverter v;
return v.convert_version(mp_in, initial_struct, target_struct);
}
void DefaultVersionConverter::convert_graph(
std::shared_ptr<Graph> g,
const OpSetID& initial_version,
const OpSetID& target_version) const {
assertNonNull(g);
// TODO: Move to Inter-Domain Converter
// Get initial model versions
// std::vector<OpSetID> initial_versions = g->opset_versions_mutable();
// No conversion necessary if Model has single, equivalent opset version
// if (initial_versions.size() == 1 && initial_versions[0].version ==
// target_version.version && initial_versions[0].domain ==
// target_version.domain) {
// return mp_in;
// }
// Check if versions are valid
assertInVersionRange(initial_version.version());
assertInVersionRange(target_version.version());
// Iterate over all versions to target_version for specified
int64_t curr_version = initial_version.version();
int64_t step;
if (target_version.version() > initial_version.version()) {
step = 1;
} else {
step = -1;
}
// Identify index of this domain in g.opset_versions
unsigned int domain_index = 0;
for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) {
if (g->opset_versions_mutable()[i].domain() == "") {
domain_index = i;
}
}
while (curr_version != target_version.version()) {
debug(
"curr_version: " + ONNX_NAMESPACE::to_string(curr_version) +
", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step));
Node* cur_op;
graph_node_list_iterator it = g->begin();
// Iterate through and call adapter returned by adapter_lookup for ops from
// current_version opset. We have to manipulate the iterator explicitly because cur_op
// might change when applying the adapter (e.g. for deprecated ops)
while (it != g->end()) {
cur_op = *it;
debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString()));
const std::string op_name = cur_op->kind().toString();
if (op_name == "ConstantFill") {
if (DEBUG) {
std::cerr
<< "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. "
"Please be advised the converted model may not be working properly if target runtime does not support this "
"experimental op."
<< std::endl;
}
} else if (cur_op->domain() != "" && cur_op->domain() != "ai.onnx") {
if (DEBUG) {
std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << std::endl;
}
} else if (op_name != "Undefined" && op_name != "Captured") {
auto& op_domain_map = all_schemas.at(op_name);
OpSetID curr_id(curr_version);
OpSetID next_id(curr_version + step);
if (searchOpDomainMap(op_domain_map, curr_version, step)) {
// Op is specifically defined for this domain and version
auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id);
// If adapter_lookup returns null, no adapter is present.
// Error thrown by adapter_lookup
if (DEBUG) {
std::cerr << "Applying adapter" << std::endl;
}
// adapt should handle replacing node in graph
cur_op = op_adapter.adapt(g, cur_op);
it = graph_node_list_iterator(cur_op, kNextDirection);
}
// Recursively convert any subgraph attributes
for (const auto& attr : cur_op->attributeNames()) {
if (cur_op->kindOf(attr) == AttributeKind::g) {
convert_graph(cur_op->g(attr), curr_id, next_id);
}
}
}
it++;
}
// Update model version
curr_version += step;
g->opset_versions_mutable()[domain_index].incrementVersion(step);
}
}
ModelProto DefaultVersionConverter::convert_version(
const ModelProto& mp_in,
const OpSetID& initial_version,
const OpSetID& target_version) const {
const std::string& initial_domain = initial_version.domain();
const std::string& target_domain = target_version.domain();
assertDefaultDomain(initial_domain, target_domain);
for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
if (it->domain() == initial_version.domain()) {
ONNX_ASSERTM(
initial_version.version() == it->version(), "initial_version does not reflect current state of model");
}
}
std::shared_ptr<Graph> g(ImportModelProto(mp_in));
convert_graph(g, initial_version, target_version);
// Export g as ModelProto
debug("Finished conversion; returning model");
ModelProto mp_out = PrepareOutput(mp_in);
ExportModelProto(&mp_out, g);
return mp_out;
}
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,787 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Default converter for ONNX models between different opset versions
// in the default domain ("" or "ai.onnx").
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "onnx/version_converter/BaseConverter.h"
#include "onnx/version_converter/adapters/axes_attribute_to_input.h"
#include "onnx/version_converter/adapters/axes_input_to_attribute.h"
#include "onnx/version_converter/adapters/axis_attribute_to_input.h"
#include "onnx/version_converter/adapters/axis_input_to_attribute.h"
#include "onnx/version_converter/adapters/batch_normalization_13_14.h"
#include "onnx/version_converter/adapters/broadcast_backward_compatibility.h"
#include "onnx/version_converter/adapters/broadcast_forward_compatibility.h"
#include "onnx/version_converter/adapters/cast_9_8.h"
#include "onnx/version_converter/adapters/clip_10_11.h"
#include "onnx/version_converter/adapters/compatible.h"
#include "onnx/version_converter/adapters/dropout_11_12.h"
#include "onnx/version_converter/adapters/extend_supported_types.h"
#include "onnx/version_converter/adapters/gemm_6_7.h"
#include "onnx/version_converter/adapters/gemm_7_6.h"
#include "onnx/version_converter/adapters/gridsample_19_20.h"
#include "onnx/version_converter/adapters/group_normalization_20_21.h"
#include "onnx/version_converter/adapters/maxpool_8_7.h"
#include "onnx/version_converter/adapters/no_previous_version.h"
#include "onnx/version_converter/adapters/pad_10_11.h"
#include "onnx/version_converter/adapters/q_dq_21_20.h"
#include "onnx/version_converter/adapters/reshape_4_5.h"
#include "onnx/version_converter/adapters/reshape_5_4.h"
#include "onnx/version_converter/adapters/resize_10_11.h"
#include "onnx/version_converter/adapters/scan_8_9.h"
#include "onnx/version_converter/adapters/scan_9_8.h"
#include "onnx/version_converter/adapters/scatter_10_11.h"
#include "onnx/version_converter/adapters/slice_9_10.h"
#include "onnx/version_converter/adapters/softmax_12_13.h"
#include "onnx/version_converter/adapters/split_12_13.h"
#include "onnx/version_converter/adapters/split_13_12.h"
#include "onnx/version_converter/adapters/split_17_18.h"
#include "onnx/version_converter/adapters/sum_8_7.h"
#include "onnx/version_converter/adapters/topk_9_10.h"
#include "onnx/version_converter/adapters/transformers.h"
#include "onnx/version_converter/adapters/type_restriction.h"
#include "onnx/version_converter/adapters/upsample_6_7.h"
#include "onnx/version_converter/adapters/upsample_8_9.h"
#include "onnx/version_converter/adapters/upsample_9_10.h"
#include "onnx/version_converter/adapters/upsample_9_8.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
class DefaultVersionConverter : public BaseVersionConverter {
private:
bool DEBUG = false;
std::pair<int, int> version_range;
bool searchOpDomainMap(
const std::unordered_map<std::string, std::map<int64_t, const OpSchema*>>& op_domain_map,
int64_t curr_version,
int64_t step) const {
bool up = step == 1;
const auto version_it = op_domain_map.find("");
return version_it != op_domain_map.end() &&
((version_it->second.find(curr_version) != version_it->second.end() && !up) ||
(version_it->second.find(curr_version + step) != version_it->second.end() && up));
}
void debug(const std::string& str) const {
if (DEBUG)
std::cerr << str << std::endl;
}
void assertInVersionRange(int64_t version) const {
ONNX_ASSERTM(
version >= version_range.first && version <= version_range.second,
"Warning: invalid version (must be between %d and %d)",
version_range.first,
version_range.second);
}
void assertDefaultDomain(const std::string& initial_domain, const std::string& target_domain) const {
ONNX_ASSERTM(
(initial_domain == "" || initial_domain == "ai.onnx") && (target_domain == "" || target_domain == "ai.onnx"),
"Warning: default onnx version converter can only convert "
" between default domain opset versions ('' or 'ai.onnx')\n");
ONNX_ASSERTM(initial_domain == target_domain, "initial_version and target_version must have the same domains");
}
void convert_graph(std::shared_ptr<Graph> g, const OpSetID& initial_version, const OpSetID& target_version) const;
public:
DefaultVersionConverter() {
const std::unordered_map<std::string, std::pair<int, int>>& versions_map =
OpSchemaRegistry::DomainToVersionRange::Instance().Map();
version_range = versions_map.at("");
// Register adapters to the version converter
const std::vector<OpSchema> all_opschemas = OpSchemaRegistry::get_all_schemas_with_history();
for (const OpSchema& schema : all_opschemas) {
all_schemas[schema.Name()][schema.domain()][(int64_t)schema.since_version()] = &schema;
}
// Iterate through all_schemas to determine NoPreviousVersionAdapters
for (auto& op_pair : all_schemas) {
const auto default_versions = op_pair.second.find("");
if (default_versions != op_pair.second.end()) {
int64_t min_version = version_range.second;
for (auto& version_pair : default_versions->second) {
if (version_pair.first < min_version) {
min_version = version_pair.first;
}
}
if (min_version > 1) {
registerAdapter(std::make_unique<NoPreviousVersionAdapter>(
op_pair.first, OpSetID(min_version), OpSetID(min_version - 1)));
}
}
}
/******** 1 -> 2 ********/
// Missing in this group: GlobalLpPool, LpPool, Pad, Split
/******** 2 -> 3 ********/
// Missing in this group: GRU
/******** 3 -> 4 ********/
registerAdapter("Concat", 3, 4, SetAttributeIfAbsent(kaxis, 1));
/******** 4 -> 3 ********/
std::vector<TensorProto_DataType> concat_unallowed_types = {
TensorProto_DataType_INT32,
TensorProto_DataType_INT64,
TensorProto_DataType_UINT32,
TensorProto_DataType_UINT64,
TensorProto_DataType_UINT8,
TensorProto_DataType_UINT16,
TensorProto_DataType_INT8,
TensorProto_DataType_INT16,
TensorProto_DataType_STRING,
TensorProto_DataType_BOOL};
registerAdapter(std::make_unique<TypeRestriction>("Concat", OpSetID(4), OpSetID(3), concat_unallowed_types));
/******** 4 -> 5 ********/
registerAdapter(std::make_unique<Reshape_4_5>());
/******** 5 -> 4 ********/
registerAdapter(std::make_unique<Reshape_5_4>());
/******** 5 -> 6 ********/
// Missing in this group: Cast, Tile
auto removeConsumedInputs = RemoveAttribute(kconsumed_inputs);
registerAdapter("Add", 5, 6, removeConsumedInputs);
registerAdapter("Mul", 5, 6, removeConsumedInputs);
registerAdapter(std::make_unique<CompatibleAdapter>("Gemm", OpSetID(5), OpSetID(6)));
registerAdapter("Relu", 5, 6, removeConsumedInputs);
registerAdapter("BatchNormalization", 5, 6, removeConsumedInputs);
registerAdapter("Sum", 5, 6, removeConsumedInputs);
registerAdapter("Dropout", 5, 6, removeConsumedInputs);
registerAdapter("Abs", 5, 6, removeConsumedInputs);
registerAdapter("Ceil", 5, 6, removeConsumedInputs);
registerAdapter("Clip", 5, 6, removeConsumedInputs);
registerAdapter("Div", 5, 6, removeConsumedInputs);
registerAdapter("Elu", 5, 6, removeConsumedInputs);
registerAdapter("Exp", 5, 6, removeConsumedInputs);
registerAdapter("Floor", 5, 6, removeConsumedInputs);
registerAdapter("HardSigmoid", 5, 6, removeConsumedInputs);
registerAdapter("InstanceNormalization", 5, 6, removeConsumedInputs);
registerAdapter("LeakyRelu", 5, 6, removeConsumedInputs);
registerAdapter("Log", 5, 6, removeConsumedInputs);
registerAdapter("Max", 5, 6, removeConsumedInputs);
registerAdapter("Mean", 5, 6, removeConsumedInputs);
registerAdapter("Min", 5, 6, removeConsumedInputs);
registerAdapter("Neg", 5, 6, removeConsumedInputs);
registerAdapter("PRelu", 5, 6, removeConsumedInputs);
registerAdapter("Reciprocal", 5, 6, removeConsumedInputs);
registerAdapter("Selu", 5, 6, removeConsumedInputs);
registerAdapter("Sigmoid", 5, 6, removeConsumedInputs);
registerAdapter("Sqrt", 5, 6, removeConsumedInputs);
registerAdapter("Sub", 5, 6, removeConsumedInputs);
registerAdapter("Tanh", 5, 6, removeConsumedInputs);
/******** 6 -> 5 ********/
std::vector<TensorProto_DataType> broadcast_unallowed_types = {
TensorProto_DataType_INT32,
TensorProto_DataType_INT64,
TensorProto_DataType_UINT32,
TensorProto_DataType_UINT64};
std::vector<TensorProto_DataType> int_unallowed_types = {
TensorProto_DataType_UINT8,
TensorProto_DataType_UINT16,
TensorProto_DataType_UINT32,
TensorProto_DataType_UINT64,
TensorProto_DataType_INT8,
TensorProto_DataType_INT16,
TensorProto_DataType_INT32,
TensorProto_DataType_INT64};
std::vector<TensorProto_DataType> neg_unallowed_types = {
TensorProto_DataType_INT32, TensorProto_DataType_INT8, TensorProto_DataType_UINT16, TensorProto_DataType_INT64};
registerAdapter(std::make_unique<TypeRestriction>("Add", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Mul", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Sub", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Div", OpSetID(6), OpSetID(5), broadcast_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Abs", OpSetID(6), OpSetID(5), int_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Neg", OpSetID(6), OpSetID(5), neg_unallowed_types));
registerAdapter("BatchNormalization", 6, 5, SetAttribute(kconsumed_inputs, std::vector<int64_t>({0, 0})));
registerAdapter(std::make_unique<CompatibleAdapter>("Gemm", OpSetID(6), OpSetID(5)));
registerAdapter(std::make_unique<CompatibleAdapter>("Relu", OpSetID(6), OpSetID(5)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sum", OpSetID(6), OpSetID(5)));
registerAdapter(std::make_unique<CompatibleAdapter>("Dropout", OpSetID(6), OpSetID(5)));
/******** 6 -> 7 ********/
// Missing in this group: And, Equal, Greater, GRU, Less, LSTM, Or, RNN, Upsample, Xor
registerAdapter(std::make_unique<BroadcastForwardCompatibility>("Add", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastForwardCompatibility>("Div", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastForwardCompatibility>("Mul", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastForwardCompatibility>("Pow", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<CompatibleAdapter>("PRelu", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastForwardCompatibility>("Sub", OpSetID(6), OpSetID(7)));
registerAdapter(std::make_unique<Gemm_6_7>());
registerAdapter("BatchNormalization", 6, 7, RemoveAttributeNotEq(kis_test, 0));
registerAdapter("Dropout", 6, 7, RemoveAttributeNotEq(kis_test, 0));
registerAdapter(std::make_unique<Upsample_6_7>());
/******** 7 -> 6 ********/
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Add", OpSetID(7), OpSetID(6)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Div", OpSetID(7), OpSetID(6)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Mul", OpSetID(7), OpSetID(6)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Pow", OpSetID(7), OpSetID(6)));
registerAdapter(std::make_unique<CompatibleAdapter>("PRelu", OpSetID(7), OpSetID(6)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Sub", OpSetID(7), OpSetID(6)));
registerAdapter("BatchNormalization", 7, 6, SetAttribute(kis_test, 1));
registerAdapter("Dropout", 7, 6, SetAttribute(kis_test, 1));
registerAdapter(std::make_unique<Gemm_7_6>());
registerAdapter("AveragePool", 7, 6, RemoveAttribute(kcount_include_pad, 0));
/******** 7 -> 8 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Max", OpSetID(7), OpSetID(8)));
registerAdapter(std::make_unique<CompatibleAdapter>("Min", OpSetID(7), OpSetID(8)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mean", OpSetID(7), OpSetID(8)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sum", OpSetID(7), OpSetID(8)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(7), OpSetID(8)));
/******** 8 -> 7 ********/
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Max", OpSetID(8), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Min", OpSetID(8), OpSetID(7)));
registerAdapter(std::make_unique<BroadcastBackwardCompatibility>("Mean", OpSetID(8), OpSetID(7)));
registerAdapter(std::make_unique<Sum_8_7>());
registerAdapter(std::make_unique<MaxPool_8_7>());
/******** 8 -> 9 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Flatten", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("MatMul", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("Gemm", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("PRelu", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("Greater", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("Less", OpSetID(8), OpSetID(9)));
registerAdapter(std::make_unique<CompatibleAdapter>("Cast", OpSetID(8), OpSetID(9)));
registerAdapter("BatchNormalization", 8, 9, RemoveAttribute(kspatial, 1));
registerAdapter(std::make_unique<Scan_8_9>());
registerAdapter(std::make_unique<Upsample_8_9>());
/******** 9 -> 8 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("Flatten", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("Constant", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("MatMul", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("Gemm", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("PRelu", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("Greater", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<ExtendSupportedTypes>("Less", OpSetID(9), OpSetID(8)));
registerAdapter(std::make_unique<Cast_9_8>());
registerAdapter(std::make_unique<Scan_9_8>());
registerAdapter(std::make_unique<Upsample_9_8>());
/******** 9 -> 10 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(9), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(9), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Dropout", OpSetID(9), OpSetID(10)));
registerAdapter(std::make_unique<Slice_9_10>());
registerAdapter(std::make_unique<TopK_9_10>());
registerAdapter(std::make_unique<Upsample_9_10>());
/******** 10 -> 9 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Dropout", OpSetID(10), OpSetID(9)));
/******** 10 -> 11 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMax", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMin", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Concat", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Compress", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Conv", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ConvTranspose", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("DepthToSpace", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Equal", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Flatten", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Gather", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Gemm", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Hardmax", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("If", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("LogSoftmax", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Loop", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("LpPool", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxUnpool", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("NonMaxSuppression", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("OneHot", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL1", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL2", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMean", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceProd", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceSum", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Scan", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Softmax", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Slice", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Split", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Squeeze", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("TopK", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<CompatibleAdapter>("Unsqueeze", OpSetID(10), OpSetID(11)));
registerAdapter(std::make_unique<Clip_10_11>());
registerAdapter(std::make_unique<Pad_10_11>());
registerAdapter(std::make_unique<Resize_10_11>());
registerAdapter(std::make_unique<Scatter_10_11>());
/******** 11 -> 10 ********/
std::vector<TensorProto_DataType> equal_unallowed_types = {
TensorProto_DataType_UINT8,
TensorProto_DataType_UINT16,
TensorProto_DataType_UINT32,
TensorProto_DataType_UINT64,
TensorProto_DataType_INT8,
TensorProto_DataType_INT16,
TensorProto_DataType_FLOAT16,
TensorProto_DataType_FLOAT,
TensorProto_DataType_DOUBLE};
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMax", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMin", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Concat", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Conv", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ConvTranspose", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<TypeRestriction>("Equal", OpSetID(11), OpSetID(10), equal_unallowed_types));
registerAdapter(std::make_unique<CompatibleAdapter>("Flatten", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("LogSoftmax", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL1", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL2", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMean", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceProd", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceSum", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Softmax", OpSetID(11), OpSetID(10)));
registerAdapter(std::make_unique<CompatibleAdapter>("Unsqueeze", OpSetID(11), OpSetID(10)));
/******** 11 -> 12 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMax", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMin", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("Clip", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("GatherND", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("Min", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("Max", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pow", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(11), OpSetID(12)));
registerAdapter(std::make_unique<Dropout_11_12>());
/******** 12 -> 11 ********/
std::vector<TensorProto_DataType> maxpool_unallowed_types = {TensorProto_DataType_UINT8, TensorProto_DataType_INT8};
registerAdapter("ArgMax", 12, 11, RemoveAttribute(kselect_last_index, 0));
registerAdapter("ArgMin", 12, 11, RemoveAttribute(kselect_last_index, 0));
registerAdapter(std::make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(12), OpSetID(11)));
registerAdapter(std::make_unique<TypeRestriction>("Clip", OpSetID(12), OpSetID(11), int_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Min", OpSetID(12), OpSetID(11), int_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("Max", OpSetID(12), OpSetID(11), int_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("MaxPool", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("ReduceMax", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
registerAdapter(std::make_unique<TypeRestriction>("ReduceMin", OpSetID(12), OpSetID(11), maxpool_unallowed_types));
/******** 12 -> 13 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Abs", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Add", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMin", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ArgMax", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Cast", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Ceil", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Clip", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Concat", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("DepthToSpace", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("DequantizeLinear", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Div", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Dropout", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Equal", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Erf", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Exp", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Expand", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Flatten", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Floor", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Gather", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("GatherElements", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("GatherND", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Gemm", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Greater", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Hardmax", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Identity", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("If", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("IsNaN", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Less", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Log", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Loop", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("LRN", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("NegativeLogLikelihoodLoss", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("MatMul", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Max", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mean", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("MeanVarianceNormalization", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Min", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mod", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mul", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Neg", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("NonZero", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pow", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pad", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("QuantizeLinear", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Reciprocal", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL1", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceL2", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSum", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceLogSumExp", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMean", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceProd", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceSumSquare", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Relu", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Reshape", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Resize", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterElements", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterND", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Shape", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sigmoid", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sign", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Size", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Slice", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("SoftmaxCrossEntropyLoss", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("SpaceToDepth", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sqrt", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sub", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sum", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Tanh", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Tile", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<CompatibleAdapter>("Transpose", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceSum", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<AxesAttributeToInput>("Squeeze", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<AxesAttributeToInput>("Unsqueeze", OpSetID(12), OpSetID(13)));
registerAdapter(std::make_unique<Split_12_13>());
registerAdapter(std::make_unique<Softmax_12_13>("Softmax"));
registerAdapter(std::make_unique<Softmax_12_13>("LogSoftmax"));
/******** 13 -> 12 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(13), OpSetID(12)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceSum", OpSetID(13), OpSetID(12)));
registerAdapter(std::make_unique<AxesInputToAttribute>("Squeeze", OpSetID(13), OpSetID(12)));
registerAdapter(std::make_unique<AxesInputToAttribute>("Unsqueeze", OpSetID(13), OpSetID(12)));
registerAdapter(std::make_unique<Split_13_12>());
/******** 13 -> 14 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Add", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("CumSum", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Div", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Identity", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mul", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Relu", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Reshape", OpSetID(13), OpSetID(14)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sub", OpSetID(13), OpSetID(14)));
registerAdapter("GRU", 13, 14, SetAttribute(klayout, 0));
registerAdapter("LSTM", 13, 14, SetAttribute(klayout, 0));
registerAdapter("RNN", 13, 14, SetAttribute(klayout, 0));
registerAdapter(std::make_unique<BatchNormalization_13_14>());
/******** 14 -> 13 ********/
registerAdapter("GRU", 14, 13, RemoveAttribute(klayout, 0));
registerAdapter("LSTM", 14, 13, RemoveAttribute(klayout, 0));
registerAdapter("RNN", 14, 13, RemoveAttribute(klayout, 0));
/******** 14 -> 15 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("BatchNormalization", OpSetID(14), OpSetID(15)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pow", OpSetID(14), OpSetID(15)));
registerAdapter(std::make_unique<CompatibleAdapter>("Shape", OpSetID(14), OpSetID(15)));
/******** 15 -> 16 ********/
registerAdapter("RoiAlign", 15, 16, SetAttribute(kcoordinate_transformation_mode, "output_half_pixel"));
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterElements", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterND", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("Identity", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("Loop", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("If", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("Where", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("Scan", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("LessOrEqual", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("GreaterOrEqual", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("LeakyRelu", OpSetID(15), OpSetID(16)));
registerAdapter(std::make_unique<CompatibleAdapter>("PRelu", OpSetID(15), OpSetID(16)));
/******** 17 -> 18 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Pad", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<CompatibleAdapter>("Resize", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<CompatibleAdapter>("OptionalGetElement", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<CompatibleAdapter>("OptionalHasElement", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<Split_17_18>());
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterND", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<CompatibleAdapter>("ScatterElements", OpSetID(17), OpSetID(18)));
registerAdapter("LpPool", 17, 18, SetAttribute(kceil_mode, 0));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceL1", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceL2", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceLogSum", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceLogSumExp", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceMax", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceMean", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceMin", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceProd", OpSetID(17), OpSetID(18)));
registerAdapter(std::make_unique<AxesAttributeToInput>("ReduceSumSquare", OpSetID(17), OpSetID(18)));
/******** 18 -> 17 ********/
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceL1", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceL2", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceLogSum", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceLogSumExp", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceMax", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceMean", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceMin", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceProd", OpSetID(18), OpSetID(17)));
registerAdapter(std::make_unique<AxesInputToAttribute>("ReduceSumSquare", OpSetID(18), OpSetID(17)));
/******** 18 -> 19 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Equal", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Cast", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("CastLike", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("DequantizeLinear", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Identity", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("If", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Loop", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pad", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("QuantizeLinear", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Reshape", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Resize", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Scan", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Shape", OpSetID(18), OpSetID(19)));
registerAdapter(std::make_unique<CompatibleAdapter>("Size", OpSetID(18), OpSetID(19)));
/******** 19 -> 20 ********/
registerAdapter(std::make_unique<AxisAttributeToInput>("DFT", OpSetID(19), OpSetID(20), 2, 1));
registerAdapter(std::make_unique<CompatibleAdapter>("ConstantOfShape", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("IsInf", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("IsNaN", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<GridSample_19_20>());
/******** 20 -> 19 ********/
const std::vector<TensorProto_DataType> is_nan_13_unallowed_types = {
TensorProto_DataType_FLOAT8E4M3FN,
TensorProto_DataType_FLOAT8E4M3FNUZ,
TensorProto_DataType_FLOAT8E5M2,
TensorProto_DataType_FLOAT8E5M2FNUZ};
registerAdapter(std::make_unique<TypeRestriction>("IsNaN", OpSetID(20), OpSetID(19), is_nan_13_unallowed_types));
const std::vector<TensorProto_DataType> is_inf_10_unallowed_types = {
TensorProto_DataType_FLOAT16,
TensorProto_DataType_BFLOAT16,
TensorProto_DataType_FLOAT8E4M3FN,
TensorProto_DataType_FLOAT8E4M3FNUZ,
TensorProto_DataType_FLOAT8E5M2,
TensorProto_DataType_FLOAT8E5M2FNUZ};
registerAdapter(std::make_unique<TypeRestriction>("IsInf", OpSetID(20), OpSetID(19), is_inf_10_unallowed_types));
registerAdapter(std::make_unique<AxisInputToAttribute>("DFT", OpSetID(20), OpSetID(19), 2, -2));
const std::vector<TensorProto_DataType> reduce_min_max_18_unallowed_types = {TensorProto_DataType_BOOL};
registerAdapter(
std::make_unique<TypeRestriction>("ReduceMax", OpSetID(20), OpSetID(19), reduce_min_max_18_unallowed_types));
registerAdapter(
std::make_unique<TypeRestriction>("ReduceMin", OpSetID(20), OpSetID(19), reduce_min_max_18_unallowed_types));
/******** 20 -> 21 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("Cast", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("CastLike", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Constant", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("ConstantOfShape", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("DequantizeLinear", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Flatten", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<GroupNormalization_20_21>());
registerAdapter(std::make_unique<CompatibleAdapter>("Identity", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("If", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Loop", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Pad", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("QLinearMatMul", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("QuantizeLinear", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Reshape", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Scan", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Shape", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Size", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Squeeze", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Transpose", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("Unsqueeze", OpSetID(20), OpSetID(21)));
registerAdapter(std::make_unique<CompatibleAdapter>("GroupNormalization", OpSetID(20), OpSetID(21)));
/******** 21 -> 20 ********/
const std::vector<TensorProto_DataType> q_dqmm_20_unallowed_types = {
TensorProto_DataType_BFLOAT16,
TensorProto_DataType_FLOAT16,
TensorProto_DataType_UINT4,
TensorProto_DataType_INT4};
const std::vector<TensorProto_DataType> ir10_types_not_in_ir9 = {
TensorProto_DataType_UINT4, TensorProto_DataType_INT4};
const std::vector<TensorProto_DataType> ir10_types_not_in_ir4 = {
TensorProto_DataType_FLOAT8E4M3FN,
TensorProto_DataType_FLOAT8E4M3FNUZ,
TensorProto_DataType_FLOAT8E5M2,
TensorProto_DataType_FLOAT8E5M2FNUZ,
TensorProto_DataType_UINT4,
TensorProto_DataType_INT4};
registerAdapter(std::make_unique<TypeRestriction>("Cast", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("CastLike", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Constant", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(
std::make_unique<TypeRestriction>("ConstantOfShape", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<DequantizeLinear_21_20>());
registerAdapter(std::make_unique<TypeRestriction>("Flatten", OpSetID(21), OpSetID(20), ir10_types_not_in_ir4));
registerAdapter(std::make_unique<TypeRestriction>("Identity", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("If", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Loop", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Pad", OpSetID(21), OpSetID(20), ir10_types_not_in_ir4));
registerAdapter(
std::make_unique<TypeRestriction>("QLinearMatMul", OpSetID(21), OpSetID(20), q_dqmm_20_unallowed_types));
registerAdapter(std::make_unique<QuantizeLinear_21_20>());
registerAdapter(std::make_unique<TypeRestriction>("Reshape", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Scan", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Shape", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Size", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Squeeze", OpSetID(21), OpSetID(20), ir10_types_not_in_ir4));
registerAdapter(std::make_unique<TypeRestriction>("Transpose", OpSetID(21), OpSetID(20), ir10_types_not_in_ir9));
registerAdapter(std::make_unique<TypeRestriction>("Unsqueeze", OpSetID(21), OpSetID(20), ir10_types_not_in_ir4));
/******** 21 -> 22 ********/
registerAdapter(std::make_unique<CompatibleAdapter>("EyeLike", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RandomUniform", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RandomNormal", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RandomUniformLike", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RandomNormalLike", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Multinomial", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Bernoulli", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("ThresholdedRelu", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Selu", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Elu", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Mish", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("HardSigmoid", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("HardSwish", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Softsign", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Softplus", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sin", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Cos", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Tan", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Asin", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Acos", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Atan", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Sinh", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Cosh", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Asinh", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Acosh", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Atanh", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Round", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Det", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("NegativeLogLikelihoodLoss", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("AveragePool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxPool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxUnpool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("LpPool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("MaxRoiPool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Conv", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("ConvTranspose", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("DeformConv", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("GlobalAveragePool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("GlobalMaxPool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("GlobalLpPool", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("InstanceNormalization", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("LpNormalization", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("Dropout", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RoiAlign", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("RNN", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("GRU", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("LSTM", OpSetID(21), OpSetID(22)));
registerAdapter(std::make_unique<CompatibleAdapter>("GridSample", OpSetID(21), OpSetID(22)));
/******** 22 -> 21 ********/
const std::vector<TensorProto_DataType> bfloat16_not_allowed = {TensorProto_DataType_BFLOAT16};
registerAdapter(std::make_unique<TypeRestriction>("EyeLike", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("AveragePool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("MaxPool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("RandomUniform", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("RandomNormal", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("RandomNormalLike", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("RandomUniformLike", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Multinomial", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Bernoulli", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("ThresholdedRelu", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Selu", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Elu", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Mish", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("HardSigmoid", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("HardSwish", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Softsign", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Softplus", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Sin", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Cos", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Tan", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Asin", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Acos", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Atan", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Sinh", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Cosh", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Asinh", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Acosh", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Atanh", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Round", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Det", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("NegativeLogLikelihoodLoss", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("MaxUnpool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("LpPool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("MaxRoiPool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Conv", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("ConvTranspose", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("DeformConv", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("GlobalAveragePool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("GlobalLpPool", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("InstanceNormalization", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(
std::make_unique<TypeRestriction>("LpNormalization", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("Dropout", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("RoiAlign", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("RNN", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("GRU", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("LSTM", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
registerAdapter(std::make_unique<TypeRestriction>("GridSample", OpSetID(22), OpSetID(21), bfloat16_not_allowed));
}
ModelProto convert_version(const ModelProto& mp_in, const OpSetID& initial_version, const OpSetID& target_version)
const override;
};
ModelProto ConvertVersion(const ModelProto& mp_in, int target_version);
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,88 @@
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
// Helper Methods for Adapters
#include "onnx/version_converter/helper.h"
#include <vector>
namespace ONNX_NAMESPACE {
namespace version_conversion {
int check_numpy_unibroadcastable_and_require_broadcast(
const std::vector<Dimension>& input1_sizes,
const std::vector<Dimension>& input2_sizes) {
// Check that input1 is larger
if (input1_sizes.size() < input2_sizes.size())
return -1;
// Check that axis is input1_sizes.size()-input2_sizes.size()
bool broadcast = false;
int axis = (int)(input1_sizes.size() - input2_sizes.size());
for (int i = 0; i < (int)input2_sizes.size(); i++) {
if (input2_sizes[i].dim != input1_sizes[axis + i].dim && input2_sizes[i].dim != 1)
return -1;
if (input2_sizes[i].dim != input1_sizes[axis + i].dim)
broadcast = true;
}
// Return true if broadcasting is required
if (input1_sizes.size() > input2_sizes.size() || broadcast)
return 1;
else
return 0;
}
void assert_numpy_multibroadcastable(
const std::vector<Dimension>& input1_sizes,
const std::vector<Dimension>& input2_sizes) {
// Generalize above for multibroadcastable case
const std::vector<Dimension>* A_ptr;
const std::vector<Dimension>* B_ptr;
int A;
int B;
if (input1_sizes.size() < input2_sizes.size()) {
A_ptr = &input2_sizes;
B_ptr = &input1_sizes;
A = 2;
B = 1;
} else {
A_ptr = &input1_sizes;
B_ptr = &input2_sizes;
A = 1;
B = 2;
}
const std::vector<Dimension>& A_sizes = *A_ptr;
const std::vector<Dimension>& B_sizes = *B_ptr;
int axis = (int)(A_sizes.size() - B_sizes.size());
for (int i = 0; i < (int)B_sizes.size(); i++) {
ONNX_ASSERTM(
B_sizes[i].dim == A_sizes[axis + i].dim || B_sizes[i].dim == 1 || A_sizes[axis + i].dim == 1,
"Dimension %d of input %d does not match "
"dimension %d of input %d, and neither's value is 1",
i,
B,
axis + i,
A);
}
}
void assertNotParams(const std::vector<Dimension>& sizes) {
for (const Dimension& dim : sizes) {
ONNX_ASSERTM(dim.is_int, "%s Dimension is a param instead of an int.", dim.param.c_str());
}
}
void assertInputsAvailable(const ArrayRef<Value*>& inputs, const char* name, uint64_t num_inputs) {
ONNX_ASSERTM(
inputs.size() == num_inputs,
"%s in opset version 6 can only broadcast"
" between %d inputs",
name,
num_inputs);
for (int i = 0; i < (int)num_inputs; i++) {
ONNX_ASSERTM(inputs[i]->has_sizes(), "Shape of input %d is not available.", num_inputs);
assertNotParams(inputs[i]->sizes());
}
}
} // namespace version_conversion
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,27 @@
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
// Helper Methods for Adapters
#pragma once
#include <vector>
#include "onnx/common/ir.h"
namespace ONNX_NAMESPACE {
namespace version_conversion {
int check_numpy_unibroadcastable_and_require_broadcast(
const std::vector<Dimension>& input1_sizes,
const std::vector<Dimension>& input2_sizes);
void assert_numpy_multibroadcastable(
const std::vector<Dimension>& input1_sizes,
const std::vector<Dimension>& input2_sizes);
void assertNotParams(const std::vector<Dimension>& sizes);
void assertInputsAvailable(const ArrayRef<Value*>& inputs, const char* name, uint64_t num_inputs);
} // namespace version_conversion
} // namespace ONNX_NAMESPACE