I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,199 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
//===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// ONNX: modified from llvm::ArrayRef.
// removed llvm-specific functionality
// removed some implicit const -> non-const conversions that rely on
// complicated std::enable_if meta-programming
// removed a bunch of slice variants for simplicity...
#pragma once
#include <assert.h>
#include <array>
#include <vector>
namespace ONNX_NAMESPACE {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
/// consecutively in memory), i.e. a start pointer and a length. It allows
/// various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the ArrayRef. For this reason, it is not in general
/// safe to store an ArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
template <typename T>
class ArrayRef {
public:
typedef const T* iterator;
typedef const T* const_iterator;
typedef size_t size_type;
typedef std::reverse_iterator<iterator> reverse_iterator;
private:
/// The start of the array, in an external buffer.
const T* Data;
/// The number of elements.
size_type Length;
public:
/// @name Constructors
/// @{
/// Construct an empty ArrayRef.
/*implicit*/ ArrayRef() : Data(nullptr), Length(0) {}
/// Construct an ArrayRef from a single element.
/*implicit*/ ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
/// Construct an ArrayRef from a pointer and length.
/*implicit*/ ArrayRef(const T* data, size_t length) : Data(data), Length(length) {}
/// Construct an ArrayRef from a range.
ArrayRef(const T* begin, const T* end) : Data(begin), Length(end - begin) {}
/// Construct an ArrayRef from a std::vector.
template <typename A>
/*implicit*/ ArrayRef(const std::vector<T, A>& Vec) : Data(Vec.data()), Length(Vec.size()) {}
/// Construct an ArrayRef from a std::array
template <size_t N>
/*implicit*/ constexpr ArrayRef(const std::array<T, N>& Arr) : Data(Arr.data()), Length(N) {}
/// Construct an ArrayRef from a C array.
template <size_t N>
/*implicit*/ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
/// Construct an ArrayRef from a std::initializer_list.
/*implicit*/ ArrayRef(const std::initializer_list<T>& Vec)
: Data(Vec.begin() == Vec.end() ? (T*)nullptr : Vec.begin()), Length(Vec.size()) {}
/// @}
/// @name Simple Operations
/// @{
iterator begin() const {
return Data;
}
iterator end() const {
return Data + Length;
}
reverse_iterator rbegin() const {
return reverse_iterator(end());
}
reverse_iterator rend() const {
return reverse_iterator(begin());
}
/// empty - Check if the array is empty.
bool empty() const {
return Length == 0;
}
const T* data() const {
return Data;
}
/// size - Get the array size.
size_t size() const {
return Length;
}
/// front - Get the first element.
const T& front() const {
assert(!empty());
return Data[0];
}
/// back - Get the last element.
const T& back() const {
assert(!empty());
return Data[Length - 1];
}
/// equals - Check for element-wise equality.
bool equals(ArrayRef RHS) const {
if (Length != RHS.Length)
return false;
return std::equal(begin(), end(), RHS.begin());
}
/// slice(n, m) - Chop off the first N elements of the array, and keep M
/// elements in the array.
ArrayRef<T> slice(size_t N, size_t M) const {
assert(N + M <= size() && "Invalid specifier");
return ArrayRef<T>(data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
ArrayRef<T> slice(size_t N) const {
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
const T& operator[](size_t Index) const {
assert(Index < Length && "Invalid index!");
return Data[Index];
}
/// Vector compatibility
const T& at(size_t Index) const {
assert(Index < Length && "Invalid index!");
return Data[Index];
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& operator=(U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, T>::value, ArrayRef<T>>::type& operator=(std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<T> vec() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
/// @name Conversion operators
/// @{
operator std::vector<T>() const {
return std::vector<T>(Data, Data + Length);
}
/// @}
};
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,45 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#include "onnx/common/assertions.h"
#include <array>
#include <cstdarg>
#include <cstdio>
#include "onnx/common/common.h"
namespace ONNX_NAMESPACE {
std::string barf(const char* fmt, ...) {
constexpr size_t buffer_size = 2048;
std::array<char, buffer_size> msg{};
va_list args;
va_start(args, fmt);
// use fixed length for buffer "msg" to avoid buffer overflow
vsnprintf(static_cast<char*>(msg.data()), msg.size() - 1, fmt, args);
// ensure null-terminated string to avoid out of bounds read
msg.back() = '\0';
va_end(args);
return std::string(msg.data());
}
void throw_assert_error(std::string& msg) {
ONNX_THROW_EX(assert_error(msg));
}
void throw_tensor_error(std::string& msg) {
ONNX_THROW_EX(tensor_error(msg));
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,72 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <stdexcept>
#include <string>
namespace ONNX_NAMESPACE {
struct assert_error : public std::runtime_error {
public:
explicit assert_error(const std::string& msg) : runtime_error(msg) {}
};
struct tensor_error : public assert_error {
public:
explicit tensor_error(const std::string& msg) : assert_error(msg) {}
};
std::string barf(const char* fmt, ...);
[[noreturn]] void throw_assert_error(std::string&);
[[noreturn]] void throw_tensor_error(std::string&);
} // namespace ONNX_NAMESPACE
#if defined(__GNUC__) || defined(__ICL) || defined(__clang__)
#define _ONNX_EXPECT(x, y) (__builtin_expect((x), (y)))
#else
#define _ONNX_EXPECT(x, y) (x)
#endif
#define ONNX_ASSERT(cond) \
if (_ONNX_EXPECT(!(cond), 0)) { \
std::string error_msg = \
::ONNX_NAMESPACE::barf("%s:%u: %s: Assertion `%s` failed.", __FILE__, __LINE__, __func__, #cond); \
throw_assert_error(error_msg); \
}
// The following is used to prevent MSVC from passing the whole __VA_ARGS__ list
// as the first parameter value to a macro call.
#define ONNX_EXPAND(x) x
// Note: msg must be a string literal
#define _ONNX_ASSERTM(cond, msg, ...) \
if (_ONNX_EXPECT(!(cond), 0)) { \
std::string error_msg = ::ONNX_NAMESPACE::barf( \
"%s:%u: %s: Assertion `%s` failed: " msg, __FILE__, __LINE__, __func__, #cond, __VA_ARGS__); \
throw_assert_error(error_msg); \
}
// The trailing ' ' argument is a hack to deal with the extra comma when ... is empty.
// Another way to solve this is ##__VA_ARGS__ in _ONNX_ASSERTM, but this is a non-portable
// extension we shouldn't use.
#define ONNX_ASSERTM(...) ONNX_EXPAND(_ONNX_ASSERTM(__VA_ARGS__, " "))
#define _TENSOR_ASSERTM(cond, msg, ...) \
if (_ONNX_EXPECT(!(cond), 0)) { \
std::string error_msg = ::ONNX_NAMESPACE::barf( \
"%s:%u: %s: Assertion `%s` failed: " msg, __FILE__, __LINE__, __func__, #cond, __VA_ARGS__); \
throw_tensor_error(error_msg); \
}
#define TENSOR_ASSERTM(...) ONNX_EXPAND(_TENSOR_ASSERTM(__VA_ARGS__, " "))

View File

@ -0,0 +1,55 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#define ONNX_UNUSED_PARAMETER(x) (void)(x)
#ifdef ONNX_NO_EXCEPTIONS
#include <iostream>
#define ONNX_THROW(...) \
do { \
std::cerr << ONNX_NAMESPACE::MakeString(__VA_ARGS__); \
abort(); \
} while (false)
#define ONNX_THROW_EX(ex) \
do { \
std::cerr << ex.what() << std::endl; \
abort(); \
} while (false)
#define ONNX_TRY if (true)
#define ONNX_CATCH(x) else if (false)
#define ONNX_HANDLE_EXCEPTION(func)
#else
#define ONNX_THROW(...) throw std::runtime_error(ONNX_NAMESPACE::MakeString(__VA_ARGS__))
#define ONNX_THROW_EX(ex) throw ex
#define ONNX_TRY try
#define ONNX_CATCH(x) catch (x)
#define ONNX_HANDLE_EXCEPTION(func) func()
#endif
// Macros to disable the copy and/or assignment methods
// These are usually placed in the private: declarations for a class.
#define ONNX_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete
#define ONNX_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete
#define ONNX_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \
ONNX_DISALLOW_COPY(TypeName); \
ONNX_DISALLOW_ASSIGNMENT(TypeName)
#define ONNX_DISALLOW_MOVE(TypeName) \
TypeName(TypeName&&) = delete; \
TypeName& operator=(TypeName&&) = delete
#define ONNX_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \
ONNX_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \
ONNX_DISALLOW_MOVE(TypeName)

View File

@ -0,0 +1,45 @@
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <string>
namespace ONNX_NAMESPACE {
// For ONNX op/function registration.
// ONNX domains.
constexpr const char* AI_ONNX_ML_DOMAIN = "ai.onnx.ml";
constexpr const char* AI_ONNX_TRAINING_DOMAIN = "ai.onnx.training";
constexpr const char* AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training";
// The following two are equivalent in an onnx proto representation.
constexpr const char* ONNX_DOMAIN = "";
constexpr const char* AI_ONNX_DOMAIN = "ai.onnx";
inline std::string NormalizeDomain(const std::string& domain) {
return (domain == AI_ONNX_DOMAIN) ? ONNX_DOMAIN : domain;
}
inline bool IsOnnxDomain(const std::string& domain) {
return (domain == AI_ONNX_DOMAIN) || ((domain == ONNX_DOMAIN));
}
constexpr bool OPTIONAL_VALUE = false;
// For dimension denotation.
constexpr const char* DATA_BATCH = "DATA_BATCH";
constexpr const char* DATA_CHANNEL = "DATA_CHANNEL";
constexpr const char* DATA_TIME = "DATA_TIME";
constexpr const char* DATA_FEATURE = "DATA_FEATURE";
constexpr const char* FILTER_IN_CHANNEL = "FILTER_IN_CHANNEL";
constexpr const char* FILTER_OUT_CHANNEL = "FILTER_OUT_CHANNEL";
constexpr const char* FILTER_SPATIAL = "FILTER_SPATIAL";
// For type denotation.
constexpr const char* TENSOR = "TENSOR";
constexpr const char* IMAGE = "IMAGE";
constexpr const char* AUDIO = "AUDIO";
constexpr const char* TEXT = "TEXT";
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,31 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <filesystem>
#include <fstream>
#include <string>
#include "onnx/checker.h"
#include "onnx/common/path.h"
namespace ONNX_NAMESPACE {
template <typename T>
void LoadProtoFromPath(const std::string proto_path, T& proto) {
std::filesystem::path proto_u8_path = std::filesystem::u8path(proto_path);
std::fstream proto_stream(proto_u8_path, std::ios::in | std::ios::binary);
if (!proto_stream.good()) {
fail_check("Unable to open proto file: ", proto_path, ". Please check if it is a valid proto. ");
}
std::string data{std::istreambuf_iterator<char>{proto_stream}, std::istreambuf_iterator<char>{}};
if (!ParseProtoFromBytes(&proto, data.c_str(), data.size())) {
fail_check(
"Unable to parse proto from file: ", proto_path, ". Please check if it is a valid protobuf file of proto. ");
}
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,167 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#include "onnx/common/assertions.h"
namespace ONNX_NAMESPACE {
// Intrusive doubly linked lists with sane reverse iterators.
// The header file is named graph_node_list.h because it is ONLY
// used for Graph's Node lists, and if you want to use it for other
// things, you will have to do some refactoring.
//
// At the moment, the templated type T must support a few operations:
//
// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
// which are used for the intrusive linked list pointers.
//
// - It must have a method 'destroy()', which removes T from the
// list and frees a T.
//
// In practice, we are only using it with Node and const Node. 'destroy()'
// needs to be renegotiated if you want to use this somewhere else.
//
// Besides the benefits of being intrusive, unlike std::list, these lists handle
// forward and backward iteration uniformly because we require a
// "before-first-element" sentinel. This means that reverse iterators
// physically point to the element they logically point to, rather than
// the off-by-one behavior for all standard library reverse iterators.
static constexpr size_t kNextDirection = 0;
static constexpr size_t kPrevDirection = 1;
template <typename T>
struct generic_graph_node_list;
template <typename T>
struct generic_graph_node_list_iterator;
struct Node;
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;
template <typename T>
struct generic_graph_node_list_iterator final {
generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {}
T* operator*() const {
return cur;
}
T* operator->() const {
return cur;
}
generic_graph_node_list_iterator& operator++() {
ONNX_ASSERT(cur);
cur = cur->next_in_graph[d];
return *this;
}
generic_graph_node_list_iterator operator++(int) {
generic_graph_node_list_iterator old = *this;
++(*this);
return old;
}
generic_graph_node_list_iterator& operator--() {
ONNX_ASSERT(cur);
cur = cur->next_in_graph[reverseDir()];
return *this;
}
generic_graph_node_list_iterator operator--(int) {
generic_graph_node_list_iterator old = *this;
--(*this);
return old;
}
// erase cur without invalidating this iterator
// named differently from destroy so that ->/. bugs do not
// silently cause the wrong one to be called.
// iterator will point to the previous entry after call
void destroyCurrent() {
T* n = cur;
cur = cur->next_in_graph[reverseDir()];
n->destroy();
}
generic_graph_node_list_iterator reverse() {
return generic_graph_node_list_iterator(cur, reverseDir());
}
private:
size_t reverseDir() {
return d == kNextDirection ? kPrevDirection : kNextDirection;
}
T* cur;
size_t d; // direction 0 is forward 1 is reverse, see next_in_graph
};
template <typename T>
struct generic_graph_node_list final {
using iterator = generic_graph_node_list_iterator<T>;
using const_iterator = generic_graph_node_list_iterator<const T>;
generic_graph_node_list_iterator<T> begin() {
return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<const T> begin() const {
return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
}
generic_graph_node_list_iterator<T> end() {
return generic_graph_node_list_iterator<T>(head, d);
}
generic_graph_node_list_iterator<const T> end() const {
return generic_graph_node_list_iterator<const T>(head, d);
}
generic_graph_node_list_iterator<T> rbegin() {
return reverse().begin();
}
generic_graph_node_list_iterator<const T> rbegin() const {
return reverse().begin();
}
generic_graph_node_list_iterator<T> rend() {
return reverse().end();
}
generic_graph_node_list_iterator<const T> rend() const {
return reverse().end();
}
generic_graph_node_list reverse() {
return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
}
const generic_graph_node_list reverse() const {
return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
}
generic_graph_node_list(T* head, size_t d) : head(head), d(d) {}
private:
T* head;
size_t d;
};
template <typename T>
static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
return *a == *b;
}
template <typename T>
static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
return *a != *b;
}
} // namespace ONNX_NAMESPACE
namespace std {
template <typename T>
struct iterator_traits<ONNX_NAMESPACE::generic_graph_node_list_iterator<T>> {
using difference_type = int64_t;
using value_type = T*;
using pointer = T**;
using reference = T*&;
using iterator_category = bidirectional_iterator_tag;
};
} // namespace std

View File

@ -0,0 +1,81 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#include "onnx/common/interned_strings.h"
#include <stdint.h>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "onnx/common/assertions.h"
namespace ONNX_NAMESPACE {
struct InternedStrings {
InternedStrings() : next_sym(kLastSymbol) {
#define REGISTER_SYMBOL(s) \
string_to_sym_[#s] = k##s; \
sym_to_string_[k##s] = #s;
FORALL_BUILTIN_SYMBOLS(REGISTER_SYMBOL)
#undef REGISTER_SYMBOL
}
uint32_t symbol(const std::string& s) {
std::lock_guard<std::mutex> guard(mutex_);
auto it = string_to_sym_.find(s);
if (it != string_to_sym_.end())
return it->second;
uint32_t k = next_sym++;
string_to_sym_[s] = k;
sym_to_string_[k] = s;
return k;
}
const char* string(Symbol sym) {
// Builtin Symbols are also in the maps, but
// we can bypass the need to acquire a lock
// to read the map for Builtins because we already
// know their string value
switch (sym) {
#define DEFINE_CASE(s) \
case k##s: \
return #s;
FORALL_BUILTIN_SYMBOLS(DEFINE_CASE)
#undef DEFINE_CASE
default:
return customString(sym);
}
}
private:
const char* customString(Symbol sym) {
std::lock_guard<std::mutex> guard(mutex_);
auto it = sym_to_string_.find(sym);
ONNX_ASSERT(it != sym_to_string_.end());
return it->second.c_str();
}
std::unordered_map<std::string, uint32_t> string_to_sym_;
std::unordered_map<uint32_t, std::string> sym_to_string_;
uint32_t next_sym;
std::mutex mutex_;
};
static InternedStrings& globalStrings() {
static InternedStrings s;
return s;
}
const char* Symbol::toString() const {
return globalStrings().string(*this);
}
Symbol::Symbol(const std::string& s) : value(globalStrings().symbol(s)) {}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,244 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <stdint.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace ONNX_NAMESPACE {
#define FORALL_BUILTIN_SYMBOLS(_) \
_(spatial) \
_(select_last_index) \
_(coordinate_transformation_mode) \
_(PythonOp) \
_(CppOp) \
_(Param) \
_(Select) \
_(Return) \
_(Eval) \
_(add) \
_(Add) \
_(Div) \
_(Mul) \
_(Neg) \
_(Sub) \
_(Pow) \
_(Sigmoid) \
_(ArgMax) \
_(Concat) \
_(Softmax) \
_(LogSoftmax) \
_(Dropout) \
_(Tanh) \
_(mul) \
_(neg) \
_(sigmoid) \
_(tanh) \
_(Constant) \
_(cat) \
_(Slice) \
_(Squeeze) \
_(Undefined) \
_(FusionGroup) \
_(MatMul) \
_(Gemm) \
_(Tile) \
_(SubConstant) \
_(Scale) \
_(Transpose) \
_(Pad) \
_(Reshape) \
_(split) \
_(chunk) \
_(Offset) \
_(value) \
_(Subgraph) \
_(BatchNormalization) \
_(Conv) \
_(ConvTranspose) \
_(is_test) \
_(epsilon) \
_(expand) \
_(Expand) \
_(order) \
_(momentum) \
_(consumed_inputs) \
_(kernels) \
_(kernel_shape) \
_(kernel) \
_(scale) \
_(strides) \
_(stride) \
_(pads) \
_(pad) \
_(beta) \
_(alpha) \
_(dilations) \
_(dilation) \
_(broadcast) \
_(axis) \
_(ratio) \
_(size) \
_(dim) \
_(keepdims) \
_(perm) \
_(shape) \
_(axes) \
_(group) \
_(inplace) \
_(transA) \
_(transB) \
_(other) \
_(__and__) \
_(__lshift__) \
_(__or__) \
_(__rshift__) \
_(__xor__) \
_(abs) \
_(acos) \
_(asin) \
_(atan) \
_(atan2) \
_(ceil) \
_(clamp) \
_(cos) \
_(cosh) \
_(div) \
_(eq) \
_(equal) \
_(Exp) \
_(ends) \
_(expm1) \
_(floor) \
_(fmod) \
_(frac) \
_(ge) \
_(gt) \
_(le) \
_(lerp) \
_(lgamma) \
_(Log) \
_(log1p) \
_(lt) \
_(max) \
_(min) \
_(ne) \
_(ones) \
_(pow) \
_(reciprocal) \
_(remainder) \
_(round) \
_(rsqrt) \
_(sin) \
_(sinh) \
_(Sqrt) \
_(sub) \
_(starts) \
_(tan) \
_(trunc) \
_(zeros) \
_(exponent) \
_(device) \
_(mode) \
_(Identity) \
_(Loop) \
_(If) \
_(body) \
_(then_branch) \
_(else_branch) \
_(Captured) \
_(__control_inputs) \
_(count_include_pad) \
_(storage_order) \
_(Unsqueeze) \
_(ReduceL1) \
_(ReduceL2) \
_(ReduceLogSum) \
_(ReduceLogSumExp) \
_(ReduceMax) \
_(ReduceMean) \
_(ReduceMin) \
_(ReduceProd) \
_(ReduceSum) \
_(ReduceSumSquare) \
_(Cast) \
_(to) \
_(PRelu) \
_(Greater) \
_(Less) \
_(scales) \
_(Upsample) \
_(RNN) \
_(layout) \
_(k) \
_(Flatten) \
_(ScatterElements) \
_(Resize) \
_(ceil_mode) \
_(num_outputs) \
_(start) \
_(end) \
_(num_groups) \
_(stash_type) \
_(block_size) \
_(output_dtype)
enum BuiltinSymbol {
#define DEFINE_SYMBOL(s) k##s,
FORALL_BUILTIN_SYMBOLS(DEFINE_SYMBOL)
#undef DEFINE_SYMBOL
kLastSymbol, // where we start counting for new symbols
};
struct Symbol {
Symbol() {}
/*implicit*/ Symbol(BuiltinSymbol value) : value(value) {}
explicit Symbol(const std::string& s);
explicit Symbol(uint32_t value) : value(value) {}
operator uint32_t() const {
return value;
}
const char* toString() const;
private:
uint32_t value;
};
static inline bool operator==(Symbol lhs, Symbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
// necessary to prevent ambiguous overload resolutions
static inline bool operator==(BuiltinSymbol lhs, Symbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
static inline bool operator==(Symbol lhs, BuiltinSymbol rhs) {
return static_cast<uint32_t>(lhs) == static_cast<uint32_t>(rhs);
}
inline Symbol operator"" _sym(const char* s, size_t) {
return Symbol(s);
}
} // namespace ONNX_NAMESPACE
// make symbol behave like an integer in hash tables
namespace std {
template <>
struct hash<ONNX_NAMESPACE::Symbol> {
std::size_t operator()(ONNX_NAMESPACE::Symbol s) const {
return std::hash<uint32_t>()(static_cast<uint32_t>(s));
}
};
} // namespace std

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,721 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#include "onnx/common/ir_pb_converter.h"
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
namespace ONNX_NAMESPACE {
// Part 1: convert ONNX Protobuf to IR
std::unique_ptr<Graph> graphProtoToGraph(const GraphProto& gp, bool nested, const int ir_version = IR_VERSION);
Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
Tensor ret;
ret.sizes().reserve(tp.dims_size());
for (int i = 0; i < tp.dims_size(); i++) {
ret.sizes().push_back(tp.dims(i));
}
ret.elem_type() = tp.data_type();
switch (tp.data_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
ret.floats().reserve(tp.float_data_size());
for (int i = 0; i < tp.float_data_size(); i++) {
ret.floats().push_back(tp.float_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: {
ret.int32s().reserve(tp.int32_data_size());
for (int i = 0; i < tp.int32_data_size(); i++) {
ret.int32s().push_back(tp.int32_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
ret.int64s().reserve(tp.int64_data_size());
for (int i = 0; i < tp.int64_data_size(); i++) {
ret.int64s().push_back(tp.int64_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
ret.uint64s().reserve(tp.uint64_data_size());
for (int i = 0; i < tp.uint64_data_size(); i++) {
ret.uint64s().push_back(tp.uint64_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
ret.doubles().reserve(tp.double_data_size());
for (int i = 0; i < tp.double_data_size(); i++) {
ret.doubles().push_back(tp.double_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
ret.strings().reserve(tp.string_data_size());
for (int i = 0; i < tp.string_data_size(); i++) {
ret.strings().push_back(tp.string_data(i));
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
fail_convert("Unknown tensor data type");
}
// The only way to know if we should be using raw_data or
// <type>_data is to look at which of them is size zero.
if (tp.has_raw_data()) {
ret.set_raw_data(tp.raw_data());
}
if (tp.has_name()) {
ret.setName(tp.name());
}
if (tp.has_segment()) {
ret.set_segment_begin_and_end(tp.segment().begin(), tp.segment().end());
}
return ret;
}
void convertAttribute(const ONNX_NAMESPACE::AttributeProto& ap, Node* n, const int ir_version = IR_VERSION) {
Symbol sym = Symbol(ap.name());
switch (ap.type()) {
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT:
n->f_(sym, ap.f());
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS: {
std::vector<double> floats;
floats.reserve(ap.floats_size());
for (int i = 0; i < ap.floats_size(); i++) {
floats.push_back(ap.floats(i));
}
n->fs_(sym, std::move(floats));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_INT:
n->i_(sym, ap.i());
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_INTS: {
std::vector<int64_t> ints;
ints.reserve(ap.ints_size());
for (int i = 0; i < ap.ints_size(); i++) {
ints.push_back(ap.ints(i));
}
n->is_(sym, std::move(ints));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRING:
n->s_(sym, ap.s());
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS: {
std::vector<std::string> strings;
strings.reserve(ap.strings_size());
for (int i = 0; i < ap.strings_size(); i++) {
strings.push_back(ap.strings(i));
}
n->ss_(sym, std::move(strings));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR:
n->t_(sym, tensorProtoToTensor(ap.t()));
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS: {
std::vector<Tensor> tensors;
tensors.reserve(ap.tensors_size());
for (int i = 0; i < ap.tensors_size(); i++) {
tensors.push_back(tensorProtoToTensor(ap.tensors(i)));
}
n->ts_(sym, std::move(tensors));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO:
n->tp_(sym, ap.tp());
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS: {
std::vector<TypeProto> types;
types.reserve(ap.type_protos_size());
for (int i = 0; i < ap.type_protos_size(); i++) {
types.push_back(ap.type_protos(i));
}
n->tps_(sym, std::move(types));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH:
n->g_(sym, graphProtoToGraph(ap.g(), true, ir_version));
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS: {
std::vector<std::shared_ptr<Graph>> graphs;
graphs.reserve(ap.graphs_size());
for (int i = 0; i < ap.graphs_size(); i++) {
graphs.push_back(graphProtoToGraph(ap.graphs(i), true, ir_version));
}
n->gs_(sym, std::move(graphs));
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSOR:
case ONNX_NAMESPACE::AttributeProto_AttributeType_SPARSE_TENSORS:
fail_convert("Sparse tensors not supported.");
break;
case ONNX_NAMESPACE::AttributeProto_AttributeType_UNDEFINED:
fail_convert("Unknown tensor data type");
break;
}
}
void convertAttributes(ONNX_NAMESPACE::NodeProto& np, Node* n, const int ir_version = IR_VERSION) {
for (int i = 0; i < np.attribute_size(); i++) {
convertAttribute(np.attribute(i), n, ir_version);
}
}
std::vector<Dimension> tensorShapeProtoToDimensions(const ONNX_NAMESPACE::TensorShapeProto& tsp) {
std::vector<Dimension> dims;
dims.reserve(tsp.dim_size());
for (int i = 0; i < tsp.dim_size(); i++) {
if (tsp.dim(i).has_dim_value()) {
dims.emplace_back(tsp.dim(i).dim_value());
} else if (tsp.dim(i).has_dim_param()) {
dims.emplace_back(tsp.dim(i).dim_param());
} else {
// a dimension that has neither dim_value nor dim_param set
// represents an unknown dimension unrelated to other unknown dimensions.
dims.emplace_back();
}
}
return dims;
}
void createDummyValue(
std::unique_ptr<Graph>& g,
const std::string& name,
std::unordered_map<std::string, Value*>& value_by_name_of) {
auto* undef = g->create(kCaptured, 1);
g->appendNode(undef);
undef->outputs()[0]->setUniqueName(name);
value_by_name_of[name] = undef->outputs()[0];
}
std::unique_ptr<Graph> graphProtoToGraph(const ONNX_NAMESPACE::GraphProto& gp, bool nested, const int ir_version) {
std::unique_ptr<Graph> g(new Graph());
if (gp.has_name()) {
g->setName(gp.name());
}
if (gp.has_doc_string()) {
g->setDocString(gp.doc_string());
}
// Values are created (as in `new Value(..)`) by the Node that
// outputs them. Therefore we initialize the Nodes and Values in
// several stages.
//
// 1) add all input (to the graph) Values, owned by the sentinel Param node
// 2) add all Nodes and their output Values, but don't intialize inputs
// 3) initialize inputs of all Nodes
// 4) initialize inputs of the Return sentinel node
// 5) fill in type info for graph outputs, and register them as outputs
// 6) fill in type info for Values from the value_info list in the graph
// In ONNX proto land, Values are just strings. We are going to make
// objects out of them, and equal strings must be mapped to the same
// Value object.
std::unordered_map<std::string, Value*> value_by_name_of;
// We initialize Node inputs in a separate pass from the Nodes
// themselves. To do so, we need to have access to the names of the
// inputs.
std::unordered_map<Node*, std::vector<std::string>> inputs_by_node;
{
// ONNX represents optional arguments in two ways
// - they are simply not provided
// - OR the empty string is passed as the input name
// This is to handle that second case, which needs a dummy node to
// be representable in the graph IR.
auto* n = g->create(kUndefined, 1);
g->appendNode(n);
n->outputs()[0]->setUniqueName("");
value_by_name_of[""] = n->outputs()[0];
}
for (int i = 0; i < gp.input_size(); i++) {
const auto& vip = gp.input(i);
auto v = g->addInput();
const auto& tensor_type = vip.type().tensor_type();
if (tensor_type.has_elem_type()) {
v->setElemType(tensor_type.elem_type());
}
if (tensor_type.has_shape()) {
v->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
}
v->setUniqueName(vip.name());
value_by_name_of[vip.name()] = v;
}
// initializers should be added before all nodes,
// otherwise getNextUnique() may conflicts with an existing initializer name.
for (int i = 0; i < gp.initializer_size(); ++i) {
auto init = tensorProtoToTensor(gp.initializer(i));
// If ir_version >= 4, initializer does not have to be included in input
// Create a Value from initializer by addInitializerNode if name does not exist in input
// and save it into value_by_name_of for later use (node input)
if (ir_version >= 4 && value_by_name_of.count(init.name()) == 0) {
value_by_name_of[init.name()] = g->addInitializerAndCreateValue(init);
} else {
// If ir_version < 4 or the initializer exists in input
// Simply add initializer without creating new value
// which means it will prioritize input value over initializer value if both exist
g->addInitializer(init);
}
}
for (int i = 0; i < gp.node_size(); i++) {
auto np = gp.node(i);
auto* n = g->create(Symbol(np.op_type()), /* num_outputs = */ np.output_size());
g->appendNode(n);
for (int j = 0; j < np.output_size(); j++) {
auto out = n->outputs()[j];
// we don't know the real type here, so that's done in a later pass
out->setElemType(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
out->setUniqueName(np.output(j));
value_by_name_of[np.output(j)] = out;
}
convertAttributes(np, n, ir_version);
std::vector<std::string> inputs;
inputs.reserve(np.input_size());
for (int j = 0; j < np.input_size(); j++) {
inputs.push_back(np.input(j));
}
inputs_by_node[n] = inputs;
if (np.has_doc_string()) {
n->setDocString(np.doc_string());
}
if (np.has_name()) {
n->setName(np.name());
}
if (np.has_domain()) {
n->setDomain(np.domain());
}
if (np.has_overload()) {
n->setOverload(np.overload());
}
}
for (auto n : g->nodes()) {
auto search = inputs_by_node.find(n);
if (search == inputs_by_node.end()) {
continue;
}
for (const auto& input : search->second) {
if (!value_by_name_of.count(input) && nested) {
// Undefined reference to an input in a nested block. This may be a
// captured value. Create a dummy node that we ignore later.
createDummyValue(g, input, value_by_name_of);
}
if (!value_by_name_of.count(input)) {
std::ostringstream msg;
msg << "Input " << input << " is undefined!";
ONNX_THROW_EX(std::out_of_range(msg.str()));
}
n->addInput(value_by_name_of.at(input));
}
}
for (int i = 0; i < gp.output_size(); i++) {
if (!value_by_name_of.count(gp.output(i).name()) && nested) {
// Same captured value logic as above. We can consider outputs of a
// graph to be "inputs" of a dummy "output" node. The same lexical
// scoping rules are valid here, thus we need to add a dummy node
// in the case of an undefined reference
createDummyValue(g, gp.output(i).name(), value_by_name_of);
}
const auto& output_tensor_type = gp.output(i).type().tensor_type();
if (output_tensor_type.has_elem_type()) {
value_by_name_of[gp.output(i).name()]->setElemType(output_tensor_type.elem_type());
}
if (output_tensor_type.has_shape()) {
value_by_name_of[gp.output(i).name()]->setSizes(tensorShapeProtoToDimensions(output_tensor_type.shape()));
}
g->registerOutput(value_by_name_of[gp.output(i).name()]);
}
for (int i = 0; i < gp.value_info_size(); i++) {
const auto& tensor_type = gp.value_info(i).type().tensor_type();
if (!value_by_name_of.count(gp.value_info(i).name())) {
// Ideally the model should not have a value_info whose name does not exist in the graph (unused); simply skip it
continue;
}
if (tensor_type.has_elem_type()) {
value_by_name_of[gp.value_info(i).name()]->setElemType(tensor_type.elem_type());
}
if (tensor_type.has_shape()) {
value_by_name_of[gp.value_info(i).name()]->setSizes(tensorShapeProtoToDimensions(tensor_type.shape()));
}
}
return g;
}
std::unique_ptr<Graph> ImportModelProto(const ModelProto& mp) {
if (!mp.has_ir_version()) {
return nullptr;
} else if (mp.ir_version() <= 1) {
// ir_version=1 is not supported and ir_version=0 is illegal
return nullptr;
}
std::unique_ptr<Graph> g(graphProtoToGraph(mp.graph(), false, mp.ir_version()));
for (int i = 0; i < mp.opset_import_size(); i++) {
OpSetID new_opset_version(mp.opset_import(i).domain(), mp.opset_import(i).version());
g->forSelfAndEachSubGraph(
[&new_opset_version](Graph* graph) { graph->opset_versions_mutable().emplace_back(new_opset_version); });
}
return g;
}
// Part 2: convert IR to ONNX Protobuf
std::string value_name(Value* n) {
return n->uniqueName();
}
void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g);
void encodeTensor(ONNX_NAMESPACE::TensorProto* p, const Tensor& tensor) {
if (tensor.hasName()) {
p->set_name(tensor.name());
}
if (tensor.is_segment()) {
ONNX_NAMESPACE::TensorProto_Segment segment;
segment.set_begin(tensor.segment_begin());
segment.set_end(tensor.segment_end());
p->mutable_segment()->CopyFrom(segment);
}
for (auto d : tensor.sizes()) {
p->add_dims(d);
}
p->set_data_type(tensor.elem_type());
switch (tensor.elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64: {
for (float x : tensor.floats()) {
p->add_float_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
for (int32_t x : tensor.int32s()) {
p->add_int32_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
for (int64_t x : tensor.int64s()) {
p->add_int64_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
for (uint64_t x : tensor.uint64s()) {
p->add_uint64_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128: {
for (double x : tensor.doubles()) {
p->add_double_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_STRING: {
for (const std::string& x : tensor.strings()) {
p->add_string_data(x);
}
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED:
fail_convert("Unknown tensor data type");
}
if (tensor.is_raw_data()) {
p->set_raw_data(tensor.raw());
}
}
void addAttribute(ONNX_NAMESPACE::NodeProto* n_p, Node* n, Symbol name) {
auto attr = n_p->add_attribute();
attr->set_name(name.toString());
switch (n->kindOf(name)) {
case AttributeKind::f: {
attr->set_f(static_cast<float>(n->f(name)));
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT);
} break;
case AttributeKind::fs: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS);
for (auto& v : n->fs(name))
attr->add_floats(static_cast<float>(v));
} break;
case AttributeKind::i: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
attr->set_i(n->i(name));
} break;
case AttributeKind::is: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS);
for (auto& v : n->is(name))
attr->add_ints(v);
} break;
case AttributeKind::s: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING);
attr->set_s(n->s(name));
} break;
case AttributeKind::ss: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS);
for (auto& v : n->ss(name))
attr->add_strings(v);
} break;
case AttributeKind::t: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
auto t = attr->mutable_t();
encodeTensor(t, n->t(name));
} break;
case AttributeKind::ts: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSORS);
for (auto& v : n->ts(name)) {
auto t = attr->add_tensors();
encodeTensor(t, v);
}
} break;
case AttributeKind::g: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPH);
auto g = attr->mutable_g();
encodeGraph(g, n->g(name));
} break;
case AttributeKind::gs: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_GRAPHS);
for (auto& v : n->gs(name)) {
auto g = attr->add_graphs();
encodeGraph(g, v);
}
} break;
case AttributeKind::tp: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTO);
auto tp = attr->mutable_tp();
tp->CopyFrom(n->tp(name));
} break;
case AttributeKind::tps: {
attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TYPE_PROTOS);
for (auto& v : n->tps(name)) {
auto tp = attr->add_type_protos();
tp->CopyFrom(v);
}
} break;
}
}
void encodeTypeProtoTensorType(ONNX_NAMESPACE::TypeProto_Tensor* tensor_type, Value* n) {
if (n->elemType() != 0) {
tensor_type->set_elem_type(n->elemType());
}
if (n->has_sizes()) {
ONNX_NAMESPACE::TensorShapeProto* shape = tensor_type->mutable_shape();
for (const Dimension& d : n->sizes()) {
auto dim = shape->add_dim();
if (!d.is_unknown) {
if (d.is_int) {
dim->set_dim_value(d.dim);
} else {
dim->set_dim_param(d.param);
}
}
}
}
}
void encodeValueInfo(ONNX_NAMESPACE::ValueInfoProto* v, Value* n) {
v->set_name(value_name(n));
if (n->elemType() != 0 || n->has_sizes()) {
ONNX_NAMESPACE::TypeProto* t = v->mutable_type();
ONNX_NAMESPACE::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
encodeTypeProtoTensorType(tensor_type, n);
}
}
void encodeGraph(GraphProto* p_g, const std::shared_ptr<Graph>& g) {
ONNX_ASSERT(p_g != nullptr);
if (g->has_name()) {
p_g->set_name(g->name());
}
if (g->has_doc_string()) {
p_g->set_doc_string(g->docString());
}
for (auto input : g->inputs()) {
ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_input();
encodeValueInfo(v, input);
}
for (auto output : g->outputs()) {
ONNX_NAMESPACE::ValueInfoProto* v = p_g->add_output();
encodeValueInfo(v, output);
}
std::unordered_set<Value*> graph_outputs(g->outputs().begin(), g->outputs().end());
for (auto node : g->nodes()) {
if (node->kind() == kUndefined || node->kind() == kCaptured) {
// Undefined nodes are used to represent optional inputs that are not
// provided.
continue;
}
auto p_n = p_g->add_node();
for (auto input : node->inputs()) {
if (input->node()->kind() == kUndefined) {
p_n->add_input("");
} else {
p_n->add_input(value_name(input));
}
}
for (auto output : node->outputs()) {
p_n->add_output(value_name(output));
// only save it if
// - it has actual information worth saving
// - it's not already saved in the graph outputs value info
if (graph_outputs.find(output) != graph_outputs.end()) {
continue;
}
if (output->elemType() == TensorProto_DataType_UNDEFINED && output->sizes().empty()) {
continue;
}
ValueInfoProto* v = p_g->add_value_info();
encodeValueInfo(v, output);
}
p_n->set_op_type(node->kind().toString());
for (auto attr_name : node->attributeNames()) {
addAttribute(p_n, node, attr_name);
}
if (node->has_doc_string()) {
p_n->set_doc_string(node->docString());
}
if (node->has_name()) {
p_n->set_name(node->name());
}
if (node->has_domain()) {
p_n->set_domain(node->domain());
}
if (node->has_overload()) {
p_n->set_overload(node->overload());
}
}
auto num_initializers = g->initializers().size();
for (unsigned int i = 0; i < num_initializers; i++) {
auto p = p_g->add_initializer();
p->set_name(g->initializer_names()[i]);
encodeTensor(p, g->initializers()[i]);
}
}
void ExportModelProto(ModelProto* p_m, const std::shared_ptr<Graph>& g) {
GraphProto* p_g = p_m->mutable_graph();
encodeGraph(p_g, g);
// Add new opset_versions
p_m->clear_opset_import();
for (const OpSetID& opset : g->opset_versions_mutable()) {
OperatorSetIdProto* opset_version_output = p_m->add_opset_import();
opset_version_output->set_domain(opset.domain());
opset_version_output->set_version(opset.version());
}
}
ModelProto PrepareOutput(const ModelProto& mp_in) {
ModelProto mp_out{};
if (mp_in.has_ir_version()) {
mp_out.set_ir_version(mp_in.ir_version());
}
if (mp_in.has_producer_name()) {
mp_out.set_producer_name(mp_in.producer_name());
}
if (mp_in.has_producer_version()) {
mp_out.set_producer_version(mp_in.producer_version());
}
if (mp_in.has_domain()) {
mp_out.set_domain(mp_in.domain());
}
if (mp_in.has_model_version()) {
mp_out.set_model_version(mp_in.model_version());
}
if (mp_in.has_doc_string()) {
mp_out.set_doc_string(mp_in.doc_string());
}
for (int i = 0; i < mp_in.opset_import_size(); i++) {
auto& oi_in = mp_in.opset_import(i);
auto* oi_out = mp_out.add_opset_import();
if (oi_in.has_domain()) {
oi_out->set_domain(oi_in.domain());
}
if (oi_in.has_version()) {
oi_out->set_version(oi_in.version());
}
}
for (int i = 0; i < mp_in.metadata_props_size(); i++) {
auto& pp_in = mp_in.metadata_props(i);
auto* pp_out = mp_out.add_metadata_props();
if (pp_in.has_key()) {
pp_out->set_key(pp_in.key());
}
if (pp_in.has_value()) {
pp_out->set_value(pp_in.value());
}
}
return mp_out;
}
void assertNonNull(const std::shared_ptr<Graph>& g) {
ONNX_ASSERTM(
g.get() != nullptr,
"Warning: onnx version converter is unable to parse input model. "
"(The IR version of the ONNX model may be too old.)");
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,50 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <memory>
#include <string>
#include "onnx/common/common.h"
#include "onnx/common/ir.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
class ConvertError final : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
explicit ConvertError(const std::string& message) : std::runtime_error(message) {}
const char* what() const noexcept override {
if (!expanded_message_.empty()) {
return expanded_message_.c_str();
}
return std::runtime_error::what();
}
void AppendContext(const std::string& context) {
expanded_message_ = MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
}
private:
std::string expanded_message_;
};
#define fail_convert(...) ONNX_THROW_EX(ConvertError(MakeString(__VA_ARGS__)));
void ExportModelProto(ModelProto* p_m, const std::shared_ptr<Graph>& g);
std::unique_ptr<Graph> ImportModelProto(const ModelProto& mp);
ModelProto PrepareOutput(const ModelProto& mp_in);
void assertNonNull(const std::shared_ptr<Graph>& g);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,37 @@
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
#include "onnx/common/model_helpers.h"
#include "onnx/checker.h"
#include "onnx/defs/schema.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
Common::Status BuildNode(
const std::string& name,
const std::string& domain,
const std::string& doc_string,
const std::string& op_type,
std::vector<std::string> const& inputs,
std::vector<std::string> const& outputs,
NodeProto* node) {
if (node == NULL) {
return Common::Status(Common::CHECKER, Common::INVALID_ARGUMENT, "node_proto should not be nullptr.");
}
node->set_name(name);
node->set_domain(domain);
node->set_doc_string(doc_string);
node->set_op_type(op_type);
for (auto& input : inputs) {
node->add_input(input);
}
for (auto& output : outputs) {
node->add_output(output);
}
return Common::Status::OK();
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,26 @@
// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0
#pragma once
#include <string>
#include <vector>
#include "onnx/common/status.h"
#include "onnx/onnx-operators_pb.h"
namespace ONNX_NAMESPACE {
// Helper function for register nodes in
// a FunctionProto. Attributes need to be
// registered separately.
Common::Status BuildNode(
const std::string& name,
const std::string& domain,
const std::string& doc_string,
const std::string& op_type,
std::vector<std::string> const& inputs,
std::vector<std::string> const& outputs,
/*OUT*/ NodeProto* node);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,82 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Copyright (c) ONNX Project Contributors.
#include "onnx/common/path.h"
namespace ONNX_NAMESPACE {
#ifdef _WIN32
#else
std::string path_join(const std::string& origin, const std::string& append) {
if (origin.find_last_of(k_preferred_path_separator) != origin.length() - 1) {
return origin + k_preferred_path_separator + append;
}
return origin + append;
}
std::string clean_relative_path(const std::string& path) {
if (path.empty()) {
return ".";
}
std::string out;
size_t n = path.size();
size_t r = 0;
size_t dotdot = 0;
while (r < n) {
if (path[r] == k_preferred_path_separator) {
r++;
continue;
}
if (path[r] == '.' && (r + 1 == n || path[r + 1] == k_preferred_path_separator)) {
r++;
continue;
}
if (path[r] == '.' && path[r + 1] == '.' && (r + 2 == n || path[r + 2] == k_preferred_path_separator)) {
r += 2;
if (out.size() > dotdot) {
while (out.size() > dotdot && out.back() != k_preferred_path_separator) {
out.pop_back();
}
if (!out.empty())
out.pop_back();
} else {
if (!out.empty()) {
out.push_back(k_preferred_path_separator);
}
out.push_back('.');
out.push_back('.');
dotdot = out.size();
}
continue;
}
if (!out.empty() && out.back() != k_preferred_path_separator) {
out.push_back(k_preferred_path_separator);
}
for (; r < n && path[r] != k_preferred_path_separator; r++) {
out.push_back(path[r]);
}
}
if (out.empty()) {
out.push_back('.');
}
return out;
}
#endif
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,64 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Copyright (c) ONNX Project Contributors.
#pragma once
#include <string>
#ifdef _WIN32
// windows.h has preproc definitions for min and max, which prevents from using std::min and std::max.
// defining NOMINMAX disables the preproc macro.
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <filesystem>
#include "onnx/checker.h"
#endif
namespace ONNX_NAMESPACE {
#ifdef _WIN32
constexpr const char k_preferred_path_separator = '\\';
#else // POSIX
constexpr const char k_preferred_path_separator = '/';
#endif
#ifdef _WIN32
inline std::wstring path_join(const std::wstring& origin, const std::wstring& append) {
return (std::filesystem::path(origin) / std::filesystem::path(append)).wstring();
}
inline std::wstring utf8str_to_wstring(const std::string& utf8str) {
if (utf8str.size() > INT_MAX) {
fail_check("utf8str_to_wstring: string is too long for converting to wstring.");
}
int size_required = MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), static_cast<int>(utf8str.size()), NULL, 0);
std::wstring ws_str(size_required, 0);
MultiByteToWideChar(CP_UTF8, 0, utf8str.c_str(), static_cast<int>(utf8str.size()), &ws_str[0], size_required);
return ws_str;
}
inline std::string wstring_to_utf8str(const std::wstring& ws_str) {
if (ws_str.size() > INT_MAX) {
fail_check("wstring_to_utf8str: string is too long for converting to UTF-8.");
}
int size_required =
WideCharToMultiByte(CP_UTF8, 0, ws_str.c_str(), static_cast<int>(ws_str.size()), NULL, 0, NULL, NULL);
std::string utf8str(size_required, 0);
WideCharToMultiByte(
CP_UTF8, 0, ws_str.c_str(), static_cast<int>(ws_str.size()), &utf8str[0], size_required, NULL, NULL);
return utf8str;
}
#else
std::string path_join(const std::string& origin, const std::string& append);
// TODO: also use std::filesystem::path for clean_relative_path after ONNX has supported C++17 for POSIX
// Clean up relative path when there is ".." in the path, e.g.: a/b/../c -> a/c
// It cannot work with absolute path
std::string clean_relative_path(const std::string& path);
#endif
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,19 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <cstdint>
namespace ONNX_NAMESPACE {
// Determine if the processor is little endian or not
inline bool is_processor_little_endian() {
constexpr std::int32_t value = 1;
return reinterpret_cast<const std::uint8_t*>(&value)[0] == 1;
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,43 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <string>
#include "onnx/common/constants.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
// ONNX (model-local) function identifiers are a tuple (domain, op, overload).
// The pair (domain, op) represents a specification of a function, while
// overload is used to disambiguate between multiple (specialized) implementations of
// the same specification. Overload is optional and can be empty.
// Multiple overloads may be used to distinguish implementations specialized
// for a specific type or rank of input tensors or for specific attribute values.
// A single string representation of (domain, op)
using FunctionSpecId = std::string;
// A single string representation of (domain, op, overload)
using FunctionImplId = std::string;
FunctionImplId GetFunctionImplId(const std::string& domain, const std::string& op, const std::string& overload) {
if (overload.empty())
return NormalizeDomain(domain) + "::" + op;
return NormalizeDomain(domain) + "::" + op + "::" + overload;
}
FunctionImplId GetFunctionImplId(const FunctionProto& function) {
return GetFunctionImplId(function.domain(), function.name(), function.overload());
}
FunctionImplId GetCalleeId(const NodeProto& node) {
return GetFunctionImplId(node.domain(), node.op_type(), node.overload());
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,89 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "status.h"
#include <assert.h>
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
namespace Common {
Status::Status(StatusCategory category, int code, const std::string& msg) {
assert(static_cast<int>(StatusCode::OK) != code);
state_.reset(new State(category, code, msg));
}
Status::Status(StatusCategory category, int code) : Status(category, code, EmptyString()) {}
bool Status::IsOK() const noexcept {
return (state_ == NULL);
}
StatusCategory Status::Category() const noexcept {
return IsOK() ? StatusCategory::NONE : state_->category;
}
int Status::Code() const noexcept {
return IsOK() ? static_cast<int>(StatusCode::OK) : state_->code;
}
const std::string& Status::ErrorMessage() const {
return IsOK() ? EmptyString() : state_->msg;
}
std::string Status::ToString() const {
if (state_ == nullptr) {
return std::string("OK");
}
std::string result;
if (StatusCategory::CHECKER == state_->category) {
result += "[CheckerError]";
} else if (StatusCategory::OPTIMIZER == state_->category) {
result += "[OptimizerError]";
}
result += " : ";
result += ONNX_NAMESPACE::to_string(Code());
std::string msg;
switch (static_cast<StatusCode>(Code())) {
case INVALID_ARGUMENT:
msg = "INVALID_ARGUMENT";
break;
case INVALID_PROTOBUF:
msg = "INVALID_PROTOBUF";
break;
case FAIL:
msg = "FAIL";
break;
default:
msg = "GENERAL ERROR";
break;
}
result += " : ";
result += msg;
result += " : ";
result += state_->msg;
return result;
}
const Status& Status::OK() noexcept {
static Status s_ok;
return s_ok;
}
const std::string& Status::EmptyString() {
static std::string empty_str;
return empty_str;
}
} // namespace Common
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,96 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <memory>
#include <ostream>
#include <string>
#include <utility>
namespace ONNX_NAMESPACE {
namespace Common {
enum StatusCategory {
NONE = 0,
CHECKER = 1,
OPTIMIZER = 2,
};
enum StatusCode {
OK = 0,
FAIL = 1,
INVALID_ARGUMENT = 2,
INVALID_PROTOBUF = 3,
};
class Status {
public:
Status() noexcept {}
Status(StatusCategory category, int code, const std::string& msg);
Status(StatusCategory category, int code);
Status(const Status& other) {
*this = other;
}
void operator=(const Status& other) {
if (&other != this) {
if (nullptr == other.state_) {
state_.reset();
} else if (state_ != other.state_) {
state_.reset(new State(*other.state_));
}
}
}
Status(Status&&) = default;
Status& operator=(Status&&) = default;
~Status() = default;
bool IsOK() const noexcept;
int Code() const noexcept;
StatusCategory Category() const noexcept;
const std::string& ErrorMessage() const;
std::string ToString() const;
bool operator==(const Status& other) const {
return (this->state_ == other.state_) || (ToString() == other.ToString());
}
bool operator!=(const Status& other) const {
return !(*this == other);
}
static const Status& OK() noexcept;
private:
struct State {
State(StatusCategory cat_, int code_, std::string msg_) : category(cat_), code(code_), msg(std::move(msg_)) {}
StatusCategory category = StatusCategory::NONE;
int code = 0;
std::string msg;
};
static const std::string& EmptyString();
// state_ == nullptr when if status code is OK.
std::unique_ptr<State> state_;
};
inline std::ostream& operator<<(std::ostream& out, const Status& status) {
return out << status.ToString();
}
} // namespace Common
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,271 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
// ATTENTION: The code in this file is highly EXPERIMENTAL.
// Adventurous users should note that the APIs will probably change.
#pragma once
#include <cmath>
#include <functional>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "onnx/common/assertions.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
struct Tensor final {
private:
bool is_segment_;
int64_t segment_begin_;
int64_t segment_end_;
bool has_name_;
std::string name_;
int32_t elem_type_;
std::vector<int64_t> sizes_;
std::vector<float> float_data_;
std::vector<double> double_data_;
std::vector<int32_t> int32_data_;
std::vector<int64_t> int64_data_;
std::vector<uint64_t> uint64_data_;
std::vector<std::string> string_data_;
bool is_raw_data_;
std::string raw_data_;
public:
Tensor()
: is_segment_(false),
segment_begin_(0),
segment_end_(0),
has_name_(false),
elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
is_raw_data_(false) {}
Tensor(const Tensor& other)
: is_segment_(other.is_segment_),
segment_begin_(other.segment_begin_),
segment_end_(other.segment_end_),
has_name_(other.has_name_),
elem_type_(other.elem_type_),
sizes_(other.sizes_),
float_data_(other.float_data_),
double_data_(other.double_data_),
int32_data_(other.int32_data_),
int64_data_(other.int64_data_),
uint64_data_(other.uint64_data_),
is_raw_data_(other.is_raw_data_) {
// Deep copy. Avoid copy on write when using gcc<5.0
string_data_.resize(other.string_data_.size());
for (unsigned int i = 0; i < other.string_data_.size(); ++i) {
string_data_[i] = std::string(other.string_data_[i].data(), other.string_data_[i].size());
}
name_ = std::string(other.name_.data(), other.name_.size());
raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
}
Tensor(Tensor&&) = default;
~Tensor() = default;
friend void swap(Tensor& first, Tensor& second) {
using std::swap;
swap(first.is_segment_, second.is_segment_);
swap(first.segment_begin_, second.segment_begin_);
swap(first.segment_end_, second.segment_end_);
swap(first.has_name_, second.has_name_);
swap(first.name_, second.name_);
swap(first.elem_type_, second.elem_type_);
swap(first.sizes_, second.sizes_);
swap(first.float_data_, second.float_data_);
swap(first.double_data_, second.double_data_);
swap(first.int32_data_, second.int32_data_);
swap(first.int64_data_, second.int64_data_);
swap(first.uint64_data_, second.uint64_data_);
swap(first.is_raw_data_, second.is_raw_data_);
swap(first.string_data_, second.string_data_);
swap(first.raw_data_, second.raw_data_);
}
Tensor& operator=(Tensor other) noexcept {
swap(*this, other);
return *this;
}
const std::vector<int64_t>& sizes() const {
return sizes_;
}
std::vector<int64_t>& sizes() {
return sizes_;
}
/// if tensor is a scalar, the sizes is empty, but the element number is actually 1.
/// size_from_dim() cannot handle this case, while elem_num() handles it correctly
int64_t elem_num() const {
return std::accumulate(sizes_.begin(), sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
}
int64_t size_from_dim(int dim) const {
if (dim < 0) {
dim += (int)sizes_.size();
}
ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
}
int32_t elem_type() const {
return elem_type_;
}
int32_t& elem_type() {
return elem_type_;
}
std::vector<std::string>& strings() {
return string_data_;
}
const std::vector<std::string>& strings() const {
return string_data_;
}
std::vector<float>& floats() {
return float_data_;
}
const std::vector<float>& floats() const {
return float_data_;
}
std::vector<double>& doubles() {
return double_data_;
}
const std::vector<double>& doubles() const {
return double_data_;
}
std::vector<int32_t>& int32s() {
return int32_data_;
}
const std::vector<int32_t>& int32s() const {
return int32_data_;
}
std::vector<int64_t>& int64s() {
return int64_data_;
}
const std::vector<int64_t>& int64s() const {
return int64_data_;
}
std::vector<uint64_t>& uint64s() {
return uint64_data_;
}
const std::vector<uint64_t>& uint64s() const {
return uint64_data_;
}
const std::string& raw() const {
return raw_data_;
}
void set_raw_data(std::string raw_data) {
is_raw_data_ = true;
raw_data_ = std::move(raw_data);
}
template <typename T>
T* data();
template <typename T>
const T* data() const;
bool is_segment() const {
return is_segment_;
}
int64_t segment_begin() const {
return segment_begin_;
}
int64_t segment_end() const {
return segment_end_;
}
void set_segment_begin_and_end(int64_t begin, int64_t end) {
is_segment_ = true;
segment_begin_ = begin;
segment_end_ = end;
}
bool hasName() const {
return has_name_;
}
const std::string& name() const {
return name_;
}
void setName(std::string name) {
has_name_ = true;
name_ = std::move(name);
}
bool is_raw_data() const {
return is_raw_data_;
}
};
template <>
inline std::string* Tensor::data<std::string>() {
ONNX_ASSERTM(
!is_raw_data(),
"data type is string. string content is required to be stored in repeated bytes string_data field."
"raw_data type cannot be string.");
return string_data_.data();
}
template <>
inline const std::string* Tensor::data<std::string>() const {
ONNX_ASSERTM(
!is_raw_data(),
"data type is string. string content is required to be stored in repeated bytes string_data field."
"raw_data type cannot be string.");
return string_data_.data();
}
#define define_data(type, field) \
template <> \
inline type* Tensor::data<type>() { \
if (is_raw_data_) { \
return (type*)const_cast<char*>(&raw_data_.data()[0]); \
} else { \
return field.data(); \
} \
} \
\
template <> \
inline const type* Tensor::data<type>() const { \
if (is_raw_data_) { \
return (const type*)(raw_data_.data()); \
} else { \
return field.data(); \
} \
}
define_data(float, float_data_);
define_data(double, double_data_);
define_data(int32_t, int32_data_);
define_data(int64_t, int64_data_);
define_data(uint64_t, uint64_data_);
#undef define_data
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,14 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
namespace ONNX_NAMESPACE {
// Represents the most recent release version. Updated with every release.
constexpr const char* LAST_RELEASE_VERSION = "1.17.0";
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,129 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/common/common.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
namespace internal {
// Visitor: A readonly visitor class for ONNX Proto objects.
// This class is restricted to Nodes, Graphs, Attributes, and Functions.
// The VisitX methods invoke ProcessX, and if that returns true, will
// continue to visit all children of the X.
struct Visitor {
virtual void VisitGraph(const GraphProto& graph) {
if (ProcessGraph(graph))
for (auto& node : graph.node())
VisitNode(node);
}
virtual void VisitFunction(const FunctionProto& function) {
if (ProcessFunction(function))
for (auto& node : function.node())
VisitNode(node);
}
virtual void VisitNode(const NodeProto& node) {
if (ProcessNode(node)) {
for (auto& attr : node.attribute()) {
VisitAttribute(attr);
}
}
}
virtual void VisitAttribute(const AttributeProto& attr) {
if (ProcessAttribute(attr)) {
if (attr.has_g()) {
VisitGraph(attr.g());
}
for (auto& graph : attr.graphs())
VisitGraph(graph);
}
}
virtual bool ProcessGraph(const GraphProto& graph) {
ONNX_UNUSED_PARAMETER(graph);
return true;
}
virtual bool ProcessFunction(const FunctionProto& function) {
ONNX_UNUSED_PARAMETER(function);
return true;
}
virtual bool ProcessNode(const NodeProto& node) {
ONNX_UNUSED_PARAMETER(node);
return true;
}
virtual bool ProcessAttribute(const AttributeProto& attr) {
ONNX_UNUSED_PARAMETER(attr);
return true;
}
virtual ~Visitor() {}
};
// MutableVisitor: A version of Visitor that allows mutation of the visited objects.
struct MutableVisitor {
virtual void VisitGraph(GraphProto* graph) {
if (ProcessGraph(graph))
for (auto& node : *(graph->mutable_node()))
VisitNode(&node);
}
virtual void VisitFunction(FunctionProto* function) {
if (ProcessFunction(function))
for (auto& node : *(function->mutable_node()))
VisitNode(&node);
}
virtual void VisitNode(NodeProto* node) {
if (ProcessNode(node)) {
for (auto& attr : *(node->mutable_attribute())) {
VisitAttribute(&attr);
}
}
}
virtual void VisitAttribute(AttributeProto* attr) {
if (ProcessAttribute(attr)) {
if (attr->has_g()) {
VisitGraph(attr->mutable_g());
}
for (auto& graph : *(attr->mutable_graphs()))
VisitGraph(&graph);
}
}
virtual bool ProcessGraph(GraphProto* graph) {
ONNX_UNUSED_PARAMETER(graph);
return true;
}
virtual bool ProcessFunction(FunctionProto* function) {
ONNX_UNUSED_PARAMETER(function);
return true;
}
virtual bool ProcessNode(NodeProto* node) {
ONNX_UNUSED_PARAMETER(node);
return true;
}
virtual bool ProcessAttribute(AttributeProto* attr) {
ONNX_UNUSED_PARAMETER(attr);
return true;
}
virtual ~MutableVisitor() {}
};
} // namespace internal
} // namespace ONNX_NAMESPACE