I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,144 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
__all__ = [
"C",
"ONNX_DOMAIN",
"ONNX_ML_DOMAIN",
"AI_ONNX_PREVIEW_TRAINING_DOMAIN",
"has",
"register_schema",
"deregister_schema",
"get_schema",
"get_all_schemas",
"get_all_schemas_with_history",
"onnx_opset_version",
"get_function_ops",
"OpSchema",
"SchemaError",
]
import onnx.onnx_cpp2py_export.defs as C # noqa: N812
from onnx import AttributeProto, FunctionProto
ONNX_DOMAIN = ""
ONNX_ML_DOMAIN = "ai.onnx.ml"
AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training"
has = C.has_schema
get_schema = C.get_schema
get_all_schemas = C.get_all_schemas
get_all_schemas_with_history = C.get_all_schemas_with_history
deregister_schema = C.deregister_schema
def onnx_opset_version() -> int:
"""Return current opset for domain `ai.onnx`."""
return C.schema_version_map()[ONNX_DOMAIN][1]
def onnx_ml_opset_version() -> int:
"""Return current opset for domain `ai.onnx.ml`."""
return C.schema_version_map()[ONNX_ML_DOMAIN][1]
@property # type: ignore
def _function_proto(self): # type: ignore
func_proto = FunctionProto()
func_proto.ParseFromString(self._function_body)
return func_proto
OpSchema = C.OpSchema # type: ignore
OpSchema.function_body = _function_proto # type: ignore
@property # type: ignore
def _attribute_default_value(self): # type: ignore
attr = AttributeProto()
attr.ParseFromString(self._default_value)
return attr
OpSchema.Attribute.default_value = _attribute_default_value # type: ignore
def _op_schema_repr(self) -> str:
return f"""\
OpSchema(
name={self.name!r},
domain={self.domain!r},
since_version={self.since_version!r},
doc={self.doc!r},
type_constraints={self.type_constraints!r},
inputs={self.inputs!r},
outputs={self.outputs!r},
attributes={self.attributes!r}
)"""
OpSchema.__repr__ = _op_schema_repr # type: ignore
def _op_schema_formal_parameter_repr(self) -> str:
return (
f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, "
f"description={self.description!r}, param_option={self.option!r}, "
f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, "
f"differentiation_category={self.differentiation_category!r})"
)
OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore
def _op_schema_type_constraint_param_repr(self) -> str:
return (
f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, "
f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})"
)
OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore
def _op_schema_attribute_repr(self) -> str:
return (
f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, "
f"default_value={self.default_value!r}, required={self.required!r})"
)
OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore
def get_function_ops() -> list[OpSchema]:
"""Return operators defined as functions."""
schemas = C.get_all_schemas()
return [schema for schema in schemas if schema.has_function or schema.has_context_dependent_function] # type: ignore
SchemaError = C.SchemaError
def register_schema(schema: OpSchema) -> None:
"""Register a user provided OpSchema.
The function extends available operator set versions for the provided domain if necessary.
Args:
schema: The OpSchema to register.
"""
version_map = C.schema_version_map()
domain = schema.domain
version = schema.since_version
min_version, max_version = version_map.get(domain, (version, version))
if domain not in version_map or not (min_version <= version <= max_version):
min_version = min(min_version, version)
max_version = max(max_version, version)
C.set_domain_to_version(schema.domain, min_version, max_version)
C.register_schema(schema)

View File

@ -0,0 +1,71 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "attr_proto_util.h"
#include <string>
#include <vector>
namespace ONNX_NAMESPACE {
#define ADD_BASIC_ATTR_IMPL(type, enumType, field) \
AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \
AttributeProto a; \
a.set_name(attr_name); \
a.set_type(enumType); \
a.set_##field(value); \
return a; \
}
#define ADD_ATTR_IMPL(type, enumType, field) \
AttributeProto MakeAttribute(const std::string& attr_name, const type& value) { \
AttributeProto a; \
a.set_name(attr_name); \
a.set_type(enumType); \
*(a.mutable_##field()) = value; \
return a; \
}
#define ADD_LIST_ATTR_IMPL(type, enumType, field) \
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<type>& values) { \
AttributeProto a; \
a.set_name(attr_name); \
a.set_type(enumType); \
for (const auto& val : values) { \
*(a.mutable_##field()->Add()) = val; \
} \
return a; \
}
ADD_BASIC_ATTR_IMPL(float, AttributeProto_AttributeType_FLOAT, f)
ADD_BASIC_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INT, i)
ADD_BASIC_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRING, s)
ADD_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSOR, t)
ADD_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPH, g)
ADD_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTO, tp)
ADD_LIST_ATTR_IMPL(float, AttributeProto_AttributeType_FLOATS, floats)
ADD_LIST_ATTR_IMPL(int64_t, AttributeProto_AttributeType_INTS, ints)
ADD_LIST_ATTR_IMPL(std::string, AttributeProto_AttributeType_STRINGS, strings)
ADD_LIST_ATTR_IMPL(TensorProto, AttributeProto_AttributeType_TENSORS, tensors)
ADD_LIST_ATTR_IMPL(GraphProto, AttributeProto_AttributeType_GRAPHS, graphs)
ADD_LIST_ATTR_IMPL(TypeProto, AttributeProto_AttributeType_TYPE_PROTOS, type_protos)
AttributeProto MakeRefAttribute(const std::string& attr_name, AttributeProto_AttributeType type) {
return MakeRefAttribute(attr_name, attr_name, type);
}
AttributeProto MakeRefAttribute(
const std::string& attr_name,
const std::string& referred_attr_name,
AttributeProto_AttributeType type) {
AttributeProto a;
a.set_name(attr_name);
a.set_ref_attr_name(referred_attr_name);
a.set_type(type);
return a;
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,43 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <string>
#include <vector>
#include "onnx/onnx-operators_pb.h"
namespace ONNX_NAMESPACE {
AttributeProto MakeAttribute(const std::string& attr_name, const float& value);
AttributeProto MakeAttribute(const std::string& attr_name, const int64_t& value);
AttributeProto MakeAttribute(const std::string& attr_name, const std::string& value);
AttributeProto MakeAttribute(const std::string& attr_name, const TensorProto& value);
AttributeProto MakeAttribute(const std::string& attr_name, const GraphProto& value);
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<float>& values);
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<int64_t>& values);
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<std::string>& values);
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<TensorProto>& values);
AttributeProto MakeAttribute(const std::string& attr_name, const std::vector<GraphProto>& values);
// Make a "reference" attribute for a node in a function body.
// <attr_name> specifies the attribute name of both the function node and its
// function body node. They're using the same attribute name.
// <type> specifies the attribute type.
AttributeProto MakeRefAttribute(const std::string& attr_name, AttributeProto_AttributeType type);
// Make a "reference" attribute for a node in a function body.
// <attr_name> specifies the attribute name of the function body node.
// <referred_attr_name> specifies the referred attribute name of the function
// node.
// <type> specifies the attribute type.
AttributeProto MakeRefAttribute(
const std::string& attr_name,
const std::string& referred_attr_name,
AttributeProto_AttributeType type);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,454 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <assert.h>
#include "onnx/defs/controlflow/utils.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
using SupportType = OpSchema::SupportType;
static std::vector<std::string> control_flow_types_ir10() {
auto t = OpSchema::all_tensor_types_ir10();
auto s = OpSchema::all_tensor_sequence_types_ir10();
auto o = OpSchema::all_optional_types_ir10();
t.insert(t.end(), s.begin(), s.end());
t.insert(t.end(), o.begin(), o.end());
return t;
}
ONNX_OPERATOR_SET_SCHEMA(
If,
21,
OpSchema()
.SetDoc("If conditional")
.Input(0, "cond", "Condition for the if. The tensor must contain a single element.", "B")
.Output(
0,
"outputs",
"Values that are live-out to the enclosing scope. The return values in "
"the `then_branch` and `else_branch` must be of the same data type. "
"The `then_branch` and `else_branch` may produce tensors with the same "
"element type and different shapes. "
"If corresponding outputs from the then-branch and the else-branch have "
"static shapes S1 and S2, then the shape of the corresponding output "
"variable of the if-node (if present) must be compatible with both S1 "
"and S2 as it represents the union of both possible shapes."
"For example, if in a model file, the first "
"output of `then_branch` is typed float tensor with shape [2] and the "
"first output of `else_branch` is another float tensor with shape [3], "
"If's first output should have (a) no shape set, or (b) "
"a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) "
"a shape of rank 1 with a unique `dim_param`. "
"In contrast, the first output cannot have the shape [2] since [2] and "
"[3] are not compatible.",
"V",
OpSchema::Variadic,
false)
.Attr(
"then_branch",
"Graph to run if condition is true. Has N outputs: values you wish to "
"be live-out to the enclosing scope. The number of outputs must match"
" the number of outputs in the else_branch.",
AttributeProto::GRAPH)
.Attr(
"else_branch",
"Graph to run if condition is false. Has N outputs: values you wish to"
" be live-out to the enclosing scope. The number of outputs must match"
" the number of outputs in the then_branch.",
AttributeProto::GRAPH)
.TypeConstraint(
"V",
control_flow_types_ir10(),
"All Tensor, Sequence(Tensor), Optional(Tensor), and Optional(Sequence(Tensor)) types up to IRv10.")
.TypeConstraint("B", {"tensor(bool)"}, "Only bool")
.TypeAndShapeInferenceFunction(IfInferenceFunction));
static const char* Loop_ver16_doc = R"DOC(
Generic Looping construct. This loop has multiple termination conditions:
1) Trip count. Iteration count specified at runtime. Set by
specifying the input M. Optional. Set to empty string to omit.
Note that a static trip count (specified at graph construction time) can be
specified by passing in a constant node for input M.
2) Loop termination condition. This is an input to the op that determines
whether to run the first iteration and also a loop-carried dependency for
the body graph. The body graph must yield a value for the condition variable,
whether this input is provided or not.
This table summarizes the operating modes of this operator with equivalent
C-style code:
Operator inputs defined as (max_trip_count, condition_var).
* input ("", ""):
for (int i=0; ; ++i) {
cond = ... // Note this value is ignored, but is required in the body
}
* input ("", cond) // Note this is analogous to a while loop
bool cond = ...;
for (int i=0; cond; ++i) {
cond = ...;
}
* input ("", 1) // Note this is analogous to a do-while loop
bool cond = true
for (int i=0; cond; ++i) {
cond = ...;
}
* input (trip_count, "") // Note this is analogous to a for loop
int trip_count = ...
for (int i=0; i < trip_count; ++i) {
cond = ...; // ignored
}
* input (trip_count, cond)
int trip_count = ...;
bool cond = ...;
for (int i=0; i < trip_count && cond; ++i) {
cond = ...;
}
*Sample usage - cond as well as trip count*
graph predict-net {
%a = Constant[value = <Scalar Tensor [3]>]()
%b = Constant[value = <Scalar Tensor [6]>]()
%keepgoing = Constant[value = <Scalar Tensor [1]>]()
%max_trip_count = Constant[value = <Scalar Tensor [10]>]()
%keepgoing_out, %b_out, %user_defined_vals = Loop[body = <graph body-net>](%max_trip_count, %keepgoing, %b)
return
}
graph body-net (
%i[INT32, scalar] // iteration number
%keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used
%b_in[INT32, scalar] // incoming value of loop-carried-dependency b
) {
%my_local = Add(%a, %b_in)
%b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b
%keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition
%user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated
return %keepgoing_out, %b_out, %user_defined_val
}
*Sample equivalent C code*
{
/* User-defined code (enclosing scope) */
int a = 3, b = 6;
bool keepgoing = true; // Analogous to input cond
/* End user-defined code */
/* Implicitly-defined code */
const int max_trip_count = 10; // Analogous to input M
int user_defined_vals[]; // Imagine this is resizable
/* End implicitly-defined code */
/* initialize loop-carried variables and scan-output variables */
bool keepgoing_out = keepgoing
int b_out = b
for (int i=0; i < max_trip_count && keepgoing_out; ++i) {
/* Implicitly-defined code: bind actual parameter values
to formal parameter variables of loop-body */
bool keepgoing_in = keepgoing_out;
bool b_in = b_out;
/* User-defined code (loop body) */
int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine
b_out = a - b_in;
keepgoing_out = my_local > b_out;
user_defined_val = b_in + b_in; // b_in and b_out are different variables
/* End user-defined code */
/* Implicitly defined-code */
user_defined_vals[i] = user_defined_val // accumulate scan-output values
}
// int t = my_local; // Can't do this. my_local is not accessible here.
// The values below are bound to the output variables of the loop and therefore accessible
// b_out; user_defined_vals; keepgoing_out;
}
There are several things of note in this code snippet:
1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can
be referenced in the inputs of the loop.
2) Any values computed in the loop body that needs to be used in a subsequent
iteration or after the loop are modelled using a pair of variables in the loop-body,
consisting of an input variable (eg., b_in) and an output variable (eg., b_out).
These are referred to as loop-carried dependences. The loop operation node
supplies the input value of the input variable for the first iteration, and
returns the output value of the output variable produced by the final
iteration.
3) Scan_output variables are used to implicitly concatenate values computed across
all the iterations. In the above example, the value of user_defined_val computed
over all iterations are concatenated and returned as the value of user_defined_vals
after the loop.
4) Values created in the body cannot be accessed in the enclosing scope,
except using the mechanism described above.
Note that the semantics of this op support "diagonal" or "wavefront" execution.
(See Step 3 here for an example:
https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
Frontends should emit multi-layer RNNs as a series of While operators (with
time being the inner looping dimension), with each successive layer consuming
the scan_outputs from the previous layer, possibly going through several
point-wise operators (e.g. dropout, residual connections, linear layer).
The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Loop,
21,
OpSchema()
.SetDoc(Loop_ver16_doc)
.Input(
0,
"M",
"A maximum trip-count for the loop specified at runtime. Optional."
" Pass empty string to skip.",
"I",
OpSchema::Optional)
.Input(
1,
"cond",
"A boolean termination condition. Optional. Pass empty string to skip.",
"B",
OpSchema::Optional)
.Input(
2,
"v_initial",
"The initial values of any loop-carried dependencies (values that "
"change across loop iterations)",
"V",
OpSchema::Variadic,
false,
0)
.Output(
0,
"v_final_and_scan_outputs",
"Final N loop carried dependency values then K scan_outputs. "
"Scan outputs must be Tensors.",
"V",
OpSchema::Variadic,
false)
.Attr(
"body",
"The graph run each iteration. It has 2+N inputs: (iteration_num, "
"condition, loop carried dependencies...). It has 1+N+K outputs: "
"(condition, loop carried dependencies..., scan_outputs...). Each "
"scan_output is created by concatenating the value of the specified "
"output value at the end of each iteration of the loop. It is an error"
" if the dimensions or data type of these scan_outputs change across loop"
" iterations.",
AttributeProto::GRAPH)
.TypeConstraint(
"V",
control_flow_types_ir10(),
"All Tensor, Sequence(Tensor), Optional(Tensor), and Optional(Sequence(Tensor)) types up to IRv10.")
.TypeConstraint("I", {"tensor(int64)"}, "tensor of int64, which should be a scalar.")
.TypeConstraint("B", {"tensor(bool)"}, "tensor of bool, which should be a scalar.")
.TypeAndShapeInferenceFunction(LoopInferenceFunction));
static const char* scan_16_doc = R"DOC(
Scan can be used to iterate over one or more scan_input tensors,
constructing zero or more scan_output tensors. It combines ideas from general recurrences,
functional programming constructs such as scan, fold, map, and zip, and is intended to enable
generalizations of RNN-like constructs for sequence-to-sequence processing.
Other tensors (referred to as state_variables here) can be used to carry a state
when iterating from one element to another (similar to hidden-state in RNNs, also referred
to as loop-carried dependences in the context of loops).
Many common usages involve a single scan_input tensor (where functionality
similar to scan, fold and map can be obtained). When more than one scan_input is used,
a behavior similar to zip is obtained.
The attribute body must be a graph, specifying the computation to be performed in
every iteration. It takes as input the current values of the state_variables and
the current iterated element of the scan_inputs. It must return the (updated) values
of the state_variables and zero or more scan_output_element tensors. The values of the
scan_output_element tensors are concatenated over all the iterations to produce the
scan_output values of the scan construct (similar to the concatenated intermediate
hidden-state values of RNN-like constructs). All the output tensors (state_variables as
well as scan_output_element tensors) are required to have the same shape in each iteration
of the loop (a restriction imposed to enable efficient memory allocation).
Note that the iterated element passed to the body subgraph does not have a sequence
axis. It will have a rank one less than the rank of the corresponding scan_input.
The scan operation returns the final values of the state_variables as well as the
scan_outputs.
The optional attribute scan_input_directions specifies the direction (forward or backward)
for each scan input. If this attribute is omitted, all sequences are scanned in the forward
direction. A bidirectional scan may be performed by specifying the same tensor input twice
in the scan_inputs, once with a forward direction, and once with a backward direction.
The scan_output of the operation is produced by concatenating the scan_output_element
values produced by the body in each iteration. The optional attribute scan_output_directions
specifies the direction in which scan_output is constructed (by appending or prepending the
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
is omitted, the scan_output_element is appended to the scan_output in each iteration.
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
Note that scanning a non-zero axis may be less efficient than scanning axis zero.
The optional attribute scan_output_axes specifies the axis along which the scan_outputs
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
value of 1.
Note that because of the ONNX restriction that only the last parameter of an operator can
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
Similarly, the final-states and scan-outputs are listed together as one output parameter.
The attribute num_scan_inputs indicates the number M of scan-inputs.
The behavior of
Scan <
num_scan_inputs = m,
body = loop-body,
scan_input_axes = [axis_1, ..., axis_m]
> (init_1, ..., init_n, scan_1, ..., scan_m)
is equivalent to the following pseudo-code:
// scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
// scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
sequence_length = scan_1.shape[axis_1];
// initialize state-variables
st_1 = init_1; ... st_n = init_n;
// initialize scan-output variables: [] denotes an empty tensor
scan_out_1 = []; ...; scan_out_k = [];
// identify number of iterations:
// execute loop
for (int t = 0; t < sequence_length; ++t) {
// generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
// of rank one less than T obtained by indexing T at position t along axis k.
si_1 = scan_1<axis=axis_1>[t];
... ;
si_m = scan_m<axis=axis_m>[t];
// execute loop-body
st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
// accumulate the scan-output elements
scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
}
return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
*Sample usage: Encoding RNN using a Scan*
The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
%Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
values are computed in the outer graph, they need to be passed in as extra state_variables.
graph rnn-encoding {
%H_0 = ...
%X = ...
%Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
return %Y, %Y_h
}
graph rnn-cell-1 (
%H_tminus1[FLOAT, tensor]
%X_t[FLOAT, tensor]
) {
%Wi = ...
%Ri = ...
%Wbi = ...
%Rbi = ...
%t1 = X_t * (Wi^T)
%t2 = H_tminus1*(Ri^T)
%t3 = Add(%t1, %t2)
%t4 = Add(%t3, %Wbi)
%t5 = Add(%t4, %Rbi)
%Ht = Tanh(%t5)
%Accumulate = Identity(%Ht)
return %Ht, %Accumulate
}
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Scan,
21,
OpSchema()
.SetDoc(scan_16_doc)
.Input(
0,
"initial_state_and_scan_inputs",
"Initial values of the loop's N state variables followed by M scan_inputs",
"V",
OpSchema::Variadic,
false)
.Output(
0,
"final_state_and_scan_outputs",
"Final values of the loop's N state variables followed by K scan_outputs",
"V",
OpSchema::Variadic,
false)
.Attr(
"body",
"The graph run each iteration. It has N+M inputs: "
"(loop state variables..., scan_input_elts...). It has N+K outputs: "
"(loop state variables..., scan_output_elts...). Each "
"scan_output is created by concatenating the value of the specified "
"scan_output_elt value at the end of each iteration of the loop. It is an error"
" if the dimensions of these values change across loop iterations.",
AttributeProto::GRAPH,
true)
.Attr("num_scan_inputs", "An attribute specifying the number of scan_inputs M. ", AttributeProto::INT, true)
.Attr(
"scan_input_directions",
"An optional list of M flags. The i-th element of the list specifies the direction "
"to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 "
"indicates reverse direction. "
"If omitted, all scan_input tensors will be scanned in the forward direction.",
AttributeProto::INTS,
false)
.Attr(
"scan_output_directions",
"An optional list of K flags, one for each scan_output. The i-th element of the list "
"specifies whether the i-th scan_output should be constructed by appending or "
"prepending a new value in each iteration: 0 indicates appending and 1 "
"indicates prepending. "
"If omitted, all scan_output tensors will be produced by appending a value "
"in each iteration.",
AttributeProto::INTS,
false)
.Attr(
"scan_input_axes",
"An optional list of M flags. The i-th element of the list specifies the axis "
"to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will "
"be used as the scan axis for every scan_input. Negative value for an axis means "
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).",
AttributeProto::INTS,
false)
.Attr(
"scan_output_axes",
"An optional list of K flags. The i-th element of the list specifies the axis "
"for the i-th scan_output. The scan outputs are accumulated along the specified "
"axis. If omitted, 0 will be used as the scan axis for every scan_output. "
"Negative value for an axis means counting dimensions from the back. Accepted "
"range is [-r, r-1].",
AttributeProto::INTS,
false)
.TypeConstraint("V", OpSchema::all_tensor_types_ir10(), "All Tensor types up to IRv10.")
.TypeAndShapeInferenceFunction(ScanInferenceFunction)); // Shares same shape inference as opset 11
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,359 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/controlflow/utils.h"
#include <string>
#include <vector>
namespace ONNX_NAMESPACE {
void ClearShape(TypeProto& input_type) {
if (input_type.has_tensor_type()) {
input_type.mutable_tensor_type()->clear_shape();
} else if (input_type.has_sequence_type()) {
auto& seq_type = *input_type.mutable_sequence_type();
if (seq_type.has_elem_type()) {
ClearShape(*(seq_type.mutable_elem_type()));
}
} else if (input_type.has_optional_type()) {
auto& opt_type = *input_type.mutable_optional_type();
if (opt_type.has_elem_type()) {
ClearShape(*(opt_type.mutable_elem_type()));
}
}
}
void IfInferenceFunction(InferenceContext& ctx) {
// there are no inputs so we just need to run the subgraph inferencing for
// then/else subgraphs and apply those to the outputs.
std::vector<const TypeProto*> subgraph_input_types; // none
std::vector<const TensorProto*> input_data; // none
std::vector<const TypeProto*> then_output_types;
std::vector<const TypeProto*> else_output_types;
// Run inferencing on the subgraph
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("then_branch");
if (graphInferencer) {
then_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
if (graphInferencer) {
else_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
auto num_outputs = ctx.getNumOutputs();
auto num_then_outputs = then_output_types.size();
auto num_else_outputs = else_output_types.size();
// the output types for then and else should be the same
if (num_then_outputs != num_else_outputs) {
fail_type_inference(
"then_branch and else_branch produce different number of outputs. ",
num_then_outputs,
" != ",
num_else_outputs);
}
if (num_then_outputs != num_outputs) {
fail_type_inference("If node has ", num_outputs, " but subgraphs produce ", num_then_outputs);
}
for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
auto then_output = then_output_types[i];
auto else_output = else_output_types[i];
auto* if_output = ctx.getOutputType(i);
*if_output = *then_output;
UnionTypeInfo(*else_output, *if_output);
}
}
void LoopInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
assert(num_inputs >= 2);
auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);
std::vector<TypeProto> temporary_type_protos;
temporary_type_protos.reserve(num_inputs - 2);
// create TypeProto to validate iteration number type is the same as the
// optional 'M' input for max iterations.
TypeProto iter_num_type;
iter_num_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
subgraph_input_types.push_back(&iter_num_type);
// 'cond'
subgraph_input_types.push_back(ctx.getInputType(1));
// loop state value types get propagated to outputs, but shape may change
// across iterations so don't propagate it to the outputs and don't pass it
// into the subgraph inferencing
for (size_t i = 2; i < num_inputs; ++i) {
propagateElemTypeFromInputToOutput(ctx, i, i - 2);
// copy so we can remove the shape before passing to the subgraph
// inferencing
temporary_type_protos.push_back(*ctx.getInputType(i));
auto& input_type = temporary_type_protos.back();
ClearShape(input_type);
subgraph_input_types.push_back(&input_type);
}
// Run inferencing on the subgraph
std::vector<const TypeProto*> subgraph_output_types;
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
if (graphInferencer) {
std::vector<const TensorProto*> input_data;
input_data.push_back(nullptr); // iteration number
for (size_t i = 1; i < num_inputs; ++i) {
input_data.push_back(ctx.getInputData(i));
}
subgraph_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
// if empty(), assume inferencing was skipped
if (!subgraph_output_types.empty()) {
auto num_outputs = ctx.getNumOutputs();
// subgraph outputs the condition value first but that is only used
// internally and not returned by Loop.
if (subgraph_output_types.size() != num_outputs + 1) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
subgraph_output_types.size(),
" outputs. Expected ",
num_outputs + 1);
}
// check loop state values match. we should already have type/shape info
for (size_t i = 0; i < num_outputs; ++i) {
auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
auto* loop_output_type = ctx.getOutputType(i);
const bool is_loop_state_var = i < num_loop_state_vars;
if (!subgraph_output_type->has_tensor_type() && !subgraph_output_type->has_sequence_type() &&
!subgraph_output_type->has_optional_type()) {
fail_type_inference(
"Loop 'body' subgraph outputs should all be tensors or sequences or optionals, but output ",
i,
" was ",
subgraph_output_type->value_case());
}
if (!is_loop_state_var && !subgraph_output_type->has_tensor_type()) {
fail_type_inference(
"Loop 'body' subgraph scan outputs should all be tensors but output ",
i,
" was ",
subgraph_output_type->value_case());
}
// if there's an existing type check it matches. otherwise propagate
propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
if (is_loop_state_var) {
// shape may change across iterations so ignore.
} else {
// propagate shape
if (subgraph_output_type->tensor_type().has_shape()) {
// per iteration output. first dimension will be number of iterations
// but we don't know that value yet
TypeProto inferred_type(*subgraph_output_type);
auto* mutable_inferred_tensor_type = inferred_type.mutable_tensor_type();
auto* mutable_inferred_shape = mutable_inferred_tensor_type->mutable_shape();
mutable_inferred_shape->clear_dim();
// add empty dimension for number of iterations
mutable_inferred_shape->add_dim();
// add dimensions from subgraph output shape
for (const auto& dim : subgraph_output_type->tensor_type().shape().dim()) {
(*mutable_inferred_shape->add_dim()) = dim;
}
mergeInShapeInfo(*mutable_inferred_tensor_type, *loop_output_type->mutable_tensor_type());
}
}
}
}
}
int handle_negative_axis_validate(const std::string& attrib, int axis, int rank) {
if (!(-rank <= axis && axis < rank)) {
fail_shape_inference(attrib, " axis value ", axis, " is invalid for a tensor of rank ", rank);
}
return (axis >= 0 ? axis : axis + rank);
}
void ScanInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
auto num_scan_inputs = narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
auto num_loop_state_vars = num_inputs - num_scan_inputs;
auto num_outputs = ctx.getNumOutputs();
auto num_scan_outputs = num_outputs - num_loop_state_vars;
std::vector<int64_t> axes, output_axes;
if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
if (axes.size() != num_scan_inputs) {
fail_shape_inference(
"Number of scan input axes specified (",
axes.size(),
") is not equal to number of scan inputs (",
num_scan_inputs,
").");
}
} else {
axes.insert(axes.end(), num_scan_inputs, 0);
}
if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
if (output_axes.size() != num_scan_outputs) {
fail_shape_inference(
"Number of scan output axes specified (",
output_axes.size(),
") is not equal to number of scan outputs (",
num_scan_outputs,
").");
}
} else {
output_axes.insert(output_axes.end(), num_scan_outputs, 0);
}
std::vector<TypeProto> temporary_type_protos;
temporary_type_protos.reserve(num_inputs);
std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);
TensorShapeProto_Dimension sequence_len_dim;
for (size_t i = 0; i < num_inputs; ++i) {
bool is_loop_state_var = i < num_loop_state_vars;
bool has_shape = hasInputShape(ctx, i);
const auto* input_type = ctx.getInputType(i);
// Enforce type constraint for inputs
if (!input_type || !input_type->has_tensor_type()) {
fail_type_inference("Scan input ", i, " was not a tensor.");
}
if (is_loop_state_var) {
// If it's a loop state variable we can propagate type and shape 1:1 to
// the matching Scan output.
// We can also pass through the type and shape to the subgraph but need to
// remove the batch size dimension from the shape.
propagateElemTypeFromInputToOutput(ctx, i, i);
if (has_shape)
propagateShapeFromInputToOutput(ctx, i, i);
subgraph_input_types.push_back(input_type);
} else {
// For other inputs there is no fixed relationships to the Scan outputs,
// so we don't propagate type/shape information.
// We can pass through the type and shape to the subgraph inputs but
// need to remove the sequence length dimensions from the shape.
if (has_shape) {
const auto& shape = input_type->tensor_type().shape();
// remove sequence length dimensions and add to subgraph_input_types
int axis = static_cast<int>(axes[i - num_loop_state_vars]);
axis = handle_negative_axis_validate("scan_input_axes", axis, shape.dim_size());
// update sequence_len if a value is available
const auto& dims = shape.dim();
mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);
temporary_type_protos.push_back(RemoveIthDimensionFromShape(*input_type, axis));
subgraph_input_types.push_back(&temporary_type_protos.back());
} else {
subgraph_input_types.push_back(input_type);
}
}
}
// Run inferencing on the subgraph
std::vector<const TypeProto*> output_types;
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
if (graphInferencer) {
std::vector<const TensorProto*> input_data;
input_data.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; ++i) {
// ctx.getInputData(i), the input to scan, does not represent the input to
// scan body. So, we pass in null, to represent an unknown value.
input_data.push_back(nullptr);
}
output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
// if empty(), assume inferencing was skipped
if (!output_types.empty()) {
if (output_types.size() != num_outputs) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
output_types.size(),
" outputs. Expected ",
num_outputs);
}
// propagate type/shape information for loop state variables and outputs
for (size_t i = 0; i < num_outputs; ++i) {
const bool is_loop_state_var = i < num_loop_state_vars;
auto* subgraph_output_type = output_types[i];
auto* scan_output_type = ctx.getOutputType(i);
auto* mutable_scan_output_tensor_type = scan_output_type->mutable_tensor_type();
if (!subgraph_output_type->has_tensor_type()) {
fail_type_inference("Scan 'body' subgraph outputs should all be tensors but output ", i, " was not");
}
auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();
if (is_loop_state_var) {
// merge shape; type already propagated
mergeInShapeInfo(subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
} else {
scan_output_type->mutable_tensor_type()->set_elem_type(subgraph_output_tensor_type.elem_type());
// propagate shape
if (subgraph_output_tensor_type.has_shape()) {
// infer shape of scan-output from the shape of scan-output-element
// by adding sequence-length at the correct axis position
const TensorShapeProto& subgraph_output_shape = subgraph_output_tensor_type.shape();
TensorShapeProto inferred_shape;
auto subgraph_output_rank = subgraph_output_shape.dim_size();
auto output_rank = subgraph_output_rank + 1;
int output_axis = static_cast<int>(output_axes[i - num_loop_state_vars]);
output_axis = handle_negative_axis_validate("scan_output_axes", output_axis, output_rank);
for (int j = 0; j < output_axis; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
*(inferred_shape.add_dim()) = sequence_len_dim;
for (int j = output_axis; j < subgraph_output_rank; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
// Merge inferred shape with existing shape information
mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
}
}
}
}
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <string>
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
void ClearShape(TypeProto& input_type);
int handle_negative_axis_validate(const std::string& attrib, int axis, int rank);
void IfInferenceFunction(InferenceContext& ctx);
void LoopInferenceFunction(InferenceContext& ctx);
void ScanInferenceFunction(InferenceContext& ctx);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,87 @@
// Copyright (c) ONNX Project Contributors
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <utility>
#include "onnx/defs/shape_inference.h"
namespace ONNX_NAMESPACE {
inline void appendDimToTensorShapeProto(TensorShapeProto& tsp, const TensorShapeProto* input_data, int index) {
if (index >= input_data->dim_size() || index < -input_data->dim_size()) {
fail_shape_inference("indices must be in [-rank, rank-1].");
} else {
*tsp.add_dim() = input_data->dim((index < 0) ? input_data->dim_size() + index : index);
}
}
// Returns true if the given axis attribute is 0
inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
auto axisAttr = ctx.getAttribute("axis");
// if axis is not defined
if (!axisAttr) {
if (defaultZero) {
return true;
} else {
fail_shape_inference("Required attribute axis is missing");
return false;
}
}
int axis = static_cast<int>(axisAttr->i());
auto input_data_0 = ctx.getInputData(0);
if (input_data_0 == nullptr) {
return false;
}
int rank = input_data_0->dim_size();
if (axis < -rank || axis >= rank) {
fail_shape_inference("axis must be in [-rank, rank-1].");
return false;
}
if (axis < 0) {
axis += rank;
}
// Only supports axis = 0 since the data comes from Shape
return axis == 0;
}
inline void PropagateShapeDataFromInputToOutput(DataPropagationContext& ctx, int idx) {
// propagate input data
const auto input_data = ctx.getInputData(idx);
if (input_data != nullptr) {
TensorShapeProto tsp;
tsp.CopyFrom(*input_data);
ctx.addOutputData(0, std::move(tsp));
}
}
inline void GatherOp13DataPropagator(DataPropagationContext& ctx) {
if (!axisIsZero(ctx, true)) {
return;
}
const auto input_data = ctx.getInputData(0);
if (input_data == nullptr) {
return;
}
const auto input_indices = ctx.getInputData(1);
if (input_data == nullptr || input_indices == nullptr) {
return;
}
TensorShapeProto tsp;
for (int i = 0; i < input_indices->dim_size(); ++i) {
if (input_indices->dim(i).has_dim_value()) {
appendDimToTensorShapeProto(tsp, input_data, input_indices->dim(i).dim_value());
} else {
return;
}
}
if (tsp.dim_size() > 0) {
ctx.addOutputData(0, std::move(tsp));
}
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,451 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "data_type_utils.h"
#include <cctype>
#include <iostream>
#include <iterator>
#include <sstream>
namespace ONNX_NAMESPACE {
namespace Utils {
// Singleton wrapper around allowed data types.
// This implements construct on first use which is needed to ensure
// static objects are initialized before use. Ops registration does not work
// properly without this.
class TypesWrapper final {
public:
static TypesWrapper& GetTypesWrapper();
std::unordered_set<std::string>& GetAllowedDataTypes();
std::unordered_map<std::string, int32_t>& TypeStrToTensorDataType();
std::unordered_map<int32_t, std::string>& TensorDataTypeToTypeStr();
~TypesWrapper() = default;
TypesWrapper(const TypesWrapper&) = delete;
void operator=(const TypesWrapper&) = delete;
private:
TypesWrapper();
std::unordered_map<std::string, int> type_str_to_tensor_data_type_;
std::unordered_map<int, std::string> tensor_data_type_to_type_str_;
std::unordered_set<std::string> allowed_data_types_;
};
// Simple class which contains pointers to external string buffer and a size.
// This can be used to track a "valid" range/slice of the string.
// Caller should ensure StringRange is not used after external storage has
// been freed.
class StringRange final {
public:
StringRange();
StringRange(const char* data, size_t size);
StringRange(const std::string& str);
StringRange(const char* data);
const char* Data() const;
size_t Size() const;
bool Empty() const;
char operator[](size_t idx) const;
void Reset();
void Reset(const char* data, size_t size);
void Reset(const std::string& str);
bool StartsWith(const StringRange& str) const;
bool EndsWith(const StringRange& str) const;
bool LStrip();
bool LStrip(size_t size);
bool LStrip(StringRange str);
bool RStrip();
bool RStrip(size_t size);
bool RStrip(StringRange str);
bool LAndRStrip();
void ParensWhitespaceStrip();
size_t Find(const char ch) const;
// These methods provide a way to return the range of the string
// which was discarded by LStrip(). i.e. We capture the string
// range which was discarded.
StringRange GetCaptured();
void RestartCapture();
private:
// data_ + size tracks the "valid" range of the external string buffer.
const char* data_;
size_t size_;
// start_ and end_ track the captured range.
// end_ advances when LStrip() is called.
const char* start_;
const char* end_;
};
std::unordered_map<std::string, TypeProto>& DataTypeUtils::GetTypeStrToProtoMap() {
static std::unordered_map<std::string, TypeProto> map;
return map;
}
std::mutex& DataTypeUtils::GetTypeStrLock() {
static std::mutex lock;
return lock;
}
DataType DataTypeUtils::ToType(const TypeProto& type_proto) {
auto typeStr = ToString(type_proto);
std::lock_guard<std::mutex> lock(GetTypeStrLock());
if (GetTypeStrToProtoMap().find(typeStr) == GetTypeStrToProtoMap().end()) {
TypeProto type;
FromString(typeStr, type);
GetTypeStrToProtoMap()[typeStr] = type;
}
return &(GetTypeStrToProtoMap().find(typeStr)->first);
}
DataType DataTypeUtils::ToType(const std::string& type_str) {
TypeProto type;
FromString(type_str, type);
return ToType(type);
}
const TypeProto& DataTypeUtils::ToTypeProto(const DataType& data_type) {
std::lock_guard<std::mutex> lock(GetTypeStrLock());
auto it = GetTypeStrToProtoMap().find(*data_type);
if (GetTypeStrToProtoMap().end() == it) {
ONNX_THROW_EX(std::invalid_argument("Invalid data type " + *data_type));
}
return it->second;
}
std::string DataTypeUtils::ToString(const TypeProto& type_proto, const std::string& left, const std::string& right) {
switch (type_proto.value_case()) {
case TypeProto::ValueCase::kTensorType: {
// Note: We do not distinguish tensors with zero rank (a shape consisting
// of an empty sequence of dimensions) here.
return left + "tensor(" + ToDataTypeString(type_proto.tensor_type().elem_type()) + ")" + right;
}
case TypeProto::ValueCase::kSequenceType: {
return ToString(type_proto.sequence_type().elem_type(), left + "seq(", ")" + right);
}
case TypeProto::ValueCase::kOptionalType: {
return ToString(type_proto.optional_type().elem_type(), left + "optional(", ")" + right);
}
case TypeProto::ValueCase::kMapType: {
std::string map_str = "map(" + ToDataTypeString(type_proto.map_type().key_type()) + ",";
return ToString(type_proto.map_type().value_type(), left + map_str, ")" + right);
}
#ifdef ONNX_ML
case TypeProto::ValueCase::kOpaqueType: {
std::string result;
const auto& op_type = type_proto.opaque_type();
result.append(left).append("opaque(");
if (op_type.has_domain() && !op_type.domain().empty()) {
result.append(op_type.domain()).append(",");
}
if (op_type.has_name() && !op_type.name().empty()) {
result.append(op_type.name());
}
result.append(")").append(right);
return result;
}
#endif
case TypeProto::ValueCase::kSparseTensorType: {
// Note: We do not distinguish tensors with zero rank (a shape consisting
// of an empty sequence of dimensions) here.
return left + "sparse_tensor(" + ToDataTypeString(type_proto.sparse_tensor_type().elem_type()) + ")" + right;
}
default:
ONNX_THROW_EX(std::invalid_argument("Unsuported type proto value case."));
}
}
std::string DataTypeUtils::ToDataTypeString(int32_t tensor_data_type) {
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
auto iter = t.TensorDataTypeToTypeStr().find(tensor_data_type);
if (t.TensorDataTypeToTypeStr().end() == iter) {
ONNX_THROW_EX(std::invalid_argument("Invalid tensor data type " + std::to_string(tensor_data_type) + "."));
}
return iter->second;
}
void DataTypeUtils::FromString(const std::string& type_str, TypeProto& type_proto) {
StringRange s(type_str);
type_proto.Clear();
if (s.LStrip("seq")) {
s.ParensWhitespaceStrip();
return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_sequence_type()->mutable_elem_type());
} else if (s.LStrip("optional")) {
s.ParensWhitespaceStrip();
return FromString(std::string(s.Data(), s.Size()), *type_proto.mutable_optional_type()->mutable_elem_type());
} else if (s.LStrip("map")) {
s.ParensWhitespaceStrip();
size_t key_size = s.Find(',');
StringRange k(s.Data(), key_size);
std::string key(k.Data(), k.Size());
s.LStrip(key_size);
s.LStrip(",");
StringRange v(s.Data(), s.Size());
int32_t key_type;
FromDataTypeString(key, key_type);
type_proto.mutable_map_type()->set_key_type(key_type);
return FromString(std::string(v.Data(), v.Size()), *type_proto.mutable_map_type()->mutable_value_type());
} else
#ifdef ONNX_ML
if (s.LStrip("opaque")) {
auto* opaque_type = type_proto.mutable_opaque_type();
s.ParensWhitespaceStrip();
if (!s.Empty()) {
size_t cm = s.Find(',');
if (cm != std::string::npos) {
if (cm > 0) {
opaque_type->mutable_domain()->assign(s.Data(), cm);
}
s.LStrip(cm + 1); // skip comma
}
if (!s.Empty()) {
opaque_type->mutable_name()->assign(s.Data(), s.Size());
}
}
} else
#endif
if (s.LStrip("sparse_tensor")) {
s.ParensWhitespaceStrip();
int32_t e;
FromDataTypeString(std::string(s.Data(), s.Size()), e);
type_proto.mutable_sparse_tensor_type()->set_elem_type(e);
} else if (s.LStrip("tensor")) {
s.ParensWhitespaceStrip();
int32_t e;
FromDataTypeString(std::string(s.Data(), s.Size()), e);
type_proto.mutable_tensor_type()->set_elem_type(e);
} else {
// Scalar
int32_t e;
FromDataTypeString(std::string(s.Data(), s.Size()), e);
TypeProto::Tensor* t = type_proto.mutable_tensor_type();
t->set_elem_type(e);
// Call mutable_shape() to initialize a shape with no dimension.
t->mutable_shape();
}
} // namespace Utils
bool DataTypeUtils::IsValidDataTypeString(const std::string& type_str) {
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
const auto& allowedSet = t.GetAllowedDataTypes();
return (allowedSet.find(type_str) != allowedSet.end());
}
void DataTypeUtils::FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type) {
if (!IsValidDataTypeString(type_str)) {
ONNX_THROW_EX(std::invalid_argument(
"DataTypeUtils::FromDataTypeString - Received invalid data type string '" + type_str + "'."));
}
TypesWrapper& t = TypesWrapper::GetTypesWrapper();
tensor_data_type = t.TypeStrToTensorDataType()[type_str];
}
StringRange::StringRange() : data_(""), size_(0), start_(data_), end_(data_) {}
StringRange::StringRange(const char* p_data, size_t p_size) : data_(p_data), size_(p_size), start_(data_), end_(data_) {
assert(p_data != nullptr);
LAndRStrip();
}
StringRange::StringRange(const std::string& p_str)
: data_(p_str.data()), size_(p_str.size()), start_(data_), end_(data_) {
LAndRStrip();
}
StringRange::StringRange(const char* p_data) : data_(p_data), size_(strlen(p_data)), start_(data_), end_(data_) {
LAndRStrip();
}
const char* StringRange::Data() const {
return data_;
}
size_t StringRange::Size() const {
return size_;
}
bool StringRange::Empty() const {
return size_ == 0;
}
char StringRange::operator[](size_t idx) const {
return data_[idx];
}
void StringRange::Reset() {
data_ = "";
size_ = 0;
start_ = end_ = data_;
}
void StringRange::Reset(const char* data, size_t size) {
data_ = data;
size_ = size;
start_ = end_ = data_;
}
void StringRange::Reset(const std::string& str) {
data_ = str.data();
size_ = str.size();
start_ = end_ = data_;
}
bool StringRange::StartsWith(const StringRange& str) const {
return ((size_ >= str.size_) && (memcmp(data_, str.data_, str.size_) == 0));
}
bool StringRange::EndsWith(const StringRange& str) const {
return ((size_ >= str.size_) && (memcmp(data_ + (size_ - str.size_), str.data_, str.size_) == 0));
}
bool StringRange::LStrip() {
size_t count = 0;
const char* ptr = data_;
while (count < size_ && isspace(*ptr)) {
count++;
ptr++;
}
if (count > 0) {
return LStrip(count);
}
return false;
}
bool StringRange::LStrip(size_t size) {
if (size <= size_) {
data_ += size;
size_ -= size;
end_ += size;
return true;
}
return false;
}
bool StringRange::LStrip(StringRange str) {
if (StartsWith(str)) {
return LStrip(str.size_);
}
return false;
}
bool StringRange::RStrip() {
size_t count = 0;
const char* ptr = data_ + size_ - 1;
while (count < size_ && isspace(*ptr)) {
++count;
--ptr;
}
if (count > 0) {
return RStrip(count);
}
return false;
}
bool StringRange::RStrip(size_t size) {
if (size_ >= size) {
size_ -= size;
return true;
}
return false;
}
bool StringRange::RStrip(StringRange str) {
if (EndsWith(str)) {
return RStrip(str.size_);
}
return false;
}
bool StringRange::LAndRStrip() {
bool l = LStrip();
bool r = RStrip();
return l || r;
}
void StringRange::ParensWhitespaceStrip() {
LStrip();
LStrip("(");
LAndRStrip();
RStrip(")");
RStrip();
}
size_t StringRange::Find(const char ch) const {
size_t idx = 0;
while (idx < size_) {
if (data_[idx] == ch) {
return idx;
}
idx++;
}
return std::string::npos;
}
void StringRange::RestartCapture() {
start_ = data_;
end_ = data_;
}
StringRange StringRange::GetCaptured() {
return StringRange(start_, end_ - start_);
}
TypesWrapper& TypesWrapper::GetTypesWrapper() {
static TypesWrapper types;
return types;
}
std::unordered_set<std::string>& TypesWrapper::GetAllowedDataTypes() {
return allowed_data_types_;
}
std::unordered_map<std::string, int>& TypesWrapper::TypeStrToTensorDataType() {
return type_str_to_tensor_data_type_;
}
std::unordered_map<int, std::string>& TypesWrapper::TensorDataTypeToTypeStr() {
return tensor_data_type_to_type_str_;
}
TypesWrapper::TypesWrapper() {
// DataType strings. These should match the DataTypes defined in onnx.proto
type_str_to_tensor_data_type_["float"] = TensorProto_DataType_FLOAT;
type_str_to_tensor_data_type_["float16"] = TensorProto_DataType_FLOAT16;
type_str_to_tensor_data_type_["bfloat16"] = TensorProto_DataType_BFLOAT16;
type_str_to_tensor_data_type_["double"] = TensorProto_DataType_DOUBLE;
type_str_to_tensor_data_type_["int8"] = TensorProto_DataType_INT8;
type_str_to_tensor_data_type_["int16"] = TensorProto_DataType_INT16;
type_str_to_tensor_data_type_["int32"] = TensorProto_DataType_INT32;
type_str_to_tensor_data_type_["int64"] = TensorProto_DataType_INT64;
type_str_to_tensor_data_type_["uint8"] = TensorProto_DataType_UINT8;
type_str_to_tensor_data_type_["uint16"] = TensorProto_DataType_UINT16;
type_str_to_tensor_data_type_["uint32"] = TensorProto_DataType_UINT32;
type_str_to_tensor_data_type_["uint64"] = TensorProto_DataType_UINT64;
type_str_to_tensor_data_type_["complex64"] = TensorProto_DataType_COMPLEX64;
type_str_to_tensor_data_type_["complex128"] = TensorProto_DataType_COMPLEX128;
type_str_to_tensor_data_type_["string"] = TensorProto_DataType_STRING;
type_str_to_tensor_data_type_["bool"] = TensorProto_DataType_BOOL;
type_str_to_tensor_data_type_["float8e4m3fn"] = TensorProto_DataType_FLOAT8E4M3FN;
type_str_to_tensor_data_type_["float8e4m3fnuz"] = TensorProto_DataType_FLOAT8E4M3FNUZ;
type_str_to_tensor_data_type_["float8e5m2"] = TensorProto_DataType_FLOAT8E5M2;
type_str_to_tensor_data_type_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
type_str_to_tensor_data_type_["uint4"] = TensorProto_DataType_UINT4;
type_str_to_tensor_data_type_["int4"] = TensorProto_DataType_INT4;
for (auto& str_type_pair : type_str_to_tensor_data_type_) {
tensor_data_type_to_type_str_[str_type_pair.second] = str_type_pair.first;
allowed_data_types_.insert(str_type_pair.first);
}
}
} // namespace Utils
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,73 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef ONNX_DATA_TYPE_UTILS_H
#define ONNX_DATA_TYPE_UTILS_H
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "onnx/common/common.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
// String pointer as unique TypeProto identifier.
using DataType = const std::string*;
namespace Utils {
// Data type utility, which maintains a global type string to TypeProto map.
// DataType (string pointer) is used as unique data type identifier for
// efficiency.
//
// Grammar for data type string:
// <type> ::= <data_type> |
// tensor(<data_type>) |
// seq(<type>) |
// map(<data_type>, <type>)
// <data_type> :: = float | int32 | string | bool | uint8
// | int8 | uint16 | int16 | int64 | float16 | double
//
// NOTE: <type> ::= <data_type> means the data is scalar (zero dimension).
//
// Example: float, tensor(float), etc.
//
class DataTypeUtils final {
public:
// If the DataType input is invalid, this function will throw std::invalid_argument exception.
// If ONNX_NO_EXCEPTIONS is set it will abort.
static DataType ToType(const std::string& type_str);
// If the DataType input is invalid, this function will throw std::invalid_argument exception.
// If ONNX_NO_EXCEPTIONS is set it will abort.
static DataType ToType(const TypeProto& type_proto);
// If the DataType input is invalid, this function will throw std::invalid_argument exception.
// If ONNX_NO_EXCEPTIONS is set it will abort.
static const TypeProto& ToTypeProto(const DataType& data_type);
static std::string ToDataTypeString(int32_t tensor_data_type);
private:
static void FromString(const std::string& type_str, TypeProto& type_proto);
static void FromDataTypeString(const std::string& type_str, int32_t& tensor_data_type);
static std::string ToString(const TypeProto& type_proto, const std::string& left = "", const std::string& right = "");
// If int32_t input is invalid, this function will throw an exception.
// If ONNX_NO_EXCEPTIONS is set it will abort.
static bool IsValidDataTypeString(const std::string& type_str);
static std::unordered_map<std::string, TypeProto>& GetTypeStrToProtoMap();
// Returns lock used for concurrent updates to TypeStrToProtoMap.
static std::mutex& GetTypeStrLock();
};
} // namespace Utils
} // namespace ONNX_NAMESPACE
#endif // ! ONNX_DATA_TYPE_UTILS_H

View File

@ -0,0 +1,175 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/function.h"
#include <map>
#include "onnx/defs/schema.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
std::string InteralTensorNameGenerator(const std::string& node_name, const std::string& internal_name) {
std::string new_name = "Func_" + node_name + internal_name;
return new_name;
}
void FunctionExpandHelper(
const NodeProto& node,
const FunctionProto& func,
GraphProto& g,
const std::string& node_prefix) {
// Create a temporary unique node prefix for tensor names
std::string uniq_prefix = node_prefix;
if (uniq_prefix.empty()) {
const void* address = static_cast<const void*>(&node);
std::stringstream ss;
ss << address;
uniq_prefix = ss.str();
}
std::string node_name = node.has_name() ? node.name() : func.name() + uniq_prefix;
std::unordered_map<std::string, std::string> io_names_map;
std::unordered_map<std::string, AttributeProto> attr_map;
for (int idx = 0; idx < node.input_size(); ++idx) {
if (idx >= func.input_size()) {
ONNX_THROW("Input for function node " + node_name + " is out of bounds");
}
io_names_map[func.input().Get(idx)] = node.input().Get(idx);
}
for (int idx = 0; idx < node.output_size(); ++idx) {
if (idx >= func.output_size()) {
ONNX_THROW("Output for function node " + node_name + " is out of bounds");
}
// If the node output is missing, the corresponding function output should
// be treated as an internal value (not as missing) because it could also be
// an intermediate value.
if (node.output().Get(idx) == "") {
continue;
}
io_names_map[func.output().Get(idx)] = node.output().Get(idx);
}
for (auto& attr : node.attribute()) {
attr_map[attr.name()] = attr;
}
// For undefined attributes of the function node
// add default values obtained from the function schema.
// get the domain version for function schema
int domain_version = -1;
for (const auto& opset_import : func.opset_import()) {
if (opset_import.domain() == node.domain()) {
domain_version = static_cast<int>(opset_import.version());
}
}
if (domain_version == -1) {
ONNX_THROW("No opset import registered for domain '" + node.domain() + "' in function proto");
}
const OpSchemaRegistry* schema_registry = OpSchemaRegistry::Instance();
const auto schema = schema_registry->GetSchema(node.op_type(), domain_version, node.domain());
std::map<std::string, OpSchema::Attribute> default_attrs = schema->attributes();
for (const auto& pair : default_attrs) {
const auto& attr_name = pair.first;
const auto& attr = pair.second;
if (!attr_map.count(attr_name)) {
attr_map[attr_name] = attr.default_value;
}
}
for (auto& function_node : func.node()) {
NodeProto* new_node = g.add_node();
new_node->CopyFrom(function_node);
new_node->clear_input();
new_node->clear_output();
new_node->clear_attribute();
for (auto& input : function_node.input()) {
if (io_names_map.count(input)) {
new_node->add_input(io_names_map[input]);
} else {
new_node->add_input(InteralTensorNameGenerator(node_name, input));
}
}
for (auto& output : function_node.output()) {
if (io_names_map.count(output)) {
new_node->add_output(io_names_map[output]);
} else {
new_node->add_output(InteralTensorNameGenerator(node_name, output));
}
}
for (auto& attr : function_node.attribute()) {
if (attr.has_ref_attr_name()) {
if (attr_map.count(attr.ref_attr_name())) {
AttributeProto* new_attr = new_node->add_attribute();
new_attr->CopyFrom(attr_map[attr.ref_attr_name()]);
new_attr->set_name(attr.name());
}
} else {
AttributeProto* new_attr = new_node->add_attribute();
new_attr->CopyFrom(attr);
}
}
}
}
std::vector<NodeProto> FunctionBodyHelper::BuildNodes(const std::vector<NodeDef>& node_defs) {
std::vector<NodeProto> nodes(node_defs.size());
for (size_t i = 0; i < node_defs.size(); i++) {
const NodeDef& node = node_defs[i];
NodeProto& n = nodes[i];
n.set_op_type(node.op_type);
n.set_domain(node.domain);
for (const auto& i : node.inputs) {
n.add_input(i);
}
for (const auto& o : node.outputs) {
n.add_output(o);
}
for (const auto& attr : node.attributes) {
*(n.add_attribute()) = attr.proto;
}
}
return nodes;
}
void FunctionBodyHelper::BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs) {
for (size_t i = 0; i < node_defs.size(); i++) {
const NodeDef& node = node_defs[i];
auto* np = functionProto.add_node();
np->set_op_type(node.op_type);
np->set_domain(node.domain);
for (const auto& inp : node.inputs) {
np->add_input(inp);
}
for (const auto& o : node.outputs) {
np->add_output(o);
}
for (const auto& attr : node.attributes) {
*(np->add_attribute()) = attr.proto;
}
}
}
bool FunctionBodyHelper::BuildFunctionProto(
FunctionProto& functionProto,
const OpSchema& schema,
const std::vector<NodeDef>& node_defs,
const std::vector<OperatorSetIdProto>& relied_opsets) {
BuildNodes(functionProto, node_defs);
for (auto& relied_opset : relied_opsets) {
*(functionProto.mutable_opset_import()->Add()) = relied_opset;
}
schema.BuildFunction(functionProto);
return true;
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,196 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "attr_proto_util.h"
#include "onnx/common/constants.h"
#include "onnx/common/status.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/schema.h"
#include "tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// Helper function to expand a function node given the function proto
void FunctionExpandHelper(
const NodeProto& node,
const FunctionProto& func,
GraphProto& g,
const std::string& node_prefix = "");
class FunctionBodyHelper {
public:
struct AttributeProtoWrapper {
AttributeProto proto;
AttributeProtoWrapper() {}
AttributeProtoWrapper(const AttributeProto& attr_prot) {
proto = attr_prot;
}
template <typename T>
AttributeProtoWrapper(const std::string& attr_name, const T& value) {
proto = MakeAttribute(attr_name, value);
}
};
struct NodeDef {
NodeDef(
std::vector<std::string> outputs,
std::string op_type,
std::vector<std::string> inputs,
std::vector<AttributeProtoWrapper> attributes = {},
std::string domain = "")
: outputs(std::move(outputs)),
op_type(std::move(op_type)),
inputs(std::move(inputs)),
attributes(std::move(attributes)),
domain(std::move(domain)) {}
std::vector<std::string> outputs;
std::string op_type;
std::vector<std::string> inputs;
std::vector<AttributeProtoWrapper> attributes;
std::string domain;
};
/*
BuildNodes() is an utility function for easily define a Function Body.
To build a simple node:
{{"Z"}, "Add", {"X", "Y"}} represents Z = Add(X,Y)
To build a node with attribute:
{{"Y"}, "Concat", {"X1", "X2", "X3"}, {{"axis", 1}}}
represents Y = Concat(X1,X2,X3) with axis = 1
The attribute type are infered from the attribute value's c++ type
Supported value types are
int64_t -> int, vector<int64_t> -> ints
float -> float, vector<float> -> floats
string -> string, vector<string> ->strings
For refering an attribute from parent, use:
{MakeRefAttribute("axes", AttributeProto::INTS)}}
To build a node which belongs to a domain other than onnx standard domain:
{{"Z"}, "Foo", {"X", "Y"}, "customdomain"} represents Z = customdomain.Foo(X,Y)
or
{{"Y"}, "Bar", {"X1", "X2", "X3"}, {{"axis", 1}}, "customdomain"}
represents Y = customdomain.Bar(X1,X2,X3) with axis = 1
For more examples, please find the references of this function
*/
static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs);
static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs);
static bool BuildFunctionProto(
FunctionProto& functionProto,
const OpSchema& schema,
const std::vector<NodeDef>& node_defs,
const std::vector<OperatorSetIdProto>& relied_opsets);
template <typename T>
static NodeDef Const(const std::string& name, const T& value) {
return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(value)}}};
}
template <typename T>
static NodeDef Const(const std::string& name, const std::vector<T>& values) {
return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(values)}}};
}
};
class FunctionBuilder {
public:
FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
FunctionBuilder& Add(const char* nodes_txt) {
OnnxParser parser(nodes_txt);
auto& nodes = *funProto.mutable_node();
while (!parser.EndOfInput()) {
auto status = parser.Parse(*nodes.Add());
if (!status.IsOK())
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
return *this;
}
FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
OnnxParser parser(node_txt);
auto& node = *funProto.add_node();
auto status = parser.Parse(node);
if (!status.IsOK()) {
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
if (!parser.EndOfInput()) {
ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
}
*node.add_attribute() = attr;
return *this;
}
template <typename T>
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) {
return Add(node_txt, MakeAttribute(attr_name, attr_value));
}
FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
// Creates a scalar constant (a tensor of rank zero).
template <typename T>
FunctionBuilder& Const(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
}
// Creates a 1D tensor constant consisting of a single value.
template <typename T>
FunctionBuilder& Const1D(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(const_value);
tensor.add_dims(1);
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
// Creates a 1D tensor constant consisting of zero or more values.
template <typename T>
FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(values);
tensor.add_dims(values.size()); // Treat as 1D tensor.
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
FunctionBuilder& AddOpset(const char* domain, int version) {
auto* opset = funProto.add_opset_import();
opset->set_domain(domain);
opset->set_version(version);
return *this;
}
private:
FunctionProto& funProto;
};
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,428 @@
#!/usr/bin/env python
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
from collections import defaultdict
from typing import Any, NamedTuple, Sequence
import numpy as np
from onnx import defs, helper
from onnx.backend.sample.ops import collect_sample_implementations
from onnx.backend.test.case import collect_snippets
from onnx.defs import ONNX_ML_DOMAIN, OpSchema
SNIPPETS = collect_snippets()
SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
ONNX_ML = not bool(os.getenv("ONNX_ML") == "0")
def display_number(v: int) -> str:
if defs.OpSchema.is_infinite(v):
return "&#8734;"
return str(v)
def should_render_domain(domain: str, output: str) -> bool:
is_ml = "-ml" in output
if domain == ONNX_ML_DOMAIN:
return is_ml
else:
return not is_ml
def format_name_with_domain(domain: str, schema_name: str) -> str:
if domain:
return f"{domain}.{schema_name}"
return schema_name
def format_function_versions(function_versions: Sequence[int]) -> str:
return f"{', '.join([str(v) for v in function_versions])}"
def format_versions(versions: Sequence[OpSchema], changelog: str) -> str:
return f"{', '.join(display_version_link(format_name_with_domain(v.domain, v.name), v.since_version, changelog) for v in versions[::-1])}"
def display_attr_type(v: OpSchema.AttrType) -> str:
assert isinstance(v, OpSchema.AttrType)
s = str(v)
s = s[s.rfind(".") + 1 :].lower()
if s[-1] == "s":
s = "list of " + s
return s
def display_domain(domain: str) -> str:
if domain:
return f"the '{domain}' operator set"
return "the default ONNX operator set"
def display_domain_short(domain: str) -> str:
if domain:
return domain
return "ai.onnx (default)"
def display_version_link(name: str, version: int, changelog: str) -> str:
name_with_ver = f"{name}-{version}"
return f'<a href="{changelog}#{name_with_ver}">{version}</a>'
def generate_formal_parameter_tags(formal_parameter: OpSchema.FormalParameter) -> str:
tags: list[str] = []
if OpSchema.FormalParameterOption.Optional == formal_parameter.option:
tags = ["optional"]
elif OpSchema.FormalParameterOption.Variadic == formal_parameter.option:
if formal_parameter.is_homogeneous:
tags = ["variadic"]
else:
tags = ["variadic", "heterogeneous"]
differentiable: OpSchema.DifferentiationCategory = (
OpSchema.DifferentiationCategory.Differentiable
)
non_differentiable: OpSchema.DifferentiationCategory = (
OpSchema.DifferentiationCategory.NonDifferentiable
)
if differentiable == formal_parameter.differentiation_category:
tags.append("differentiable")
elif non_differentiable == formal_parameter.differentiation_category:
tags.append("non-differentiable")
return "" if len(tags) == 0 else " (" + ", ".join(tags) + ")"
def display_schema(
schema: OpSchema, versions: Sequence[OpSchema], changelog: str
) -> str:
s = ""
# doc
if schema.doc:
s += "\n"
s += "\n".join(
(" " + line).rstrip() for line in schema.doc.lstrip().splitlines()
)
s += "\n"
# since version
s += "\n#### Version\n"
if schema.support_level == OpSchema.SupportType.EXPERIMENTAL:
s += "\nNo versioning maintained for experimental ops."
else:
s += (
"\nThis version of the operator has been "
+ ("deprecated" if schema.deprecated else "available")
+ f" since version {schema.since_version}"
)
s += f" of {display_domain(schema.domain)}.\n"
if len(versions) > 1:
# TODO: link to the Changelog.md
s += "\nOther versions of this operator: {}\n".format(
", ".join(
display_version_link(
format_name_with_domain(v.domain, v.name),
v.since_version,
changelog,
)
for v in versions[:-1]
)
)
# If this schema is deprecated, don't display any of the following sections
if schema.deprecated:
return s
# attributes
if schema.attributes:
s += "\n#### Attributes\n\n"
s += "<dl>\n"
for _, attr in sorted(schema.attributes.items()):
# option holds either required or default value
opt = ""
if attr.required:
opt = "required"
elif attr.default_value.name:
default_value = helper.get_attribute_value(attr.default_value)
doc_string = attr.default_value.doc_string
def format_value(value: Any) -> str:
if isinstance(value, float):
formatted = str(np.round(value, 5))
# use default formatting, unless too long.
if len(formatted) > 10: # noqa: PLR2004
formatted = str(f"({value:e})")
return formatted
if isinstance(value, (bytes, bytearray)):
return str(value.decode("utf-8"))
return str(value)
if isinstance(default_value, list):
default_value = [format_value(val) for val in default_value]
else:
default_value = format_value(default_value)
opt = f"default is {default_value}{doc_string}"
s += f"<dt><tt>{attr.name}</tt> : {display_attr_type(attr.type)}{f' ({opt})' if opt else ''}</dt>\n"
s += f"<dd>{attr.description}</dd>\n"
s += "</dl>\n"
# inputs
s += "\n#### Inputs"
if schema.min_input != schema.max_input:
s += f" ({display_number(schema.min_input)} - {display_number(schema.max_input)})"
s += "\n\n"
if schema.inputs:
s += "<dl>\n"
for input_ in schema.inputs:
option_str = generate_formal_parameter_tags(input_)
s += f"<dt><tt>{input_.name}</tt>{option_str} : {input_.type_str}</dt>\n"
s += f"<dd>{input_.description}</dd>\n"
s += "</dl>\n"
# outputs
s += "\n#### Outputs"
if schema.min_output != schema.max_output:
s += f" ({display_number(schema.min_output)} - {display_number(schema.max_output)})"
s += "\n\n"
if schema.outputs:
s += "<dl>\n"
for output in schema.outputs:
option_str = generate_formal_parameter_tags(output)
s += f"<dt><tt>{output.name}</tt>{option_str} : {output.type_str}</dt>\n"
s += f"<dd>{output.description}</dd>\n"
s += "</dl>\n"
# type constraints
s += "\n#### Type Constraints"
s += "\n\n"
if schema.type_constraints:
s += "<dl>\n"
for type_constraint in schema.type_constraints:
allowedTypes = type_constraint.allowed_type_strs
if len(allowedTypes) > 0:
allowedTypeStr = allowedTypes[0]
for allowedType in allowedTypes[1:]:
allowedTypeStr += ", " + allowedType
s += f"<dt><tt>{type_constraint.type_param_str}</tt> : {allowedTypeStr}</dt>\n"
s += f"<dd>{type_constraint.description}</dd>\n"
s += "</dl>\n"
# Function Body
# TODO: this should be refactored to show the function body graph's picture (DAG).
# if schema.has_function or schema.has_context_dependent_function: # type: ignore
# s += '\n#### Function\n'
# s += '\nThe Function can be represented as a function.\n'
return s
def support_level_str(level: OpSchema.SupportType) -> str:
return (
"<sub>experimental</sub> " if level == OpSchema.SupportType.EXPERIMENTAL else ""
)
class Args(NamedTuple):
output: str
changelog: str
def main(args: Args) -> None:
base_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)
docs_dir = os.path.join(base_dir, "docs")
with open(
os.path.join(docs_dir, args.changelog), "w", newline="", encoding="utf-8"
) as fout:
fout.write("<!--- SPDX-License-Identifier: Apache-2.0 -->\n")
fout.write("## Operator Changelog\n")
fout.write(
"*This file is automatically generated from the\n"
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
" Do not modify directly and instead edit operator definitions.*\n"
"\n"
"For an operator input/output's differentiability, it can be differentiable,\n"
" non-differentiable, or undefined. If a variable's differentiability\n"
" is not specified, that variable has undefined differentiability.\n"
)
# domain -> version -> [schema]
dv_index: dict[str, dict[int, list[OpSchema]]] = defaultdict(
lambda: defaultdict(list)
)
for schema in defs.get_all_schemas_with_history():
dv_index[schema.domain][schema.since_version].append(schema)
fout.write("\n")
for domain, versionmap in sorted(dv_index.items()):
if not should_render_domain(domain, args.output):
continue
s = f"# {display_domain_short(domain)}\n"
for version, unsorted_schemas in sorted(versionmap.items()):
s += f"## Version {version} of {display_domain(domain)}\n"
for schema in sorted(unsorted_schemas, key=lambda s: s.name):
name_with_ver = f"{format_name_with_domain(domain, schema.name)}-{schema.since_version}"
s += (
'### <a name="{}"></a>**{}**'
+ (" (deprecated)" if schema.deprecated else "")
+ "</a>\n"
).format(name_with_ver, name_with_ver)
s += display_schema(schema, [schema], args.changelog)
s += "\n"
fout.write(s)
with open(
os.path.join(docs_dir, args.output), "w", newline="", encoding="utf-8"
) as fout:
fout.write("<!--- SPDX-License-Identifier: Apache-2.0 -->\n")
fout.write("## Operator Schemas\n")
fout.write(
"*This file is automatically generated from the\n"
" [def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py).\n"
" Do not modify directly and instead edit operator definitions.*\n"
"\n"
"For an operator input/output's differentiability, it can be differentiable,\n"
" non-differentiable, or undefined. If a variable's differentiability\n"
" is not specified, that variable has undefined differentiability.\n"
)
# domain -> support level -> name -> [schema]
index: dict[str, dict[int, dict[str, list[OpSchema]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(list))
)
for schema in defs.get_all_schemas_with_history():
index[schema.domain][int(schema.support_level)][schema.name].append(schema)
fout.write("\n")
# Preprocess the Operator Schemas
# [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])]
operator_schemas: list[
tuple[str, list[tuple[int, list[tuple[str, OpSchema, list[OpSchema]]]]]]
] = []
existing_ops: set[str] = set()
for domain, _supportmap in sorted(index.items()):
if not should_render_domain(domain, args.output):
continue
processed_supportmap = []
for _support, _namemap in sorted(_supportmap.items()):
processed_namemap = []
for n, unsorted_versions in sorted(_namemap.items()):
versions = sorted(unsorted_versions, key=lambda s: s.since_version)
schema = versions[-1]
if schema.name in existing_ops:
continue
existing_ops.add(schema.name)
processed_namemap.append((n, schema, versions))
processed_supportmap.append((_support, processed_namemap))
operator_schemas.append((domain, processed_supportmap))
# Table of contents
for domain, supportmap in operator_schemas:
s = f"### {display_domain_short(domain)}\n"
fout.write(s)
fout.write("|**Operator**|**Since version**||\n")
fout.write("|-|-|-|\n")
function_ops = []
for _, namemap in supportmap:
for n, schema, versions in namemap:
if schema.has_function or schema.has_context_dependent_function: # type: ignore
function_versions = schema.all_function_opset_versions # type: ignore
function_ops.append((n, schema, versions, function_versions))
continue
s = '|{}<a href="#{}">{}</a>{}|{}|\n'.format(
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
format_name_with_domain(domain, n),
" (deprecated)" if schema.deprecated else "",
format_versions(versions, args.changelog),
)
fout.write(s)
if function_ops:
fout.write("|**Function**|**Since version**|**Function version**|\n")
for n, schema, versions, function_versions in function_ops:
s = '|{}<a href="#{}">{}</a>|{}|{}|\n'.format( # noqa: UP032
support_level_str(schema.support_level),
format_name_with_domain(domain, n),
format_name_with_domain(domain, n),
format_versions(versions, args.changelog),
format_function_versions(function_versions),
)
fout.write(s)
fout.write("\n")
fout.write("\n")
for domain, supportmap in operator_schemas:
s = f"## {display_domain_short(domain)}\n"
fout.write(s)
for _, namemap in supportmap:
for op_type, schema, versions in namemap:
# op_type
s = (
'### {}<a name="{}"></a><a name="{}">**{}**'
+ (" (deprecated)" if schema.deprecated else "")
+ "</a>\n"
).format(
support_level_str(schema.support_level),
format_name_with_domain(domain, op_type),
format_name_with_domain(domain, op_type.lower()),
format_name_with_domain(domain, op_type),
)
s += display_schema(schema, versions, args.changelog)
s += "\n\n"
if op_type in SNIPPETS:
s += "#### Examples\n\n"
for summary, code in sorted(SNIPPETS[op_type]):
s += "<details>\n"
s += f"<summary>{summary}</summary>\n\n"
s += f"```python\n{code}\n```\n\n"
s += "</details>\n"
s += "\n\n"
if op_type.lower() in SAMPLE_IMPLEMENTATIONS:
s += "#### Sample Implementation\n\n"
s += "<details>\n"
s += f"<summary>{op_type}</summary>\n\n"
s += f"```python\n{SAMPLE_IMPLEMENTATIONS[op_type.lower()]}\n```\n\n"
s += "</details>\n"
s += "\n\n"
fout.write(s)
if __name__ == "__main__":
if ONNX_ML:
main(
Args(
"Operators-ml.md",
"Changelog-ml.md",
)
)
main(
Args(
"Operators.md",
"Changelog.md",
)
)

View File

@ -0,0 +1,31 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from onnx import defs
def main() -> None:
# domain -> support level -> name -> [schema]
with_inference = []
without_inference = []
for schema in defs.get_all_schemas():
domain, name, has_inference = (
schema.domain,
schema.name,
schema.has_type_and_shape_inference_function,
)
elem = (domain, name)
if has_inference:
with_inference.append(elem)
else:
without_inference.append(elem)
print(len(with_inference), "operators have a type/shape inference function.")
print(len(without_inference), "do not. These are:")
for domain, name in sorted(without_inference):
print(domain, name)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,591 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <cmath>
#include "onnx/defs/function.h"
#include "onnx/defs/generator/utils.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* Constant_ver19_doc = R"DOC(
This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
or value_* must be specified.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
21,
OpSchema()
.SetDoc(Constant_ver19_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Attr(
"value_int",
"The value for the sole element for the scalar, int64, output tensor.",
AttributeProto::INT,
false)
.Attr(
"value_ints",
"The values for the elements for the 1D, int64, output tensor.",
AttributeProto::INTS,
false)
.Attr(
"value_float",
"The value for the sole element for the scalar, float32, output tensor.",
AttributeProto::FLOAT,
false)
.Attr(
"value_floats",
"The values for the elements for the 1D, float32, output tensor.",
AttributeProto::FLOATS,
false)
.Attr(
"value_string",
"The value for the sole element for the scalar, UTF-8 string, output tensor.",
AttributeProto::STRING,
false)
.Attr(
"value_strings",
"The values for the elements for the 1D, UTF-8 string, output tensor.",
AttributeProto::STRINGS,
false)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types_ir10(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction(ConstantOpInference));
static const char* ConstantOfShape_ver20_doc = R"DOC(
Generate a tensor with given value and shape.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ConstantOfShape,
21,
OpSchema()
.SetDoc(ConstantOfShape_ver20_doc)
.Attr(
"value",
"(Optional) The value of the output elements."
"Should be a one-element tensor. If not specified, it defaults to a tensor of value 0 and datatype float32",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Input(
0,
"input",
"1D tensor. The shape of the expected output tensor. If empty tensor is given, the output would be a scalar."
" All values must be >= 0.",
"T1")
.Output(
0,
"output",
"Output tensor of shape specified by 'input'."
"If attribute 'value' is specified, the value and datatype of the output tensor is taken from 'value'."
"If attribute 'value' is not specified, the value in the output defaults to 0, and the datatype "
"defaults to float32.",
"T2")
.TypeConstraint("T1", {"tensor(int64)"}, "Constrain input types.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(uint4)",
"tensor(int4)",
"tensor(bool)",
"tensor(bfloat16)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)"},
"Constrain output types to be numerics or boolean.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("value") != nullptr) {
propagateElemTypeFromDtypeToOutput(ctx, ctx.getAttribute("value"), 0);
} else {
propagateElemTypeFromDtypeToOutput(ctx, TensorProto::FLOAT, 0);
}
bool found = false;
TensorShapeProto output_shape = getShapeInput(ctx, 0, found);
if (found) {
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = output_shape;
}
}));
static const char* EyeLike_ver22_doc = R"DOC(
Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D
tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the
same as the input tensor. The data type can be specified by the 'dtype' argument. If
'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal
is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
EyeLike,
22,
OpSchema()
.SetDoc(EyeLike_ver22_doc)
.Attr(
"k",
"(Optional) Index of the diagonal to be populated with ones. Default is 0."
" If T2 is the output, this op sets T2[i, i+k] = 1. k = 0 populates the main diagonal, "
"k > 0 populates an upper diagonal, and k < 0 populates a lower diagonal.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor. If not specified,"
"the data type of the input tensor T1 is used. If input tensor T1 is also not"
"specified, then type defaults to 'float'.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "2D input tensor to copy shape, and optionally, type information from.", "T1")
.Output(0, "output", "Output tensor, same shape as input tensor T1.", "T2")
.TypeConstraint(
"T1",
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain input types. Strings and complex are not supported.")
.TypeConstraint(
"T2",
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types. Strings and complex are not supported.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
} else {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
}
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() != 2) {
fail_shape_inference("Input tensor must be 2-dimensional");
}
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* RandomUniform_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution. The shape
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
The data type is specified by the 'dtype' argument. The 'dtype' argument must
be one of the data types specified in the 'DataType' enum field in the
TensorProto message.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomUniform,
22,
OpSchema()
.SetDoc(RandomUniform_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. If not specified, default is TensorProto::FLOAT.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));
static const char* RandomNormal_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution. The shape
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
specified by `mean` and `scale`.
The data type is specified by the 'dtype' argument. The 'dtype' argument must
be one of the data types specified in the 'DataType' enum field in the
TensorProto message.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomNormal,
22,
OpSchema()
.SetDoc(RandomNormal_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. Default is TensorProto::FLOAT.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T")
.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));
static const char* RandomUniformLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the uniform distribution are specified by `low` and `high`.
The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomUniformLike,
22,
OpSchema()
.SetDoc(RandomUniformLike_ver22_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "Input tensor to copy shape and optionally type information from.", "T1")
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint("T2", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* RandomNormalLike_ver22_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the normal distribution are specified by `mean` and `scale`.
The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message, and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomNormalLike,
22,
OpSchema()
.SetDoc(RandomNormalLike_ver22_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "Input tensor to copy shape and optionally type information from.", "T1")
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types_ir4(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint("T2", OpSchema::all_float_types_ir4(), "Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* Multinomial_ver22_doc = R"DOC(
Generate a tensor of samples from a multinomial distribution according to the probabilities
of each of the possible outcomes.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Multinomial,
22,
OpSchema()
.SetDoc(Multinomial_ver22_doc)
.Attr("sample_size", "Number of times to sample.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use int32.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::INT32))
.Input(
0,
"input",
"Input tensor with shape [batch_size, class_size], where class_size is the number of all possible outcomes. Each value along the axis zero represents the unnormalized log-probability of each corresponding outcome in a batch.",
"T1")
.Output(
0,
"output",
"Output tensor with shape [batch_size, sample_size], where sample_size is the number of times to sample. Each value along the axis zero represents the outcome of the corresponding sample in a batch.",
"T2")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain output types to integral tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto dtype = ctx.getAttribute("dtype");
auto dataType = TensorProto_DataType::TensorProto_DataType_INT32;
if (dtype != nullptr) {
dataType = static_cast<TensorProto_DataType>(dtype->i());
if (dataType != TensorProto_DataType::TensorProto_DataType_INT32 &&
dataType != TensorProto_DataType::TensorProto_DataType_INT64) {
fail_type_inference("Output type must be int32 or int64");
}
}
updateOutputElemType(ctx, 0, dataType);
TensorShapeProto::Dimension batch_size, sample_size;
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() != 2) {
fail_shape_inference("Input tensor must have rank 2");
}
batch_size = input_shape.dim(0);
} // else statically-unknown batch-size
sample_size.set_dim_value(getAttribute(ctx, "sample_size", 1));
updateOutputShape(ctx, 0, {batch_size, sample_size});
}));
static const char* Range_ver11_doc = R"DOC(
Generate a tensor containing a sequence of numbers that begin at `start` and extends by increments of `delta`
up to `limit` (exclusive).
The number of elements in the output of range is computed as below:
```
number_of_elements = max( ceil( (limit - start) / delta ) , 0 )
```
The pseudocode determining the contents of the output is shown below:
```
for(int i=0; i<number_of_elements; ++i) {
output[i] = start + (i * delta);
}
```
Example 1
```
Inputs: start = 3, limit = 9, delta = 3
Output: [3, 6]
```
Example 2
```
Inputs: start = 10, limit = 4, delta = -2
Output: [10, 8, 6]
```
)DOC";
template <typename T>
inline int64_t
compute_output_dim_for_range(const TensorProto* start, const TensorProto* limit, const TensorProto* delta) {
if (start->dims().size() != 0 || limit->dims().size() != 0 || delta->dims().size() != 0) {
fail_shape_inference("Input to 'Range' op should be scalars (Tensor with only one element and shape empty)");
}
const auto& start_data = ParseData<T>(start);
const auto& limit_data = ParseData<T>(limit);
const auto& delta_data = ParseData<T>(delta);
int64_t n = static_cast<int64_t>(ceil((1.0 * (limit_data[0] - start_data[0])) / delta_data[0]));
if (n < 0)
n = 0;
return n;
}
ONNX_OPERATOR_SET_SCHEMA(
Range,
11,
OpSchema()
.SetDoc(Range_ver11_doc)
.Input(0, "start", "Scalar. First entry for the range of output values.", "T")
.Input(1, "limit", "Scalar. Exclusive upper limit for the range of output values.", "T")
.Input(2, "delta", "Scalar. Value to step by.", "T")
.Output(0, "output", "A 1-D tensor with same type as the inputs containing generated range of values.", "T")
.TypeConstraint(
"T",
{"tensor(float)", "tensor(double)", "tensor(int16)", "tensor(int32)", "tensor(int64)"},
"Constrain input types to common numeric type tensors.")
.FunctionBody(R"ONNX(
{
sub_result = Sub (limit, start)
sub_result_casted = Cast <to = 1> (sub_result)
delta_casted = Cast <to = 1> (delta)
div_result = Div (sub_result_casted, delta_casted)
ceil_result = Ceil (div_result)
ceil_result_relu = Relu (ceil_result)
ceil_result_relu_int = Cast <to = 7> (ceil_result_relu)
ceil_result_relu_bool = Cast <to = 9> (ceil_result_relu)
variadic_output, output = Loop (ceil_result_relu_int, ceil_result_relu_bool, start)
<body = loop_body_attribute (int64 i, bool cond, prev) => (cond_out, current, range) {
cond_out = Identity (cond)
current = Add (prev, delta)
range = Identity (prev)
}>
}
)ONNX")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
const auto* start_initializer = ctx.getInputData(0);
const auto* limit_initializer = ctx.getInputData(1);
const auto* delta_initializer = ctx.getInputData(2);
// Output is always 1-D
auto* output_dim = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim();
// If any of Range's inputs are not initializers, the output dimension
// value would remain unknown.
if (start_initializer != nullptr && limit_initializer != nullptr && delta_initializer != nullptr) {
// Make sure the input types are homogeneous
if ((start_initializer->data_type() != limit_initializer->data_type()) ||
(start_initializer->data_type() != delta_initializer->data_type())) {
fail_shape_inference("All inputs to 'Range' op must be of the same type");
}
// Explicitly compute the output dimension if Range's inputs are
// stored in initializer list.
if (start_initializer->data_type() == TensorProto::FLOAT) {
output_dim->set_dim_value(
compute_output_dim_for_range<float>(start_initializer, limit_initializer, delta_initializer));
} else if (start_initializer->data_type() == TensorProto::INT32) {
output_dim->set_dim_value(
compute_output_dim_for_range<int32_t>(start_initializer, limit_initializer, delta_initializer));
} else if (start_initializer->data_type() == TensorProto::INT64) {
output_dim->set_dim_value(
compute_output_dim_for_range<int64_t>(start_initializer, limit_initializer, delta_initializer));
} else if (start_initializer->data_type() == TensorProto::DOUBLE) {
output_dim->set_dim_value(
compute_output_dim_for_range<double>(start_initializer, limit_initializer, delta_initializer));
} else {
// 'float16' has no native CPU type -
// stop with rank inference, no action here
}
return;
}
}));
static const char* Bernoulli_ver22_doc = R"DOC(
Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number,
where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p).
This operator is non-deterministic and may not produce the same values in different
implementations (even if a seed is specified).
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Bernoulli,
22,
OpSchema()
.SetDoc(Bernoulli_ver22_doc)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "All values in input have to be in the range:[0, 1].", "T1")
.Output(0, "output", "The returned output tensor only has values 0 or 1, same shape as input tensor.", "T2")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain input types to float tensors.")
.TypeConstraint(
"T2",
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
"Constrain output types to all numeric tensors and bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
})
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) -> bool {
if (ctx.getInputType(0) == nullptr) {
// we cannot create a correct function body without knowing the input type
return false;
}
auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
auto dtype = ctx.getAttribute("dtype") != nullptr
? static_cast<TensorProto_DataType>(ctx.getAttribute("dtype")->i())
: input_type;
FunctionBuilder builder(functionProto);
builder
.Add(
"X_random = RandomUniformLike <low = 0.0, high = 1.0, seed = @seed> (input)",
"dtype",
int64_t(input_type))
.Add("X_greater = Greater (X_random, input)")
.Add("output = Cast (X_greater)", "to", int64_t(dtype));
schema.BuildFunction(functionProto);
return true;
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,768 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <functional>
#include "onnx/defs/function.h"
#include "onnx/defs/generator/utils.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* Bernoulli_ver15_doc = R"DOC(
Draws binary random numbers (0 or 1) from a Bernoulli distribution. The input tensor should be a tensor
containing probabilities p (a value in the range [0,1]) to be used for drawing the binary random number,
where an output of 1 is produced with probability p and an output of 0 is produced with probability (1-p).
This operator is non-deterministic and may not produce the same values in different
implementations (even if a seed is specified).
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Bernoulli,
15,
OpSchema()
.SetDoc(Bernoulli_ver15_doc)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "All values in input have to be in the range:[0, 1].", "T1")
.Output(0, "output", "The returned output tensor only has values 0 or 1, same shape as input tensor.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(bool)"},
"Constrain output types to all numeric tensors and bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
})
.SetContextDependentFunctionBodyBuilder(
[](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) -> bool {
if (ctx.getInputType(0) == nullptr) {
// we cannot create a correct function body without knowing the input type
return false;
}
auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
auto dtype = ctx.getAttribute("dtype") != nullptr
? static_cast<TensorProto_DataType>(ctx.getAttribute("dtype")->i())
: input_type;
FunctionBuilder builder(functionProto);
builder
.Add(
"X_random = RandomUniformLike <low = 0.0, high = 1.0, seed = @seed> (input)",
"dtype",
int64_t(input_type))
.Add("X_greater = Greater (X_random, input)")
.Add("output = Cast (X_greater)", "to", int64_t(dtype));
schema.BuildFunction(functionProto);
return true;
}));
static const char* Multinomial_ver7_doc = R"DOC(
Generate a tensor of samples from a multinomial distribution according to the probabilities
of each of the possible outcomes.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Multinomial,
7,
OpSchema()
.SetDoc(Multinomial_ver7_doc)
.Attr("sample_size", "Number of times to sample.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use int32.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::INT32))
.Input(
0,
"input",
"Input tensor with shape [batch_size, class_size], where class_size is the number of all possible outcomes. Each value along the axis zero represents the unnormalized log-probability of each corresponding outcome in a batch.",
"T1")
.Output(
0,
"output",
"Output tensor with shape [batch_size, sample_size], where sample_size is the number of times to sample. Each value along the axis zero represents the outcome of the corresponding sample in a batch.",
"T2")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input types to float tensors.")
.TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain output types to integral tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto dtype = ctx.getAttribute("dtype");
auto dataType = TensorProto_DataType::TensorProto_DataType_INT32;
if (dtype != nullptr) {
dataType = static_cast<TensorProto_DataType>(dtype->i());
if (dataType != TensorProto_DataType::TensorProto_DataType_INT32 &&
dataType != TensorProto_DataType::TensorProto_DataType_INT64) {
fail_type_inference("Output type must be int32 or int64");
}
}
updateOutputElemType(ctx, 0, dataType);
TensorShapeProto::Dimension batch_size, sample_size;
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() != 2) {
fail_shape_inference("Input tensor must have rank 2");
}
batch_size = input_shape.dim(0);
} // else statically-unknown batch-size
sample_size.set_dim_value(getAttribute(ctx, "sample_size", 1));
updateOutputShape(ctx, 0, {batch_size, sample_size});
}));
static const char* RandomNormalLike_ver1_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the normal distribution are specified by `mean` and `scale`.
The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message, and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomNormalLike,
1,
OpSchema()
.SetDoc(RandomNormalLike_ver1_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "Input tensor to copy shape and optionally type information from.", "T1")
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* RandomUniformLike_ver1_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution.
The shape of the output tensor is copied from the shape of the input tensor,
and the parameters of the uniform distribution are specified by `low` and `high`.
The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomUniformLike,
1,
OpSchema()
.SetDoc(RandomUniformLike_ver1_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor, if not specified, we will use "
"the data type of the input tensor.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "Input tensor to copy shape and optionally type information from.", "T1")
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T2")
.TypeConstraint(
"T1",
OpSchema::all_tensor_types(),
"Constrain to any tensor type. If the dtype attribute is not provided this must be a valid output type.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr)
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
else
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* RandomNormal_ver1_doc = R"DOC(
Generate a tensor with random values drawn from a normal distribution. The shape
of the tensor is specified by the `shape` argument and the parameter of the normal distribution
specified by `mean` and `scale`.
The data type is specified by the 'dtype' argument. The 'dtype' argument must
be one of the data types specified in the 'DataType' enum field in the
TensorProto message.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomNormal,
1,
OpSchema()
.SetDoc(RandomNormal_ver1_doc)
.Attr("mean", "The mean of the normal distribution.", AttributeProto::FLOAT, 0.0f)
.Attr("scale", "The standard deviation of the normal distribution.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. Default is TensorProto::FLOAT.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from normal distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));
static const char* RandomUniform_ver1_doc = R"DOC(
Generate a tensor with random values drawn from a uniform distribution. The shape
of the tensor is specified by the `shape` argument and the range by `low` and `high`.
The data type is specified by the 'dtype' argument. The 'dtype' argument must
be one of the data types specified in the 'DataType' enum field in the
TensorProto message.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RandomUniform,
1,
OpSchema()
.SetDoc(RandomUniform_ver1_doc)
.Attr("low", "Lower boundary of the output values.", AttributeProto::FLOAT, 0.0f)
.Attr("high", "Upper boundary of the output values.", AttributeProto::FLOAT, 1.0f)
.Attr(
"seed",
"(Optional) Seed to the random generator, if not specified we will auto generate one.",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr(
"dtype",
"The data type for the elements of the output tensor. If not specified, default is TensorProto::FLOAT.",
AttributeProto::INT,
static_cast<int64_t>(TensorProto::FLOAT))
.Attr("shape", "The shape of the output tensor.", AttributeProto::INTS)
.Output(0, "output", "Output tensor of random values drawn from uniform distribution", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
propagateShapeFromAttributeToOutput(ctx, "shape", 0);
}));
static const char* EyeLike_ver9_doc = R"DOC(
Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D
tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the
same as the input tensor. The data type can be specified by the 'dtype' argument. If
'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal
is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals.
The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the
TensorProto message and be valid as an output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
EyeLike,
9,
OpSchema()
.SetDoc(EyeLike_ver9_doc)
.Attr(
"k",
"(Optional) Index of the diagonal to be populated with ones. Default is 0."
" If T2 is the output, this op sets T2[i, i+k] = 1. k = 0 populates the main diagonal, "
"k > 0 populates an upper diagonal, and k < 0 populates a lower diagonal.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"dtype",
"(Optional) The data type for the elements of the output tensor. If not specified,"
"the data type of the input tensor T1 is used. If input tensor T1 is also not"
"specified, then type defaults to 'float'.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0, "input", "2D input tensor to copy shape, and optionally, type information from.", "T1")
.Output(0, "output", "Output tensor, same shape as input tensor T1.", "T2")
.TypeConstraint(
"T1",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
"Constrain input types. Strings and complex are not supported.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
"Constrain output types. Strings and complex are not supported.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("dtype") != nullptr) {
propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0);
} else {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
}
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() != 2) {
fail_shape_inference("Input tensor must be 2-dimensional");
}
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* Constant_ver19_doc = R"DOC(
This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
or value_* must be specified.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
19,
OpSchema()
.SetDoc(Constant_ver19_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Attr(
"value_int",
"The value for the sole element for the scalar, int64, output tensor.",
AttributeProto::INT,
false)
.Attr(
"value_ints",
"The values for the elements for the 1D, int64, output tensor.",
AttributeProto::INTS,
false)
.Attr(
"value_float",
"The value for the sole element for the scalar, float32, output tensor.",
AttributeProto::FLOAT,
false)
.Attr(
"value_floats",
"The values for the elements for the 1D, float32, output tensor.",
AttributeProto::FLOATS,
false)
.Attr(
"value_string",
"The value for the sole element for the scalar, UTF-8 string, output tensor.",
AttributeProto::STRING,
false)
.Attr(
"value_strings",
"The values for the elements for the 1D, UTF-8 string, output tensor.",
AttributeProto::STRINGS,
false)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types_ir9(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction(ConstantOpInference));
static const char* Constant_ver13_doc = R"DOC(
This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
or value_* must be specified.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
13,
OpSchema()
.SetDoc(Constant_ver13_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Attr(
"value_int",
"The value for the sole element for the scalar, int64, output tensor.",
AttributeProto::INT,
false)
.Attr(
"value_ints",
"The values for the elements for the 1D, int64, output tensor.",
AttributeProto::INTS,
false)
.Attr(
"value_float",
"The value for the sole element for the scalar, float32, output tensor.",
AttributeProto::FLOAT,
false)
.Attr(
"value_floats",
"The values for the elements for the 1D, float32, output tensor.",
AttributeProto::FLOATS,
false)
.Attr(
"value_string",
"The value for the sole element for the scalar, UTF-8 string, output tensor.",
AttributeProto::STRING,
false)
.Attr(
"value_strings",
"The values for the elements for the 1D, UTF-8 string, output tensor.",
AttributeProto::STRINGS,
false)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction(ConstantOpInference));
static const char* Constant_ver12_doc = R"DOC(
This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value,
or value_* must be specified.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
12,
OpSchema()
.SetDoc(Constant_ver12_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Attr(
"value_int",
"The value for the sole element for the scalar, int64, output tensor.",
AttributeProto::INT,
false)
.Attr(
"value_ints",
"The values for the elements for the 1D, int64, output tensor.",
AttributeProto::INTS,
false)
.Attr(
"value_float",
"The value for the sole element for the scalar, float32, output tensor.",
AttributeProto::FLOAT,
false)
.Attr(
"value_floats",
"The values for the elements for the 1D, float32, output tensor.",
AttributeProto::FLOATS,
false)
.Attr(
"value_string",
"The value for the sole element for the scalar, UTF-8 string, output tensor.",
AttributeProto::STRING,
false)
.Attr(
"value_strings",
"The values for the elements for the 1D, UTF-8 string, output tensor.",
AttributeProto::STRINGS,
false)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction(ConstantOpInference));
static const char* Constant_ver1_doc = R"DOC(A constant tensor.)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
1,
OpSchema()
.SetDoc(Constant_ver1_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto attr_proto = ctx.getAttribute("value");
if (nullptr == attr_proto)
return; // attribute not present
if (!attr_proto->has_t())
return; // attribute has no tensor value
const TensorProto& tensor_proto = attr_proto->t();
updateOutputElemType(ctx, 0, tensor_proto.data_type());
updateOutputShape(ctx, 0, tensor_proto);
}));
static const char* Constant_ver9_doc = R"DOC(A constant tensor.)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
9,
OpSchema()
.SetDoc(Constant_ver9_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto attr_proto = ctx.getAttribute("value");
if (nullptr == attr_proto || !attr_proto->has_t())
fail_shape_inference("Attribute 'value' of Constant node must exist with 'Tensor' data.");
const TensorProto& tensor_proto = attr_proto->t();
updateOutputElemType(ctx, 0, tensor_proto.data_type());
updateOutputShape(ctx, 0, tensor_proto);
}));
static const char* Constant_ver11_doc = R"DOC(
A constant tensor. Exactly one of the two attributes, either value or sparse_value,
must be specified.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Constant,
11,
OpSchema()
.SetDoc(Constant_ver11_doc)
.Attr("value", "The value for the elements of the output tensor.", AttributeProto::TENSOR, false)
.Attr(
"sparse_value",
"The value for the elements of the output tensor in sparse format.",
AttributeProto::SPARSE_TENSOR,
false)
.Output(0, "output", "Output tensor containing the same value of the provided tensor.", "T")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto* value = ctx.getAttribute("value");
auto* sparse_value = ctx.getAttribute("sparse_value");
if ((nullptr != value) && (nullptr != sparse_value))
fail_shape_inference(
"Only one of the attributes 'value' or 'sparse_value' must be specified for a Constant node.");
if (nullptr != value) {
// OpSchema::Verify check ensures that the attribute value has_t():
const TensorProto& tensor_proto = value->t();
updateOutputElemType(ctx, 0, tensor_proto.data_type());
updateOutputShape(ctx, 0, tensor_proto);
return;
}
if (nullptr != sparse_value) {
// OpSchema::Verify check ensures that the attribute value
// has_sparse_tensor():
const SparseTensorProto& sparse = sparse_value->sparse_tensor();
// checker.cc::check_sparse_tensor checks that the sparse-value is
// well-formed
updateOutputElemType(ctx, 0, sparse.values().data_type());
auto* output_shape = getOutputShape(ctx, 0);
for (int i = 0; i < sparse.dims_size(); ++i)
appendDim(output_shape, sparse.dims(i));
return;
}
fail_shape_inference(
"One of the attributes 'value' or 'sparse_value' must be specified for a Constant node.");
}));
static const char* ConstantOfShape_ver20_doc = R"DOC(
Generate a tensor with given value and shape.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ConstantOfShape,
20,
OpSchema()
.SetDoc(ConstantOfShape_ver20_doc)
.Attr(
"value",
"(Optional) The value of the output elements."
"Should be a one-element tensor. If not specified, it defaults to a tensor of value 0 and datatype float32",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Input(
0,
"input",
"1D tensor. The shape of the expected output tensor. If empty tensor is given, the output would be a scalar."
" All values must be >= 0.",
"T1")
.Output(
0,
"output",
"Output tensor of shape specified by 'input'."
"If attribute 'value' is specified, the value and datatype of the output tensor is taken from 'value'."
"If attribute 'value' is not specified, the value in the output defaults to 0, and the datatype "
"defaults to float32.",
"T2")
.TypeConstraint("T1", {"tensor(int64)"}, "Constrain input types.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)",
"tensor(bfloat16)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)"},
"Constrain output types to be numerics.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("value") != nullptr) {
propagateElemTypeFromDtypeToOutput(ctx, ctx.getAttribute("value"), 0);
} else {
propagateElemTypeFromDtypeToOutput(ctx, TensorProto::FLOAT, 0);
}
bool found = false;
TensorShapeProto output_shape = getShapeInput(ctx, 0, found);
if (found) {
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = output_shape;
}
}));
static const char* ConstantOfShape_ver9_doc = R"DOC(
Generate a tensor with given value and shape.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ConstantOfShape,
9,
OpSchema()
.SetDoc(ConstantOfShape_ver9_doc)
.Attr(
"value",
"(Optional) The value of the output elements."
"Should be a one-element tensor. If not specified, it defaults to a tensor of value 0 and datatype float32",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Input(
0,
"input",
"1D tensor. The shape of the expected output tensor. If empty tensor is given, the output would be a scalar."
" All values must be >= 0.",
"T1")
.Output(
0,
"output",
"Output tensor of shape specified by 'input'."
"If attribute 'value' is specified, the value and datatype of the output tensor is taken from 'value'."
"If attribute 'value' is not specified, the value in the output defaults to 0, and the datatype "
"defaults to float32.",
"T2")
.TypeConstraint("T1", {"tensor(int64)"}, "Constrain input types.")
.TypeConstraint(
"T2",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(bool)"},
"Constrain output types to be numerics.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (ctx.getAttribute("value") != nullptr) {
propagateElemTypeFromDtypeToOutput(ctx, ctx.getAttribute("value"), 0);
} else {
propagateElemTypeFromDtypeToOutput(ctx, TensorProto::FLOAT, 0);
}
bool found = false;
TensorShapeProto output_shape = getShapeInput(ctx, 0, found);
if (found) {
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = output_shape;
}
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,111 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/generator/utils.h"
#include <algorithm>
#include <cmath>
namespace ONNX_NAMESPACE {
void ConstantOpInference(InferenceContext& ctx) {
auto* value = ctx.getAttribute("value");
auto* sparse_value = ctx.getAttribute("sparse_value");
auto* value_int = ctx.getAttribute("value_int");
auto* value_ints = ctx.getAttribute("value_ints");
auto* value_float = ctx.getAttribute("value_float");
auto* value_floats = ctx.getAttribute("value_floats");
auto* value_string = ctx.getAttribute("value_string");
auto* value_strings = ctx.getAttribute("value_strings");
std::vector<bool> non_null_attr = {
(nullptr != value),
(nullptr != sparse_value),
(nullptr != value_int),
(nullptr != value_ints),
(nullptr != value_float),
(nullptr != value_floats),
(nullptr != value_string),
(nullptr != value_strings)};
if (std::count(non_null_attr.begin(), non_null_attr.end(), true) != 1) {
fail_shape_inference(
"One and only one of the attributes 'value', 'value_*' or 'sparse_value' must be specified for a Constant node.");
}
if (nullptr != value) {
// OpSchema::Verify check ensures that the attribute value has_t():
const TensorProto& tensor_proto = value->t();
updateOutputElemType(ctx, 0, tensor_proto.data_type());
updateOutputShape(ctx, 0, tensor_proto);
return;
}
if (nullptr != value_int) {
// OpSchema::Verify check ensures that the attribute value has_i():
if (!value_int->has_i()) {
fail_shape_inference("Attribute 'value_int' expect an integer.")
}
updateOutputElemType(ctx, 0, TensorProto::INT64);
updateOutputShape(ctx, 0, TensorShapeProto());
return;
}
if (nullptr != value_ints) {
updateOutputElemType(ctx, 0, TensorProto::INT64);
appendDim(getOutputShape(ctx, 0), value_ints->ints_size());
return;
}
if (nullptr != value_float) {
// OpSchema::Verify check ensures that the attribute value has_i():
if (!value_float->has_f()) {
fail_shape_inference("Attribute 'value_float' expect a float.");
}
updateOutputElemType(ctx, 0, TensorProto::FLOAT);
updateOutputShape(ctx, 0, TensorShapeProto());
return;
}
if (nullptr != value_floats) {
updateOutputElemType(ctx, 0, TensorProto::FLOAT);
appendDim(getOutputShape(ctx, 0), value_floats->floats_size());
return;
}
if (nullptr != value_string) {
// OpSchema::Verify check ensures that the attribute value has_i():
if (!value_string->has_s()) {
fail_shape_inference("Attribute 'value_string' expect a string.");
}
updateOutputElemType(ctx, 0, TensorProto::STRING);
updateOutputShape(ctx, 0, TensorShapeProto());
return;
}
if (nullptr != value_strings) {
updateOutputElemType(ctx, 0, TensorProto::STRING);
appendDim(getOutputShape(ctx, 0), value_strings->strings_size());
return;
}
if (nullptr != sparse_value) {
// OpSchema::Verify check ensures that the attribute value
// has_sparse_tensor():
const SparseTensorProto& sparse = sparse_value->sparse_tensor();
// checker.cc::check_sparse_tensor checks that the sparse-value is
// well-formed
updateOutputElemType(ctx, 0, sparse.values().data_type());
auto* output_shape = getOutputShape(ctx, 0);
for (int i = 0; i < sparse.dims_size(); ++i)
appendDim(output_shape, sparse.dims(i));
return;
}
fail_shape_inference(
"TypeAndShapeInferenceFunction implementation incomplete: "
"this line should never be reached.");
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,13 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
void ConstantOpInference(InferenceContext& ctx);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <functional>
#include "onnx/defs/data_type_utils.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
static const char* ImageDecoder_ver20_doc =
R"DOC(Loads and decodes and image from a file. If it can't decode for any reason (e.g. corrupted encoded
stream, invalid format, it will return an empty matrix).
The following image formats are supported:
* BMP
* JPEG (note: Lossless JPEG support is optional)
* JPEG2000
* TIFF
* PNG
* WebP
* Portable image format (PBM, PGM, PPM, PXM, PNM)
Decoded images follow a channel-last layout: (Height, Width, Channels).
**JPEG chroma upsampling method:**
When upsampling the chroma components by a factor of 2, the pixels are linearly interpolated so that the
centers of the output pixels are 1/4 and 3/4 of the way between input pixel centers.
When rounding, 0.5 is rounded down and up at alternative pixels locations to prevent bias towards
larger values (ordered dither pattern).
Considering adjacent input pixels A, B, and C, B is upsampled to pixels B0 and B1 so that
```
B0 = round_half_down((1/4) * A + (3/4) * B)
B1 = round_half_up((3/4) * B + (1/4) * C)
```
This method, is the default chroma upsampling method in the well-established libjpeg-turbo library,
also referred as "smooth" or "fancy" upsampling.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ImageDecoder,
20,
OpSchema()
.SetDoc(ImageDecoder_ver20_doc)
.Attr(
"pixel_format",
"Pixel format. Can be one of \"RGB\", \"BGR\", or \"Grayscale\".",
AttributeProto::STRING,
std::string("RGB"))
.Input(0, "encoded_stream", "Encoded stream", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "image", "Decoded image", "T2", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint("T1", {"tensor(uint8)"}, "Constrain input types to 8-bit unsigned integer tensor.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain output types to 8-bit unsigned integer tensor.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (hasInputShape(ctx, 0)) {
auto& input_shape = getInputShape(ctx, 0);
if (input_shape.dim_size() != 1) {
fail_shape_inference("Input tensor must be 1-dimensional");
}
}
propagateElemTypeFromDtypeToOutput(ctx, TensorProto::UINT8, 0);
auto output_type = ctx.getOutputType(0);
auto* sh = output_type->mutable_tensor_type()->mutable_shape();
sh->clear_dim();
sh->add_dim();
sh->add_dim();
sh->add_dim();
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,342 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
inline void unaryLogicalOpInference(InferenceContext& ctx) {
// Type inference
updateOutputElemType(ctx, 0, TensorProto::BOOL);
// Shape inference
if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
}
}
std::function<void(OpSchema&)> BinaryLogicDocGenerator(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Returns the tensor resulted from performing the `{name}` logical operation
elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support).
{broadcast_doc}
)DOC";
ReplaceAll(doc, "{name}", name);
ReplaceAll(doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
schema.SetDoc(doc);
schema.Input(
0,
"A",
"First input operand for the logical operator.",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.Input(
1,
"B",
"Second input operand for the logical operator.",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.Output(0, "C", "Result tensor.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable);
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
updateOutputElemType(ctx, 0, TensorProto::BOOL);
// Shape inference
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
};
}
ONNX_OPERATOR_SET_SCHEMA(
And,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("and"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Or,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("or"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Xor,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("xor"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Greater,
13,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("greater"))
.TypeConstraint("T", OpSchema::all_numeric_types_ir4(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Less,
13,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("less"))
.TypeConstraint("T", OpSchema::all_numeric_types_ir4(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Equal,
19,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("equal"))
.TypeConstraint(
"T",
{"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)",
"tensor(string)"},
"Constrain input types to all (non-complex) tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
static const char* Not_ver1_doc = R"DOC(
Returns the negation of the input tensor element-wise.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Not,
1,
OpSchema()
.SetDoc(Not_ver1_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input/output to boolean tensors.")
.TypeAndShapeInferenceFunction(unaryLogicalOpInference));
static const char* BitShift_ver11_doc = R"DOC(
Bitwise shift operator performs element-wise operation. For each input element, if the
attribute "direction" is "RIGHT", this operator moves its binary representation toward
the right side so that the input value is effectively decreased. If the attribute "direction"
is "LEFT", bits of binary representation moves toward the left side, which results the
increase of its actual value. The input X is the tensor to be shifted and another input
Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4],
and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with
X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8].
Because this operator supports Numpy-style broadcasting, X's and Y's shapes are
not necessarily identical.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
BitShift,
11,
OpSchema()
.SetDoc(GET_OP_DOC_STR(std::string(BitShift_ver11_doc) + GenerateBroadcastingDocMul()))
.Input(
0,
"X",
"First operand, input to be shifted.",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Input(1, "Y", "Second operand, amounts of shift.", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "Z", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint(
"T",
{"tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)"},
"Constrain input and output types to integer tensors.")
.Attr(
"direction",
"Direction of moving bits. It can be either \"RIGHT\" (for right shift) "
"or \"LEFT\" (for left shift).",
AttributeProto::STRING)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}));
ONNX_OPERATOR_SET_SCHEMA(
LessOrEqual,
16,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("less_equal"))
.TypeConstraint("T", OpSchema::all_numeric_types_ir4(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor.")
.TypeAndShapeInferenceFunction(InferenceFunction())
.FunctionBody(R"ONNX(
{
O1 = Less (A, B)
O2 = Equal (A, B)
C = Or (O1, O2)
}
)ONNX"));
ONNX_OPERATOR_SET_SCHEMA(
GreaterOrEqual,
16,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("greater_equal"))
.TypeConstraint("T", OpSchema::all_numeric_types_ir4(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor.")
.TypeAndShapeInferenceFunction(InferenceFunction())
.FunctionBody(R"ONNX(
{
O1 = Greater (A, B)
O2 = Equal (A, B)
C = Or (O1, O2)
}
)ONNX"));
static const char* BitwiseNot_ver18_doc = R"DOC(
Returns the bitwise not of the input tensor element-wise.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
BitwiseNot,
18,
OpSchema()
.SetDoc(BitwiseNot_ver18_doc)
.Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint(
"T",
{"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input/output to integer tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
std::function<void(OpSchema&)> BinaryBitwiseDocGenerator(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Returns the tensor resulting from performing the bitwise `{name}` operation
elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support).
{broadcast_doc}
)DOC";
ReplaceAll(doc, "{name}", name);
ReplaceAll(doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
schema.SetDoc(doc);
schema.Input(
0,
"A",
"First input operand for the bitwise operator.",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.Input(
1,
"B",
"Second input operand for the bitwise operator.",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.Output(0, "C", "Result tensor.", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable);
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
};
}
ONNX_OPERATOR_SET_SCHEMA(
BitwiseAnd,
18,
OpSchema()
.FillUsing(BinaryBitwiseDocGenerator("and"))
.TypeConstraint(
"T",
{"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input to integer tensors."));
ONNX_OPERATOR_SET_SCHEMA(
BitwiseOr,
18,
OpSchema()
.FillUsing(BinaryBitwiseDocGenerator("or"))
.TypeConstraint(
"T",
{"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input to integer tensors."));
ONNX_OPERATOR_SET_SCHEMA(
BitwiseXor,
18,
OpSchema()
.FillUsing(BinaryBitwiseDocGenerator("xor"))
.TypeConstraint(
"T",
{"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input to integer tensors."));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,274 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
std::function<void(OpSchema&)> BinaryLogicDocGenerator_opset12(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Returns the tensor resulted from performing the `{name}` logical operation
elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support).
{broadcast_doc}
)DOC";
ReplaceAll(doc, "{name}", name);
ReplaceAll(doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
schema.SetDoc(doc);
schema.Input(0, "A", "First input operand for the logical operator.", "T");
schema.Input(1, "B", "Second input operand for the logical operator.", "T");
schema.Output(0, "C", "Result tensor.", "T1");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
updateOutputElemType(ctx, 0, TensorProto::BOOL);
// Shape inference
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
};
}
ONNX_OPERATOR_SET_SCHEMA(
Greater,
9,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset12("greater"))
.TypeConstraint("T", OpSchema::all_numeric_types(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Less,
9,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset12("less"))
.TypeConstraint("T", OpSchema::all_numeric_types(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Equal,
11,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset12("equal"))
.TypeConstraint(
"T",
{"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(float16)",
"tensor(float)",
"tensor(double)"},
"Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
inline void logicalOpInference_opset1(InferenceContext& ctx) {
updateOutputElemType(ctx, 0, TensorProto::BOOL);
if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
}
}
std::function<void(OpSchema&)> BinaryLogicDocGenerator_opset1(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Returns the tensor resulted from performing the `{name}` logical operation
elementwise on the input tensors `A` and `B`.
If broadcasting is enabled, the right-hand-side argument will be broadcasted
to match the shape of left-hand-side argument. See the doc of `Add` for a
detailed description of the broadcasting rules.
)DOC";
ReplaceAll(doc, "{name}", name););
schema.SetDoc(doc);
schema.Attr("broadcast", "Enable broadcasting", AttributeProto::INT, static_cast<int64_t>(0));
schema.Attr("axis", "If set, defines the broadcast dimensions.", AttributeProto::INT, OPTIONAL_VALUE);
schema.Input(0, "A", "Left input tensor for the logical operator.", "T");
schema.Input(1, "B", "Right input tensor for the logical operator.", "T");
schema.Output(0, "C", "Result tensor.", "T1");
schema.TypeAndShapeInferenceFunction(logicalOpInference_opset1);
};
}
std::function<void(OpSchema&)> BinaryLogicDocGenerator_opset7(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Returns the tensor resulted from performing the `{name}` logical operation
elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support).
{broadcast_doc}
)DOC";
ReplaceAll(doc, "{name}", name);
ReplaceAll(doc, "{broadcast_doc}", GenerateBroadcastingDocMul().c_str()););
schema.SetDoc(doc);
schema.Input(0, "A", "First input operand for the logical operator.", "T");
schema.Input(1, "B", "Second input operand for the logical operator.", "T");
schema.Output(0, "C", "Result tensor.", "T1");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
updateOutputElemType(ctx, 0, TensorProto::BOOL);
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
});
};
}
ONNX_OPERATOR_SET_SCHEMA(
And,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("and"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Or,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("or"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Xor,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("xor"))
.TypeConstraint("T", {"tensor(bool)"}, "Constrain input to boolean tensor.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Greater,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("greater"))
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input to float tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Less,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("less"))
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input to float tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Equal,
1,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset1("equal"))
.TypeConstraint("T", {"tensor(bool)", "tensor(int32)", "tensor(int64)"}, "Constrain input to integral tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Equal,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset7("equal"))
.TypeConstraint("T", {"tensor(bool)", "tensor(int32)", "tensor(int64)"}, "Constrain input to integral tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Greater,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset7("greater"))
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input to float tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
ONNX_OPERATOR_SET_SCHEMA(
Less,
7,
OpSchema()
.FillUsing(BinaryLogicDocGenerator_opset7("less"))
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input to float tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
// Shares same doc generator as newer opset 16 version.
extern std::function<void(OpSchema&)> BinaryLogicDocGenerator(const char* name);
ONNX_OPERATOR_SET_SCHEMA(
LessOrEqual,
12,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("less_equal"))
.TypeConstraint("T", OpSchema::all_numeric_types(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor.")
.TypeAndShapeInferenceFunction(InferenceFunction())
.FunctionBody(R"ONNX(
{
O1 = Less (A, B)
O2 = Equal (A, B)
C = Or (O1, O2)
}
)ONNX"));
ONNX_OPERATOR_SET_SCHEMA(
GreaterOrEqual,
12,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("greater_equal"))
.TypeConstraint("T", OpSchema::all_numeric_types(), "Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor.")
.TypeAndShapeInferenceFunction(InferenceFunction())
.FunctionBody(R"ONNX(
{
O1 = Greater (A, B)
O2 = Equal (A, B)
C = Or (O1, O2)
}
)ONNX"));
ONNX_OPERATOR_SET_SCHEMA(
Equal,
13,
OpSchema()
.FillUsing(BinaryLogicDocGenerator("equal"))
.TypeConstraint(
"T",
{"tensor(bool)",
"tensor(uint8)",
"tensor(uint16)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int8)",
"tensor(int16)",
"tensor(int32)",
"tensor(int64)",
"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(bfloat16)"},
"Constrain input types to all numeric tensors.")
.TypeConstraint("T1", {"tensor(bool)"}, "Constrain output to boolean tensor."));
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,138 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/math/utils.h"
#include <string>
namespace ONNX_NAMESPACE {
namespace defs {
namespace math {
namespace utils {
int MathOpTwoIntegers(std::string op_type, int a, int b) {
if (op_type == "Add") {
return a + b;
} else if (op_type == "Sub") {
return a - b;
} else if (op_type == "Mul") {
return a * b;
}
fail_shape_inference("Wrong op_type name for running propagation: ", op_type);
}
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx) {
if (!hasInputShape(ctx, input1Idx) || !hasInputShape(ctx, input2Idx)) {
return;
}
const auto shape0 = ctx.getInputType(input1Idx)->tensor_type().shape();
const auto shape1 = ctx.getInputType(input2Idx)->tensor_type().shape();
if (shape0.dim_size() == 0 || shape1.dim_size() == 0) {
fail_shape_inference("Input tensors of wrong rank (0).");
}
ONNX_NAMESPACE::TensorShapeProto shapeL, shapeR;
// First promote each shape to at least rank-2. This logic is
// specific to matmul, not generic broadcasting.
{
if (shape0.dim_size() == 1) {
shapeL.add_dim()->set_dim_value(1);
*shapeL.add_dim() = shape0.dim(0);
} else {
*shapeL.mutable_dim() = shape0.dim();
}
if (shape1.dim_size() == 1) {
*shapeR.add_dim() = shape1.dim(0);
shapeR.add_dim()->set_dim_value(1);
} else {
*shapeR.mutable_dim() = shape1.dim();
}
}
// Check for compatible matrix multiply dimensions
{
auto dimL = shapeL.dim(shapeL.dim_size() - 1);
auto dimR = shapeR.dim(shapeR.dim_size() - 2);
if (dimL.has_dim_value() && dimR.has_dim_value() && dimL.dim_value() != dimR.dim_value()) {
fail_shape_inference("Incompatible dimensions for matrix multiplication");
}
}
ONNX_NAMESPACE::TensorShapeProto resultShape;
// Now call out to generic multidimensional broadcasting for
// the broadcastable prefixes.
{
ONNX_NAMESPACE::TensorShapeProto prefixShapeL, prefixShapeR;
for (int i = 0; i < shapeL.dim_size() - 2; ++i) {
*prefixShapeL.add_dim() = shapeL.dim(i);
}
for (int i = 0; i < shapeR.dim_size() - 2; ++i) {
*prefixShapeR.add_dim() = shapeR.dim(i);
}
bidirectionalBroadcastShapeInference(prefixShapeL, prefixShapeR, resultShape);
}
// Back to matmul-specific. Add the trailing dimensions back in.
{
if (shape0.dim_size() != 1) {
*resultShape.add_dim() = shapeL.dim(shapeL.dim_size() - 2);
}
if (shape1.dim_size() != 1) {
*resultShape.add_dim() = shapeR.dim(shapeR.dim_size() - 1);
}
}
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
}
void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
auto a_type = ctx.getInputType(0);
auto b_type = ctx.getInputType(3);
if (nullptr == a_type || nullptr == b_type || a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference("inputs are expected to have tensor type.");
}
auto a_zero_point_type = ctx.getInputType(2);
if (nullptr == a_zero_point_type ||
a_zero_point_type->tensor_type().elem_type() != a_type->tensor_type().elem_type()) {
fail_type_inference("input and zero_point pair is expected to have be same type.");
}
auto b_zero_point_type = ctx.getInputType(5);
if (nullptr == b_zero_point_type ||
b_zero_point_type->tensor_type().elem_type() != b_type->tensor_type().elem_type()) {
fail_type_inference("input and zero_point pair is expected to have same type.");
}
propagateElemTypeFromInputToOutput(ctx, 7, 0);
MatMulShapeInference(ctx, 0, 3);
}
const char* QLinearMatMulDoc() {
static const char* QLinearMatMul_doc = R"DOC(
Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
It consumes two quantized input tensors, their scales and zero points, scale and zero point of output,
and computes the quantized output. The quantization formula is y = saturate((x / y_scale) + y_zero_point).
For (x / y_scale), it is rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
Scale and zero point must have same shape. They must be either scalar (per tensor) or N-D tensor
(per row for 'a' and per column for 'b'). Scalar refers to per tensor quantization whereas N-D refers to per row
or per column quantization. If the input is 2D of shape [M, K] then zero point and scale tensor may be
an M element vector [v_1, v_2, ..., v_M] for per row quantization and K element vector of shape [v_1, v_2, ..., v_K]
for per column quantization. If the input is N-D tensor with shape [D1, D2, M, K] then zero point and scale tensor may
have shape [D1, D2, M, 1] for per row quantization and shape [D1, D2, 1, K] for per column quantization.
Production must never overflow, and accumulation may overflow if and only if in 32 bits.
)DOC";
return QLinearMatMul_doc;
}
} // namespace utils
} // namespace math
} // namespace defs
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,48 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/defs/shape_inference.h"
#include "onnx/defs/tensor_proto_util.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
namespace defs {
namespace math {
namespace utils {
template <typename T>
T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
if (t == nullptr) {
return T{};
}
auto data_type = t->data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
case ONNX_NAMESPACE::TensorProto::DOUBLE:
return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT32:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT64:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
default:
fail_shape_inference("Unsupported input data type of ", data_type);
}
}
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
const char* QLinearMatMulDoc();
int MathOpTwoIntegers(std::string op_type, int a, int b);
} // namespace utils
} // namespace math
} // namespace defs
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,199 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
static const char* RoiAlign_ver22_doc = R"DOC(
Region of Interest (RoI) align operation described in the
[Mask R-CNN paper](https://arxiv.org/abs/1703.06870).
RoiAlign consumes an input tensor X and region of interests (rois)
to apply pooling across each RoI; it produces a 4-D tensor of shape
(num_rois, C, output_height, output_width).
RoiAlign is proposed to avoid the misalignment by removing
quantizations while converting from original image into feature
map and from feature map into RoI feature; in each ROI bin,
the value of the sampled locations are computed directly
through bilinear interpolation.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RoiAlign,
22,
OpSchema()
.SetDoc(RoiAlign_ver22_doc)
.Attr(
"spatial_scale",
"Multiplicative spatial scale factor to translate ROI coordinates "
"from their input spatial scale to the scale used when pooling, "
"i.e., spatial scale of the input feature map X relative to the "
"input image. E.g.; default is 1.0f. ",
AttributeProto::FLOAT,
1.f)
.Attr("output_height", "default 1; Pooled output Y's height.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("output_width", "default 1; Pooled output Y's width.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"sampling_ratio",
"Number of sampling points in the interpolation grid used to compute "
"the output value of each pooled output bin. If > 0, then exactly "
"sampling_ratio x sampling_ratio grid points are used. If == 0, then "
"an adaptive number of grid points are used (computed as "
"ceil(roi_width / output_width), and likewise for height). Default is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"mode",
"The pooling method. Two modes are supported: 'avg' and 'max'. "
"Default is 'avg'.",
AttributeProto::STRING,
std::string("avg"))
.Attr(
"coordinate_transformation_mode",
"Allowed values are 'half_pixel' and 'output_half_pixel'. "
"Use the value 'half_pixel' to pixel shift the input coordinates by -0.5 (the recommended behavior). "
"Use the value 'output_half_pixel' to omit the pixel shift for the input (use this for a "
"backward-compatible behavior).",
AttributeProto::STRING,
std::string("half_pixel"))
.Input(
0,
"X",
"Input data tensor from the previous operator; "
"4-D feature map of shape (N, C, H, W), "
"where N is the batch size, C is the number of channels, "
"and H and W are the height and the width of the data.",
"T1")
.Input(
1,
"rois",
"RoIs (Regions of Interest) to pool over; rois is "
"2-D input of shape (num_rois, 4) given as "
"[[x1, y1, x2, y2], ...]. "
"The RoIs' coordinates are in the coordinate system of the input image. "
"Each coordinate set has a 1:1 correspondence with the 'batch_indices' input.",
"T1")
.Input(
2,
"batch_indices",
"1-D tensor of shape (num_rois,) with each element denoting "
"the index of the corresponding image in the batch.",
"T2")
.Output(
0,
"Y",
"RoI pooled output, 4-D tensor of shape "
"(num_rois, C, output_height, output_width). The r-th batch element Y[r-1] "
"is a pooled feature map corresponding to the r-th RoI X[r-1].",
"T1")
.TypeConstraint("T1", OpSchema::all_float_types_ir4(), "Constrain types to float tensors.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain types to int tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
size_t input_param = 0, rois_param = 1, batch_index_param = 2;
checkInputRank(ctx, input_param, 4);
checkInputRank(ctx, rois_param, 2);
checkInputRank(ctx, batch_index_param, 1);
// Output dimensions, initialized to an unknown-dimension-value
Dim num_rois, C, ht, width;
// Get value of C from dim 1 of input_param, if available
unifyInputDim(ctx, input_param, 1, C);
// Get value of num_rois from dim 0 of rois_param, if available
unifyInputDim(ctx, rois_param, 0, num_rois);
// ... or from dim 0 of batch_index_param, if available
unifyInputDim(ctx, batch_index_param, 0, num_rois);
// Get height from attribute, using default-value of 1
unifyDim(ht, getAttribute(ctx, "output_height", 1));
// Get width from attribute, using default-value of 1
unifyDim(width, getAttribute(ctx, "output_width", 1));
// set output shape:
updateOutputShape(ctx, 0, {num_rois, C, ht, width});
}));
static const char* NonMaxSuppression_ver11_doc = R"DOC(
Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes.
Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box.
Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to
orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system
result in the same boxes being selected by the algorithm.
The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes.
The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
NonMaxSuppression,
11,
OpSchema()
.Input(
0,
"boxes",
"An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.",
"tensor(float)")
.Input(1, "scores", "An input tensor with shape [num_batches, num_classes, spatial_dimension]", "tensor(float)")
.Input(
2,
"max_output_boxes_per_class",
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Default to 0, which means no output.",
"tensor(int64)",
OpSchema::Optional)
.Input(
3,
"iou_threshold",
"Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1]. Default to 0.",
"tensor(float)",
OpSchema::Optional)
.Input(
4,
"score_threshold",
"Float representing the threshold for deciding when to remove boxes based on score. It is a scalar.",
"tensor(float)",
OpSchema::Optional)
.Output(
0,
"selected_indices",
"selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].",
"tensor(int64)")
.Attr(
"center_point_box",
"Integer indicate the format of the box data. The default is 0. "
"0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners "
"and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models. "
"1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models.",
AttributeProto::INT,
static_cast<int64_t>(0))
.SetDoc(NonMaxSuppression_ver11_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference - Output is always of type INT64
auto* selected_indices_type = ctx.getOutputType(0)->mutable_tensor_type();
selected_indices_type->set_elem_type(TensorProto_DataType::TensorProto_DataType_INT64);
// Shape inference
// The exact shape cannot be determined as it depends on the input and
// other input configurations for the op But part of the shape can be
// established
auto* selected_indices_shape = getOutputShape(ctx, 0);
selected_indices_shape->clear_dim();
// Output is 2D always
// The value of the first dim is determined by input data
// hence its value cannot be determined statically
selected_indices_shape->add_dim();
// The value of the second dim is 3
selected_indices_shape->add_dim()->set_dim_value(3);
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,293 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
using namespace ONNX_NAMESPACE;
namespace ONNX_NAMESPACE {
static const char* RoiAlign_ver16_doc = R"DOC(
Region of Interest (RoI) align operation described in the
[Mask R-CNN paper](https://arxiv.org/abs/1703.06870).
RoiAlign consumes an input tensor X and region of interests (rois)
to apply pooling across each RoI; it produces a 4-D tensor of shape
(num_rois, C, output_height, output_width).
RoiAlign is proposed to avoid the misalignment by removing
quantizations while converting from original image into feature
map and from feature map into RoI feature; in each ROI bin,
the value of the sampled locations are computed directly
through bilinear interpolation.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RoiAlign,
16,
OpSchema()
.SetDoc(RoiAlign_ver16_doc)
.Attr(
"spatial_scale",
"Multiplicative spatial scale factor to translate ROI coordinates "
"from their input spatial scale to the scale used when pooling, "
"i.e., spatial scale of the input feature map X relative to the "
"input image. E.g.; default is 1.0f. ",
AttributeProto::FLOAT,
1.f)
.Attr("output_height", "default 1; Pooled output Y's height.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("output_width", "default 1; Pooled output Y's width.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"sampling_ratio",
"Number of sampling points in the interpolation grid used to compute "
"the output value of each pooled output bin. If > 0, then exactly "
"sampling_ratio x sampling_ratio grid points are used. If == 0, then "
"an adaptive number of grid points are used (computed as "
"ceil(roi_width / output_width), and likewise for height). Default is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"mode",
"The pooling method. Two modes are supported: 'avg' and 'max'. "
"Default is 'avg'.",
AttributeProto::STRING,
std::string("avg"))
.Attr(
"coordinate_transformation_mode",
"Allowed values are 'half_pixel' and 'output_half_pixel'. "
"Use the value 'half_pixel' to pixel shift the input coordinates by -0.5 (the recommended behavior). "
"Use the value 'output_half_pixel' to omit the pixel shift for the input (use this for a "
"backward-compatible behavior).",
AttributeProto::STRING,
std::string("half_pixel"))
.Input(
0,
"X",
"Input data tensor from the previous operator; "
"4-D feature map of shape (N, C, H, W), "
"where N is the batch size, C is the number of channels, "
"and H and W are the height and the width of the data.",
"T1")
.Input(
1,
"rois",
"RoIs (Regions of Interest) to pool over; rois is "
"2-D input of shape (num_rois, 4) given as "
"[[x1, y1, x2, y2], ...]. "
"The RoIs' coordinates are in the coordinate system of the input image. "
"Each coordinate set has a 1:1 correspondence with the 'batch_indices' input.",
"T1")
.Input(
2,
"batch_indices",
"1-D tensor of shape (num_rois,) with each element denoting "
"the index of the corresponding image in the batch.",
"T2")
.Output(
0,
"Y",
"RoI pooled output, 4-D tensor of shape "
"(num_rois, C, output_height, output_width). The r-th batch element Y[r-1] "
"is a pooled feature map corresponding to the r-th RoI X[r-1].",
"T1")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain types to float tensors.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain types to int tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
size_t input_param = 0, rois_param = 1, batch_index_param = 2;
checkInputRank(ctx, input_param, 4);
checkInputRank(ctx, rois_param, 2);
checkInputRank(ctx, batch_index_param, 1);
// Output dimensions, initialized to an unknown-dimension-value
Dim num_rois, C, ht, width;
// Get value of C from dim 1 of input_param, if available
unifyInputDim(ctx, input_param, 1, C);
// Get value of num_rois from dim 0 of rois_param, if available
unifyInputDim(ctx, rois_param, 0, num_rois);
// ... or from dim 0 of batch_index_param, if available
unifyInputDim(ctx, batch_index_param, 0, num_rois);
// Get height from attribute, using default-value of 1
unifyDim(ht, getAttribute(ctx, "output_height", 1));
// Get width from attribute, using default-value of 1
unifyDim(width, getAttribute(ctx, "output_width", 1));
// set output shape:
updateOutputShape(ctx, 0, {num_rois, C, ht, width});
}));
static const char* RoiAlign_ver10_doc = R"DOC(
Region of Interest (RoI) align operation described in the
[Mask R-CNN paper](https://arxiv.org/abs/1703.06870).
RoiAlign consumes an input tensor X and region of interests (rois)
to apply pooling across each RoI; it produces a 4-D tensor of shape
(num_rois, C, output_height, output_width).
RoiAlign is proposed to avoid the misalignment by removing
quantizations while converting from original image into feature
map and from feature map into RoI feature; in each ROI bin,
the value of the sampled locations are computed directly
through bilinear interpolation.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RoiAlign,
10,
OpSchema()
.SetDoc(RoiAlign_ver10_doc)
.Attr(
"spatial_scale",
"Multiplicative spatial scale factor to translate ROI coordinates "
"from their input spatial scale to the scale used when pooling, "
"i.e., spatial scale of the input feature map X relative to the "
"input image. E.g.; default is 1.0f. ",
AttributeProto::FLOAT,
1.f)
.Attr("output_height", "default 1; Pooled output Y's height.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr("output_width", "default 1; Pooled output Y's width.", AttributeProto::INT, static_cast<int64_t>(1))
.Attr(
"sampling_ratio",
"Number of sampling points in the interpolation grid used to compute "
"the output value of each pooled output bin. If > 0, then exactly "
"sampling_ratio x sampling_ratio grid points are used. If == 0, then "
"an adaptive number of grid points are used (computed as "
"ceil(roi_width / output_width), and likewise for height). Default is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"mode",
"The pooling method. Two modes are supported: 'avg' and 'max'. "
"Default is 'avg'.",
AttributeProto::STRING,
std::string("avg"))
.Input(
0,
"X",
"Input data tensor from the previous operator; "
"4-D feature map of shape (N, C, H, W), "
"where N is the batch size, C is the number of channels, "
"and H and W are the height and the width of the data.",
"T1")
.Input(
1,
"rois",
"RoIs (Regions of Interest) to pool over; rois is "
"2-D input of shape (num_rois, 4) given as "
"[[x1, y1, x2, y2], ...]. "
"The RoIs' coordinates are in the coordinate system of the input image. "
"Each coordinate set has a 1:1 correspondence with the 'batch_indices' input.",
"T1")
.Input(
2,
"batch_indices",
"1-D tensor of shape (num_rois,) with each element denoting "
"the index of the corresponding image in the batch.",
"T2")
.Output(
0,
"Y",
"RoI pooled output, 4-D tensor of shape "
"(num_rois, C, output_height, output_width). The r-th batch element Y[r-1] "
"is a pooled feature map corresponding to the r-th RoI X[r-1].",
"T1")
.TypeConstraint(
"T1",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain types to float tensors.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain types to int tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
size_t input_param = 0, rois_param = 1, batch_index_param = 2;
checkInputRank(ctx, input_param, 4);
checkInputRank(ctx, rois_param, 2);
checkInputRank(ctx, batch_index_param, 1);
// Output dimensions, initialized to an unknown-dimension-value
Dim num_rois, C, ht, width;
// Get value of C from dim 1 of input_param, if available
unifyInputDim(ctx, input_param, 1, C);
// Get value of num_rois from dim 0 of rois_param, if available
unifyInputDim(ctx, rois_param, 0, num_rois);
// ... or from dim 0 of batch_index_param, if available
unifyInputDim(ctx, batch_index_param, 0, num_rois);
// Get height from attribute, using default-value of 1
unifyDim(ht, getAttribute(ctx, "output_height", 1));
// Get width from attribute, using default-value of 1
unifyDim(width, getAttribute(ctx, "output_width", 1));
// set output shape:
updateOutputShape(ctx, 0, {num_rois, C, ht, width});
}));
static const char* NonMaxSuppression_ver10_doc = R"DOC(
Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes.
Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box.
Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to
orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system
result in the same boxes being selected by the algorithm.
The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes.
The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
NonMaxSuppression,
10,
OpSchema()
.Input(
0,
"boxes",
"An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.",
"tensor(float)")
.Input(1, "scores", "An input tensor with shape [num_batches, num_classes, spatial_dimension]", "tensor(float)")
.Input(
2,
"max_output_boxes_per_class",
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Default to 0, which means no output.",
"tensor(int64)",
OpSchema::Optional)
.Input(
3,
"iou_threshold",
"Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1]. Default to 0.",
"tensor(float)",
OpSchema::Optional)
.Input(
4,
"score_threshold",
"Float representing the threshold for deciding when to remove boxes based on score. It is a scalar.",
"tensor(float)",
OpSchema::Optional)
.Output(
0,
"selected_indices",
"selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].",
"tensor(int64)")
.Attr(
"center_point_box",
"Integer indicate the format of the box data. The default is 0. "
"0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners "
"and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models. "
"1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models.",
AttributeProto::INT,
static_cast<int64_t>(0))
.SetDoc(NonMaxSuppression_ver10_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto selected_indices_type = ctx.getOutputType(0)->mutable_tensor_type();
selected_indices_type->set_elem_type(::ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64);
}));
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,109 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#ifdef ONNX_ML
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
// Forward declarations for ai.onnx.ml version 1
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ArrayFeatureExtractor);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Binarizer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CastMap);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CategoryMapper);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, DictVectorizer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, FeatureVectorizer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Imputer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LabelEncoder);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearClassifier);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearRegressor);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Normalizer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, OneHotEncoder);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMClassifier);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMRegressor);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Scaler);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleClassifier);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleRegressor);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ZipMap);
// Iterate over schema from ai.onnx.ml version 1
class OpSet_OnnxML_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ArrayFeatureExtractor)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Binarizer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CastMap)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, CategoryMapper)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, DictVectorizer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, FeatureVectorizer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Imputer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LabelEncoder)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearClassifier)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, LinearRegressor)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Normalizer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, OneHotEncoder)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMClassifier)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, SVMRegressor)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, Scaler)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleClassifier)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, TreeEnsembleRegressor)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 1, ZipMap)>());
}
};
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 2, LabelEncoder);
class OpSet_OnnxML_ver2 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 2, LabelEncoder)>());
}
};
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleClassifier);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleRegressor);
class OpSet_OnnxML_ver3 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleClassifier)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 3, TreeEnsembleRegressor)>());
}
};
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 4, LabelEncoder);
class OpSet_OnnxML_ver4 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 4, LabelEncoder)>());
}
};
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsemble);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsembleRegressor);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsembleClassifier);
class OpSet_OnnxML_ver5 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsemble)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsembleRegressor)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(OnnxML, 5, TreeEnsembleClassifier)>());
}
};
inline void RegisterOnnxMLOperatorSetSchema() {
RegisterOpSetSchema<OpSet_OnnxML_ver1>();
RegisterOpSetSchema<OpSet_OnnxML_ver2>();
RegisterOpSetSchema<OpSet_OnnxML_ver3>();
RegisterOpSetSchema<OpSet_OnnxML_ver4>();
RegisterOpSetSchema<OpSet_OnnxML_ver5>();
}
} // namespace ONNX_NAMESPACE
#endif

View File

@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
// Declare training operators.
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam);
// Iterate over schema from ai.onnx.training version 1
class OpSet_OnnxPreview_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam)>());
}
};
// Register preview operators.
inline void RegisterOnnxPreviewOperatorSetSchema() {
// Preview operators should have only one version.
// If changes are needed for a specific preview operator,
// its spec should be modified without increasing its version.
RegisterOpSetSchema<OpSet_OnnxPreview_ver1>();
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
// Declare training operators.
// Iterate over schema from ai.onnx.training version 1
class OpSet_OnnxTraining_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> /* fn */) {}
};
// Register training operators.
inline void RegisterOnnxTrainingOperatorSetSchema() {
RegisterOpSetSchema<OpSet_OnnxTraining_ver1>();
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,154 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <numeric>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static std::vector<std::string> optional_and_tensor_types() {
auto optional_types = OpSchema::all_optional_types();
auto tensor_types = OpSchema::all_tensor_types();
auto sequence_types = OpSchema::all_tensor_sequence_types();
optional_types.insert(optional_types.end(), tensor_types.begin(), tensor_types.end());
optional_types.insert(optional_types.end(), sequence_types.begin(), sequence_types.end());
return optional_types;
}
static const char* Optional_ver15_doc = R"DOC(
Constructs an optional-type value containing either an empty optional of a certain type specified by the attribute,
or a non-empty value containing the input element.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
Optional,
15,
OpSchema()
.SetDoc(Optional_ver15_doc)
.Input(0, "input", "The input element.", "V", OpSchema::Optional)
.Attr("type", "Type of the element in the optional output", AttributeProto::TYPE_PROTO, OPTIONAL_VALUE)
.Output(0, "output", "The optional output enclosing the input element.", "O")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain input type to all tensor and sequence types.")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain output type to all optional tensor or optional sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("Optional is expected to have an output.");
}
const size_t numInputs = ctx.getNumInputs();
const auto* attr_proto = ctx.getAttribute("type");
if ((numInputs == 0) && (attr_proto != nullptr)) {
if (!attr_proto->has_tp())
fail_type_inference("Attribute 'type' should be a TypeProto and it should specify a type.");
auto attr_tp = attr_proto->tp();
ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(attr_tp);
} else if (numInputs == 1) {
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Type information is expected for the input.");
}
ctx.getOutputType(0)->mutable_optional_type()->mutable_elem_type()->CopyFrom(*input_type);
} else {
fail_type_inference("Optional is expected to have either an input or the type attribute set.");
}
}));
static const char* OptionalHasElement_ver18_doc = R"DOC(
Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is not provided or is an empty optional-type, this op returns false.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalHasElement,
18,
OpSchema()
.SetDoc(OptionalHasElement_ver18_doc)
.Input(0, "input", "The optional input.", "O", OpSchema::Optional)
.Output(
0,
"output",
"A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.",
"B")
.TypeConstraint(
"O",
optional_and_tensor_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint("B", {"tensor(bool)"}, "Constrain output to a boolean tensor.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 0 && numInputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 0 or 1 input.");
}
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 1 output.");
}
auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type();
output_tensor_type->set_elem_type(TensorProto::BOOL);
output_tensor_type->mutable_shape()->Clear();
}));
static const char* OptionalGetElement_ver18_doc = R"DOC(
If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalGetElement,
18,
OpSchema()
.SetDoc(OptionalGetElement_ver18_doc)
.Input(0, "input", "The optional input.", "O")
.Output(0, "output", "Output element in the optional input.", "V")
.TypeConstraint(
"O",
optional_and_tensor_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain output type to all tensor or sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 1) {
fail_type_inference("OptionalGetElement must have an input element.");
}
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Input must have Type information.");
}
if (input_type->has_optional_type()) {
if (!input_type->optional_type().has_elem_type()) {
fail_type_inference("Optional-type input must contain an element with type information.");
}
ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type());
} else {
propagateShapeAndTypeFromFirstInput(ctx);
}
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,86 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <numeric>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* OptionalHasElement_ver1_doc = R"DOC(
Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalHasElement,
15,
OpSchema()
.SetDoc(OptionalHasElement_ver1_doc)
.Input(0, "input", "The optional input.", "O")
.Output(
0,
"output",
"A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.",
"B")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint("B", {"tensor(bool)"}, "Constrain output to a boolean tensor.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 1 input.");
}
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 1 output.");
}
auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type();
output_tensor_type->set_elem_type(TensorProto::BOOL);
output_tensor_type->mutable_shape()->Clear();
}));
static const char* OptionalGetElement_ver1_doc = R"DOC(
Outputs the element in the optional-type input. It is an error if the input value does not have an element
and the behavior is undefined in this case.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalGetElement,
15,
OpSchema()
.SetDoc(OptionalGetElement_ver1_doc)
.Input(0, "input", "The optional input.", "O")
.Output(0, "output", "Output element in the optional input.", "V")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain output type to all tensor or sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 1) {
fail_type_inference("OptionalGetElement must have an input element.");
}
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Input must have Type information.");
}
if (!input_type->has_optional_type() || !input_type->optional_type().has_elem_type()) {
fail_type_inference("Input must be an optional-type value containing an element with type information.");
}
ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type());
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,895 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
// by this parser is preliminary and may change.
#include "onnx/defs/parser.h"
#include <cctype>
#include <iostream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include "onnx/common/common.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"
#define PARSE_TOKEN(x) CHECK_PARSER_STATUS(ParserBase::Parse(x))
#define PARSE(...) CHECK_PARSER_STATUS(Parse(__VA_ARGS__))
#define MATCH(...) CHECK_PARSER_STATUS(Match(__VA_ARGS__))
namespace ONNX_NAMESPACE {
Status ParserBase::Parse(Literal& result) {
bool decimal_point = false;
auto nextch = NextChar();
auto from = next_;
if (nextch == '"') {
++next_;
bool has_escape = false;
while ((next_ < end_) && (*next_ != '"')) {
if (*next_ == '\\') {
has_escape = true;
++next_;
if (next_ >= end_)
return ParseError("Incomplete string literal.");
}
++next_;
}
if (next_ >= end_)
return ParseError("Incomplete string literal.");
++next_;
result.type = LiteralType::STRING_LITERAL;
if (has_escape) {
std::string& target = result.value;
target.clear();
target.reserve(next_ - from - 2); // upper bound
// *from is the starting quote. *(next_-1) is the ending quote.
// Copy what is in-between, except for the escape character
while (++from < next_ - 1) {
// Copy current char, if not escape, or next char otherwise.
target.push_back(*from != '\\' ? (*from) : *(++from));
}
} else
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
return Status::OK();
}
// Simplify the next ifs by consuming a possible negative sign.
if (nextch == '-') {
++next_;
nextch = NextChar();
}
// Check for float literals that start with alphabet characters.
if (isalpha(nextch)) {
// Has to be a special float literal now: (-)*(nan|inf|infinity).
if (NextIsValidFloatString()) {
while (next_ < end_ && isalpha(*next_)) {
++next_;
}
ONNX_TRY {
static_cast<void>(std::stof(std::string(from, next_ - from)));
result.type = LiteralType::FLOAT_LITERAL;
result.value = std::string(from, next_ - from);
}
ONNX_CATCH(...) {
ONNX_HANDLE_EXCEPTION([&]() { return ParseError("Encountered invalid float literal!"); });
}
} else {
return ParseError("Encountered invalid float literal!");
}
return Status::OK();
}
// Checking for numeric ints or float literal.
if (isdigit(nextch)) {
++next_;
while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
if (*next_ == '.') {
if (decimal_point)
break; // Only one decimal point allowed in numeric literal
decimal_point = true;
}
++next_;
}
if (next_ == from)
return ParseError("Value expected but not found.");
// Optional exponent syntax: (e|E)(+|-)?[0-9]+
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
decimal_point = true; // treat as float-literal
++next_;
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
++next_;
while ((next_ < end_) && (isdigit(*next_)))
++next_;
}
result.value = std::string(from, next_ - from);
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
}
return Status::OK();
}
bool ParserBase::NextIsValidFloatString() {
auto nextch = NextChar();
auto from = next_;
constexpr int INFINITY_LENGTH = 8;
if (isalpha(nextch)) {
while (next_ < end_ && isalpha(*next_) && (next_ - from) <= INFINITY_LENGTH) {
++next_;
}
if (isdigit(*next_)) { // No trailing digits
next_ = from;
return false;
}
std::string candidate = std::string(from, next_ - from);
// Reset parser location before continuing.
next_ = from;
std::transform(
candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char c) { return std::tolower(c); });
if (candidate == std::string("inf") || candidate == std::string("infinity") || candidate == std::string("nan")) {
return true;
}
}
return false;
}
Status OnnxParser::Parse(IdList& idlist) {
idlist.Clear();
std::string id;
ParseOptionalIdentifier(id);
if (id.empty())
return Status::OK(); // Treat as empty list of identifiers
*idlist.Add() = id;
while (Matches(',')) {
ParseOptionalIdentifier(id);
*idlist.Add() = id;
}
return Status::OK();
}
Status OnnxParser::Parse(char open, IdList& idlist, char close) {
idlist.Clear();
if (Matches(open)) {
PARSE(idlist);
MATCH(close);
}
return Status::OK();
}
Status OnnxParser::Parse(IdList& idlist, AttrList& attrlist) {
idlist.Clear();
attrlist.Clear();
do {
std::string id;
ParseIdentifier(id);
auto next = NextChar();
if (next == ':' || next == '=')
Parse(*attrlist.Add(), id);
else
*idlist.Add() = id;
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(char open, IdList& idlist, AttrList& attrlist, char close) {
if (Matches(open)) {
PARSE(idlist, attrlist);
MATCH(close);
} else {
idlist.Clear();
attrlist.Clear();
}
return Status::OK();
}
Status OnnxParser::Parse(TensorShapeProto& shape) {
shape.clear_dim();
do {
if (Matches('?')) {
shape.add_dim();
} else {
// Check for a symbolic identifier ...
std::string id;
CHECK_PARSER_STATUS(ParseOptionalIdentifier(id));
if (!id.empty()) {
shape.add_dim()->set_dim_param(id);
} else {
// ...or a integer value
int64_t dimval = 0;
PARSE_TOKEN(dimval);
shape.add_dim()->set_dim_value(dimval);
}
}
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(TypeProto& typeProto) {
std::string id;
CHECK_PARSER_STATUS(ParseIdentifier(id));
int dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype != 0) {
auto* tensortype = typeProto.mutable_tensor_type();
tensortype->set_elem_type(dtype);
tensortype->clear_shape();
// Grammar:
// float indicates scalar (rank 0)
// float [] indicates unknown rank tensor (not a zero rank tensor)
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
if (Matches('[')) {
if (!Matches(']')) {
PARSE(*tensortype->mutable_shape());
MATCH(']');
}
} else {
// Create shape with zero dimensions for scalar
(void)(tensortype->mutable_shape());
}
} else {
switch (KeyWordMap::Lookup(id)) {
case KeyWordMap::KeyWord::SEQ_TYPE: {
// Grammar: seq ( type )
MATCH('(');
auto* seqtype = typeProto.mutable_sequence_type();
PARSE(*seqtype->mutable_elem_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::MAP_TYPE: {
// Grammar: map ( prim-type , type )
MATCH('(');
auto* maptype = typeProto.mutable_map_type();
CHECK_PARSER_STATUS(ParseIdentifier(id));
dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype == 0) {
return ParseError("Expecting primitive type as map key type.");
}
maptype->set_key_type(dtype);
MATCH(',');
PARSE(*maptype->mutable_value_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::OPTIONAL_TYPE: {
// Grammar: optional ( type )
MATCH('(');
auto* opttype = typeProto.mutable_optional_type();
PARSE(*opttype->mutable_elem_type());
MATCH(')');
break;
}
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE: {
// Grammar: sparse_tensor ( tensor-type )
MATCH('(');
CHECK_PARSER_STATUS(ParseIdentifier(id));
dtype = PrimitiveTypeNameMap::Lookup(id);
if (dtype != 0) {
auto* sparsetype = typeProto.mutable_sparse_tensor_type();
sparsetype->set_elem_type(dtype);
sparsetype->clear_shape();
// Grammar:
// float indicates scalar (rank 0)
// float [] indicates unknown rank tensor (not a zero rank tensor)
// float [one-or-more-dimensions] indicates tensor of known rank > 0.
if (Matches('[')) {
if (!Matches(']')) {
PARSE(*sparsetype->mutable_shape());
MATCH(']');
}
} else {
// Create shape with zero dimensions for scalar
(void)(sparsetype->mutable_shape());
}
} else {
return ParseError("Unexpected type in sparse-tensor element type.");
}
MATCH(')');
break;
}
default:
return ParseError("Unexpected type.");
}
}
return Status::OK();
}
Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
if (NextIsType())
PARSE(*valueinfo.mutable_type());
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
valueinfo.set_name(name);
return Status::OK();
}
Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) {
MATCH(open);
if (!Matches(close)) {
do {
PARSE(*vilist.Add());
} while (Matches(','));
MATCH(close);
}
return Status::OK();
}
Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) {
vilist.Clear();
PARSE('(', vilist, ')');
return Status::OK();
}
Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) {
// Do not clear vilist, as it accumulates values over inputs and outputs.
idlist.Clear();
MATCH('(');
if (!Matches(')')) {
do {
// Function inputs/outputs can be optionally typed.
// Syntax: Name | Type Name
// The name is added to idlist. If the optional type is present, an entry is
// added to vilist.
std::string* name = idlist.Add();
ValueInfoProto* vi = nullptr;
if (NextIsType()) {
vi = vilist.Add();
PARSE(*(vi->mutable_type()));
}
CHECK_PARSER_STATUS(ParseIdentifier(*name));
if (vi != nullptr)
vi->set_name(*name);
} while (Matches(','));
MATCH(')');
}
return Status::OK();
}
// Each input element is a value-info with an optional initializer of the form "= initial-value".
// The value-info is added to the "inputs", while the initializer is added to initializers.
Status OnnxParser::ParseInput(ValueInfoList& inputs, TensorList& initializers) {
inputs.Clear();
if (Matches('(')) {
if (!Matches(')')) {
do {
ValueInfoProto vi;
PARSE(vi);
*inputs.Add() = vi;
if (Matches('=')) {
// default value for input
TensorProto& tp = *initializers.Add();
tp.set_name(vi.name());
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
}
} while (Matches(','));
MATCH(')');
}
}
return Status::OK();
}
// This is handled slightly different from the inputs.
// Each element is either a value-info or an initializer.
// A value-info is added to the "value_infos", while an initializer is added to initializers.
Status OnnxParser::ParseValueInfo(ValueInfoList& value_infos, TensorList& initializers) {
value_infos.Clear();
if (Matches('<')) {
if (!Matches('>')) {
do {
ValueInfoProto vi;
PARSE(vi);
if (Matches('=')) {
// initializer
TensorProto& tp = *initializers.Add();
tp.set_name(vi.name());
CHECK_PARSER_STATUS(Parse(tp, vi.type()));
} else {
// valueinfo
*value_infos.Add() = vi;
}
} while (Matches(','));
MATCH('>');
}
}
return Status::OK();
}
Status OnnxParser::Parse(StringStringList& stringStringList) {
std::string strval;
do {
auto* metadata = stringStringList.Add();
PARSE_TOKEN(strval);
metadata->set_key(strval);
MATCH(':');
PARSE_TOKEN(strval);
metadata->set_value(strval);
} while (Matches(','));
return Status::OK();
}
Status OnnxParser::Parse(TensorProto& tensorProto) {
tensorProto = TensorProto();
// Parse the concrete tensor-type with numeric dimensions:
TypeProto typeProto;
PARSE(typeProto);
ParseOptionalIdentifier(*tensorProto.mutable_name());
(void)Matches('='); // Optional, to unify handling of initializers as well as tensor-protos in other contexts
return Parse(tensorProto, typeProto);
}
// Parse TensorProto data given its type:
Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto) {
if (!tensorTypeProto.has_tensor_type())
return ParseError("Error parsing TensorProto (expected a tensor type).");
auto elem_type = tensorTypeProto.tensor_type().elem_type();
tensorProto.set_data_type(elem_type);
if (!tensorTypeProto.tensor_type().has_shape())
return ParseError("Error parsing TensorProto (expected a tensor shape).");
for (auto& dim : tensorTypeProto.tensor_type().shape().dim()) {
if (!dim.has_dim_value())
return ParseError("Error parsing TensorProto shape (expected numeric dimension).");
auto dimval = dim.dim_value();
tensorProto.add_dims(dimval);
}
// tensorProto.mutable_int64_data()->Reserve(n);
// Parse the actual values:
int64_t intval;
uint64_t uintval = 0;
float floatval = 0.0;
double dblval = 0.0;
std::string strval;
if (Matches('{')) {
if (!Matches('}')) {
do {
switch (static_cast<TensorProto::DataType>(elem_type)) {
case TensorProto::DataType::TensorProto_DataType_INT4:
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT4:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_FLOAT16:
case TensorProto::DataType::TensorProto_DataType_BFLOAT16:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FN:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
case TensorProto::DataType::TensorProto_DataType_BOOL:
PARSE_TOKEN(intval);
// TODO: check values are in the correct range.
tensorProto.add_int32_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
PARSE_TOKEN(intval);
tensorProto.add_int64_data(intval);
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
PARSE_TOKEN(uintval);
tensorProto.add_uint64_data(uintval);
break;
case TensorProto::DataType::TensorProto_DataType_COMPLEX64:
case TensorProto::DataType::TensorProto_DataType_FLOAT:
PARSE_TOKEN(floatval);
tensorProto.add_float_data(floatval);
break;
case TensorProto::DataType::TensorProto_DataType_COMPLEX128:
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
PARSE_TOKEN(dblval);
tensorProto.add_double_data(dblval);
break;
case TensorProto::DataType::TensorProto_DataType_STRING:
PARSE_TOKEN(strval);
tensorProto.add_string_data(strval);
break;
default:
return ParseError("Unhandled type: %d", elem_type);
}
} while (Matches(','));
MATCH('}');
}
} else if (Matches('[')) {
tensorProto.set_data_location(TensorProto::DataLocation::TensorProto_DataLocation_EXTERNAL);
auto& externalData = *tensorProto.mutable_external_data();
PARSE(externalData);
MATCH(']');
}
return Status::OK();
}
bool OnnxParser::NextIsIdentifier() {
std::string id("");
(void)PeekIdentifier(id);
return !(id.empty());
}
bool OnnxParser::NextIsType() {
std::string id("");
(void)PeekIdentifier(id);
if (PrimitiveTypeNameMap::IsTypeName(id))
return true;
switch (KeyWordMap::Lookup(id)) {
case KeyWordMap::KeyWord::SEQ_TYPE:
case KeyWordMap::KeyWord::MAP_TYPE:
case KeyWordMap::KeyWord::OPTIONAL_TYPE:
case KeyWordMap::KeyWord::SPARSE_TENSOR_TYPE:
return true;
default:
return false;
}
}
Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected) {
// Parse a single-value
auto next = NextChar();
if (isalpha(next) || next == '_') {
if (NextIsType()) {
TypeProto typeProto;
Parse(typeProto);
next = NextChar();
if ((next == '{') || (next == '=') || (NextIsIdentifier())) {
attr.set_type(AttributeProto_AttributeType_TENSOR);
auto& tensorProto = *attr.mutable_t();
ParseOptionalIdentifier(*tensorProto.mutable_name());
(void)Matches('='); // Optional, to unify handling of initializers
Parse(tensorProto, typeProto);
} else {
attr.set_type(AttributeProto_AttributeType_TYPE_PROTO);
attr.mutable_tp()->CopyFrom(typeProto);
}
} else {
if (NextIsValidFloatString()) {
Literal literal;
PARSE_TOKEN(literal);
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(std::stof(literal.value)));
} else {
attr.set_type(AttributeProto_AttributeType_GRAPH);
PARSE(*attr.mutable_g());
}
}
} else if (Matches('@')) {
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
attr.set_ref_attr_name(name);
} else {
Literal literal;
PARSE_TOKEN(literal);
switch (literal.type) {
case LiteralType::INT_LITERAL:
attr.set_type(AttributeProto_AttributeType_INT);
attr.set_i(std::stol(literal.value));
break;
case LiteralType::FLOAT_LITERAL:
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(std::stof(literal.value)));
break;
case LiteralType::STRING_LITERAL:
attr.set_type(AttributeProto_AttributeType_STRING);
attr.set_s(literal.value);
break;
}
}
if ((expected != AttributeProto_AttributeType_UNDEFINED) && (expected != attr.type())) {
// Mismatch between type-annotation and attribute-value. We do an implicit cast
// only in the special case of FLOAT type and integral value like 2
if ((expected == AttributeProto_AttributeType_FLOAT) && (attr.type() == AttributeProto_AttributeType_INT)) {
attr.set_type(AttributeProto_AttributeType_FLOAT);
attr.set_f(static_cast<float>(attr.i()));
} else {
return ParseError(
"Mismatch between expected type ",
AttributeProto_AttributeType_Name(expected),
" and specified value's type",
AttributeProto_AttributeType_Name(attr.type()));
}
}
return Status::OK();
}
Status OnnxParser::Parse(AttributeProto& attr) {
attr.Clear();
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
return Parse(attr, name);
}
bool IsSingletonAttribute(AttributeProto_AttributeType type) {
switch (type) {
case AttributeProto_AttributeType_FLOAT:
case AttributeProto_AttributeType_INT:
case AttributeProto_AttributeType_STRING:
case AttributeProto_AttributeType_TENSOR:
case AttributeProto_AttributeType_GRAPH:
case AttributeProto_AttributeType_SPARSE_TENSOR:
case AttributeProto_AttributeType_TYPE_PROTO:
return true;
default:
return false;
}
}
AttributeProto_AttributeType ToSingletonType(AttributeProto_AttributeType type) {
switch (type) {
case AttributeProto_AttributeType_FLOATS:
return AttributeProto_AttributeType_FLOAT;
case AttributeProto_AttributeType_INTS:
return AttributeProto_AttributeType_INT;
case AttributeProto_AttributeType_STRINGS:
return AttributeProto_AttributeType_STRING;
case AttributeProto_AttributeType_TENSORS:
return AttributeProto_AttributeType_TENSOR;
case AttributeProto_AttributeType_GRAPHS:
return AttributeProto_AttributeType_GRAPH;
case AttributeProto_AttributeType_SPARSE_TENSORS:
return AttributeProto_AttributeType_SPARSE_TENSOR;
case AttributeProto_AttributeType_TYPE_PROTOS:
return AttributeProto_AttributeType_TYPE_PROTO;
default:
return type;
}
}
Status OnnxParser::Parse(AttributeProto& attr, std::string& name) {
attr.set_name(name);
if (Matches(':')) {
CHECK_PARSER_STATUS(ParseIdentifier(name));
int attrtype = AttributeTypeNameMap::Lookup(name);
if (attrtype != 0) {
attr.set_type(static_cast<AttributeProto_AttributeType>(attrtype));
} else {
return ParseError("Unexpected attribute type.");
}
}
MATCH('=');
if (NextChar() == '[') {
// Parse a list of values. For an empty list, the type MUST be specified
// using the type-annotation syntax of ": type".
MATCH('[');
if (NextChar() != ']') {
do {
AttributeProto nextval;
auto expected_type = ToSingletonType(attr.type());
CHECK_PARSER_STATUS(ParseSingleAttributeValue(nextval, expected_type));
switch (nextval.type()) {
case AttributeProto_AttributeType_INT:
attr.set_type(AttributeProto_AttributeType_INTS);
attr.add_ints(nextval.i());
break;
case AttributeProto_AttributeType_FLOAT:
attr.set_type(AttributeProto_AttributeType_FLOATS);
attr.add_floats(nextval.f());
break;
case AttributeProto_AttributeType_STRING:
attr.add_strings(nextval.s());
attr.set_type(AttributeProto_AttributeType_STRINGS);
break;
default:
break;
}
} while (Matches(','));
} else {
if (attr.type() == AttributeProto_AttributeType_UNDEFINED)
return ParseError("Empty list attribute value requires type annotation.");
if (IsSingletonAttribute(attr.type()))
return ParseError("Singleton attribute value cannot be specified as a list.");
}
MATCH(']');
} else {
CHECK_PARSER_STATUS(ParseSingleAttributeValue(attr, attr.type()));
}
return Status::OK();
}
Status OnnxParser::Parse(AttrList& attrlist) {
attrlist.Clear();
if (Matches('<')) {
do {
PARSE(*attrlist.Add());
} while (Matches(','));
MATCH('>');
}
return Status::OK();
}
Status OnnxParser::Parse(NodeProto& node) {
PARSE(*node.mutable_output());
MATCH('=');
std::string domain("");
std::string id;
ParseIdentifier(id);
while (Matches('.')) {
if (!domain.empty())
domain += ".";
domain += id;
ParseIdentifier(id);
}
node.set_domain(domain);
node.set_op_type(id);
if (Matches(':')) {
std::string overload;
ParseIdentifier(overload);
node.set_overload(overload);
}
PARSE(*node.mutable_attribute());
MATCH('(');
PARSE(*node.mutable_input());
MATCH(')');
if (node.attribute_size() == 0) {
// Permit attributes to be specified before or after parameters.
PARSE(*node.mutable_attribute());
}
return Status::OK();
}
Status OnnxParser::Parse(NodeList& nodelist) {
nodelist.Clear();
MATCH('{');
while (!Matches('}')) {
PARSE(*nodelist.Add());
}
return Status::OK();
}
Status OnnxParser::Parse(GraphProto& graph) {
std::string id;
ParseIdentifier(id);
return Parse(id, graph);
}
Status OnnxParser::Parse(std::string name, GraphProto& graph) {
graph.set_name(name);
graph.mutable_initializer()->Clear();
CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
MATCH('=');
MATCH('>', false);
CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output()));
CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
return Parse(*graph.mutable_node());
}
Status OnnxParser::Parse(FunctionProto& fn) {
fn.Clear();
std::string strval;
if (Matches('<')) {
do {
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
PARSE_TOKEN(keyword);
MATCH(':');
switch (keyword) {
case KeyWordMap::KeyWord::OPSET_IMPORT:
PARSE(*fn.mutable_opset_import());
break;
case KeyWordMap::KeyWord::DOC_STRING:
PARSE_TOKEN(strval);
fn.set_doc_string(strval);
break;
case KeyWordMap::KeyWord::DOMAIN_KW:
PARSE_TOKEN(strval);
fn.set_domain(strval);
break;
case KeyWordMap::KeyWord::OVERLOAD_KW:
PARSE_TOKEN(strval);
fn.set_overload(strval);
break;
default:
return ParseError("Unhandled keyword.");
}
} while (Matches(','));
MATCH('>');
}
std::string id;
ParseIdentifier(id);
fn.set_name(id);
PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>');
fn.mutable_value_info()->Clear();
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info()));
MATCH('=');
MATCH('>', false);
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info()));
if (NextChar() == '<') {
PARSE('<', *fn.mutable_value_info(), '>');
}
return Parse(*fn.mutable_node());
}
Status OnnxParser::Parse(OpsetIdList& opsets) {
std::string strval;
int64_t intval = 0;
MATCH('[');
if (!Matches(']')) {
do {
auto* import = opsets.Add();
PARSE_TOKEN(strval);
import->set_domain(strval);
MATCH(':');
PARSE_TOKEN(intval);
import->set_version(intval);
} while (Matches(','));
MATCH(']');
}
return Status::OK();
}
Status OnnxParser::Parse(ModelProto& model) {
model.Clear();
std::string strval;
int64_t intval;
if (Matches('<')) {
do {
KeyWordMap::KeyWord keyword = KeyWordMap::KeyWord::NONE;
PARSE_TOKEN(keyword);
MATCH(':');
switch (keyword) {
case KeyWordMap::KeyWord::IR_VERSION:
PARSE_TOKEN(intval);
model.set_ir_version(intval);
break;
case KeyWordMap::KeyWord::OPSET_IMPORT:
PARSE(*model.mutable_opset_import());
break;
case KeyWordMap::KeyWord::PRODUCER_NAME:
PARSE_TOKEN(strval);
model.set_producer_name(strval);
break;
case KeyWordMap::KeyWord::PRODUCER_VERSION:
PARSE_TOKEN(strval);
model.set_producer_version(strval);
break;
case KeyWordMap::KeyWord::DOMAIN_KW:
PARSE_TOKEN(strval);
model.set_domain(strval);
break;
case KeyWordMap::KeyWord::MODEL_VERSION:
PARSE_TOKEN(intval);
model.set_model_version(intval);
break;
case KeyWordMap::KeyWord::DOC_STRING:
PARSE_TOKEN(strval);
model.set_doc_string(strval);
break;
case KeyWordMap::KeyWord::METADATA_PROPS: {
auto& metadata_props = *model.mutable_metadata_props();
MATCH('[');
if (!Matches(']')) {
PARSE(metadata_props);
MATCH(']');
}
break;
}
default:
return ParseError("Unhandled keyword.");
}
} while (Matches(','));
MATCH('>');
}
PARSE(*model.mutable_graph());
auto* functions = model.mutable_functions();
while (!EndOfInput()) {
PARSE(*functions->Add());
}
return Status::OK();
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,457 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
// Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
// by this parser is preliminary and may change.
#pragma once
#include <ctype.h>
#include <iostream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include "onnx/common/status.h"
#include "onnx/onnx_pb.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
using namespace ONNX_NAMESPACE::Common;
using IdList = google::protobuf::RepeatedPtrField<std::string>;
using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;
using StringStringList = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
#define CHECK_PARSER_STATUS(status) \
{ \
auto local_status_ = status; \
if (!local_status_.IsOK()) \
return local_status_; \
}
template <typename Map>
class StringIntMap {
public:
static const std::unordered_map<std::string, int32_t>& Instance() {
static Map instance;
return instance.map_;
}
static int32_t Lookup(const std::string& dtype) {
auto it = Instance().find(dtype);
if (it != Instance().end())
return it->second;
return 0;
}
static const std::string& ToString(int32_t dtype) {
static std::string undefined("undefined");
for (const auto& pair : Instance()) {
if (pair.second == dtype)
return pair.first;
}
return undefined;
}
protected:
std::unordered_map<std::string, int32_t> map_;
};
class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
public:
PrimitiveTypeNameMap() : StringIntMap() {
map_["float"] = TensorProto_DataType_FLOAT;
map_["uint8"] = TensorProto_DataType_UINT8;
map_["int8"] = TensorProto_DataType_INT8;
map_["uint16"] = TensorProto_DataType_UINT16;
map_["int16"] = TensorProto_DataType_INT16;
map_["int32"] = TensorProto_DataType_INT32;
map_["int64"] = TensorProto_DataType_INT64;
map_["string"] = TensorProto_DataType_STRING;
map_["bool"] = TensorProto_DataType_BOOL;
map_["float16"] = TensorProto_DataType_FLOAT16;
map_["double"] = TensorProto_DataType_DOUBLE;
map_["uint32"] = TensorProto_DataType_UINT32;
map_["uint64"] = TensorProto_DataType_UINT64;
map_["complex64"] = TensorProto_DataType_COMPLEX64;
map_["complex128"] = TensorProto_DataType_COMPLEX128;
map_["bfloat16"] = TensorProto_DataType_BFLOAT16;
map_["float8e4m3fn"] = TensorProto_DataType_FLOAT8E4M3FN;
map_["float8e4m3fnuz"] = TensorProto_DataType_FLOAT8E4M3FNUZ;
map_["float8e5m2"] = TensorProto_DataType_FLOAT8E5M2;
map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
map_["uint4"] = TensorProto_DataType_UINT4;
map_["int4"] = TensorProto_DataType_INT4;
}
static bool IsTypeName(const std::string& dtype) {
return Lookup(dtype) != 0;
}
};
class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
public:
AttributeTypeNameMap() : StringIntMap() {
map_["float"] = AttributeProto_AttributeType_FLOAT;
map_["int"] = AttributeProto_AttributeType_INT;
map_["string"] = AttributeProto_AttributeType_STRING;
map_["tensor"] = AttributeProto_AttributeType_TENSOR;
map_["graph"] = AttributeProto_AttributeType_GRAPH;
map_["sparse_tensor"] = AttributeProto_AttributeType_SPARSE_TENSOR;
map_["type_proto"] = AttributeProto_AttributeType_TYPE_PROTO;
map_["floats"] = AttributeProto_AttributeType_FLOATS;
map_["ints"] = AttributeProto_AttributeType_INTS;
map_["strings"] = AttributeProto_AttributeType_STRINGS;
map_["tensors"] = AttributeProto_AttributeType_TENSORS;
map_["graphs"] = AttributeProto_AttributeType_GRAPHS;
map_["sparse_tensors"] = AttributeProto_AttributeType_SPARSE_TENSORS;
map_["type_protos"] = AttributeProto_AttributeType_TYPE_PROTOS;
}
};
class KeyWordMap {
public:
enum class KeyWord {
NONE,
IR_VERSION,
OPSET_IMPORT,
PRODUCER_NAME,
PRODUCER_VERSION,
DOMAIN_KW,
MODEL_VERSION,
DOC_STRING,
METADATA_PROPS,
SEQ_TYPE,
MAP_TYPE,
OPTIONAL_TYPE,
SPARSE_TENSOR_TYPE,
OVERLOAD_KW
};
KeyWordMap() {
map_["ir_version"] = KeyWord::IR_VERSION;
map_["opset_import"] = KeyWord::OPSET_IMPORT;
map_["producer_name"] = KeyWord::PRODUCER_NAME;
map_["producer_version"] = KeyWord::PRODUCER_VERSION;
map_["domain"] = KeyWord::DOMAIN_KW;
map_["model_version"] = KeyWord::MODEL_VERSION;
map_["doc_string"] = KeyWord::DOC_STRING;
map_["metadata_props"] = KeyWord::METADATA_PROPS;
map_["seq"] = KeyWord::SEQ_TYPE;
map_["map"] = KeyWord::MAP_TYPE;
map_["optional"] = KeyWord::OPTIONAL_TYPE;
map_["sparse_tensor"] = KeyWord::SPARSE_TENSOR_TYPE;
map_["overload"] = KeyWord::OVERLOAD_KW;
}
static const std::unordered_map<std::string, KeyWord>& Instance() {
static KeyWordMap instance;
return instance.map_;
}
static KeyWord Lookup(const std::string& id) {
auto it = Instance().find(id);
if (it != Instance().end())
return it->second;
return KeyWord::NONE;
}
static const std::string& ToString(KeyWord kw) {
static std::string undefined("undefined");
for (const auto& pair : Instance()) {
if (pair.second == kw)
return pair.first;
}
return undefined;
}
private:
std::unordered_map<std::string, KeyWord> map_;
};
class ParserBase {
public:
ParserBase(const std::string& str)
: start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
void SavePos() {
saved_pos_ = next_;
}
void RestorePos() {
next_ = saved_pos_;
}
std::string GetCurrentPos() {
uint32_t line = 1, col = 1;
for (const char* p = start_; p < next_; ++p) {
if (*p == '\n') {
++line;
col = 1;
} else {
++col;
}
}
return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
}
// Return a suitable suffix of what has been parsed to provide error message context:
// return the line containing the last non-space character preceding the error (if it exists).
std::string GetErrorContext() {
// Special cases: empty input string, and parse-error at first character.
const char* p = next_ < end_ ? next_ : next_ - 1;
while ((p > start_) && isspace(*p))
--p;
while ((p > start_) && (*p != '\n'))
--p;
// Start at character after '\n' unless we are at start of input
const char* context_start = (p > start_) ? (p + 1) : start_;
for (p = context_start; (p < end_) && (*p != '\n'); ++p)
;
return std::string(context_start, p - context_start);
}
template <typename... Args>
Status ParseError(const Args&... args) {
return Status(
NONE,
FAIL,
ONNX_NAMESPACE::MakeString(
"[ParseError at position ", GetCurrentPos(), "]\n", "Error context: ", GetErrorContext(), "\n", args...));
}
void SkipWhiteSpace() {
do {
while ((next_ < end_) && (isspace(*next_)))
++next_;
if ((next_ >= end_) || ((*next_) != '#'))
return;
// Skip rest of the line:
while ((next_ < end_) && ((*next_) != '\n'))
++next_;
} while (true);
}
int NextChar(bool skipspace = true) {
if (skipspace)
SkipWhiteSpace();
return (next_ < end_) ? *next_ : 0;
}
bool Matches(char ch, bool skipspace = true) {
if (skipspace)
SkipWhiteSpace();
if ((next_ < end_) && (*next_ == ch)) {
++next_;
return true;
}
return false;
}
Status Match(char ch, bool skipspace = true) {
if (!Matches(ch, skipspace))
return ParseError("Expected character ", ch, " not found.");
return Status::OK();
}
bool EndOfInput() {
SkipWhiteSpace();
return (next_ >= end_);
}
enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
struct Literal {
LiteralType type;
std::string value;
};
Status Parse(Literal& result);
Status Parse(int64_t& val) {
Literal literal;
CHECK_PARSER_STATUS(Parse(literal));
if (literal.type != LiteralType::INT_LITERAL)
return ParseError("Integer value expected, but not found.");
std::string s = literal.value;
val = std::stoll(s);
return Status::OK();
}
Status Parse(uint64_t& val) {
Literal literal;
CHECK_PARSER_STATUS(Parse(literal));
if (literal.type != LiteralType::INT_LITERAL)
return ParseError("Integer value expected, but not found.");
std::string s = literal.value;
val = std::stoull(s);
return Status::OK();
}
Status Parse(float& val) {
Literal literal;
CHECK_PARSER_STATUS(Parse(literal));
switch (literal.type) {
case LiteralType::INT_LITERAL:
case LiteralType::FLOAT_LITERAL:
val = std::stof(literal.value);
break;
default:
return ParseError("Unexpected literal type.");
}
return Status::OK();
}
Status Parse(double& val) {
Literal literal;
CHECK_PARSER_STATUS(Parse(literal));
switch (literal.type) {
case LiteralType::INT_LITERAL:
case LiteralType::FLOAT_LITERAL:
val = std::stod(literal.value);
break;
default:
return ParseError("Unexpected literal type.");
}
return Status::OK();
}
// Parse a string-literal enclosed within doube-quotes.
Status Parse(std::string& val) {
Literal literal;
CHECK_PARSER_STATUS(Parse(literal));
if (literal.type != LiteralType::STRING_LITERAL)
return ParseError("String value expected, but not found.");
val = literal.value;
return Status::OK();
}
// Parse an identifier, including keywords. If none found, this will
// return an empty-string identifier.
Status ParseOptionalIdentifier(std::string& id) {
SkipWhiteSpace();
auto from = next_;
if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
++next_;
while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
++next_;
}
id = std::string(from, next_ - from);
return Status::OK();
}
Status ParseIdentifier(std::string& id) {
ParseOptionalIdentifier(id);
if (id.empty())
return ParseError("Identifier expected but not found.");
return Status::OK();
}
Status PeekIdentifier(std::string& id) {
SavePos();
ParseOptionalIdentifier(id);
RestorePos();
return Status::OK();
}
Status Parse(KeyWordMap::KeyWord& keyword) {
std::string id;
CHECK_PARSER_STATUS(ParseIdentifier(id));
keyword = KeyWordMap::Lookup(id);
return Status::OK();
}
protected:
const char* start_;
const char* next_;
const char* end_;
const char* saved_pos_;
bool NextIsValidFloatString();
};
class OnnxParser : public ParserBase {
public:
OnnxParser(const char* cstr) : ParserBase(cstr) {}
Status Parse(TensorShapeProto& shape);
Status Parse(TypeProto& typeProto);
Status Parse(StringStringList& stringStringList);
Status Parse(TensorProto& tensorProto);
Status Parse(AttributeProto& attr);
Status Parse(AttributeProto& attr, std::string& name);
Status Parse(AttrList& attrlist);
Status Parse(NodeProto& node);
Status Parse(NodeList& nodelist);
Status Parse(GraphProto& graph);
Status Parse(FunctionProto& fn);
Status Parse(ModelProto& model);
template <typename T>
static Status Parse(T& parsedData, const char* input) {
OnnxParser parser(input);
return parser.Parse(parsedData);
}
private:
Status Parse(std::string name, GraphProto& graph);
Status Parse(IdList& idlist);
Status Parse(char open, IdList& idlist, char close);
Status Parse(IdList& idlist, AttrList& attrlist);
Status Parse(char open, IdList& idlist, AttrList& attrlist, char close);
Status ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected);
Status Parse(ValueInfoProto& valueinfo);
Status ParseGraphInputOutput(ValueInfoList& vilist);
Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist);
Status Parse(char open, ValueInfoList& vilist, char close);
Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
Status Parse(OpsetIdList& opsets);
bool NextIsType();
bool NextIsIdentifier();
};
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,473 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/printer.h"
#include <iomanip>
#include <vector>
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
using StringStringEntryProtos = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
class ProtoPrinter {
public:
ProtoPrinter(std::ostream& os) : output_(os) {}
void print(const TensorShapeProto_Dimension& dim);
void print(const TensorShapeProto& shape);
void print(const TypeProto_Tensor& tensortype);
void print(const TypeProto& type);
void print(const TypeProto_Sequence& seqType);
void print(const TypeProto_Map& mapType);
void print(const TypeProto_Optional& optType);
void print(const TypeProto_SparseTensor& sparseType);
void print(const TensorProto& tensor, bool is_initializer = false);
void print(const ValueInfoProto& value_info);
void print(const ValueInfoList& vilist);
void print(const AttributeProto& attr);
void print(const AttrList& attrlist);
void print(const NodeProto& node);
void print(const NodeList& nodelist);
void print(const GraphProto& graph);
void print(const FunctionProto& fn);
void print(const ModelProto& model);
void print(const OperatorSetIdProto& opset);
void print(const OpsetIdList& opsets);
void print(const StringStringEntryProtos& stringStringProtos) {
printSet("[", ", ", "]", stringStringProtos);
}
void print(const StringStringEntryProto& metadata) {
printQuoted(metadata.key());
output_ << ": ";
printQuoted(metadata.value());
}
private:
template <typename T>
inline void print(T prim) {
output_ << prim;
}
void printQuoted(const std::string& str) {
output_ << "\"";
for (const char* p = str.c_str(); *p; ++p) {
if ((*p == '\\') || (*p == '"'))
output_ << '\\';
output_ << *p;
}
output_ << "\"";
}
template <typename T>
inline void printKeyValuePair(KeyWordMap::KeyWord key, const T& val, bool addsep = true) {
if (addsep)
output_ << "," << std::endl;
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
print(val);
}
inline void printKeyValuePair(KeyWordMap::KeyWord key, const std::string& val) {
output_ << "," << std::endl;
output_ << std::setw(indent_level) << ' ' << KeyWordMap::ToString(key) << ": ";
printQuoted(val);
}
template <typename Collection>
inline void printSet(const char* open, const char* separator, const char* close, Collection coll) {
const char* sep = "";
output_ << open;
for (auto& elt : coll) {
output_ << sep;
print(elt);
sep = separator;
}
output_ << close;
}
std::ostream& output_;
int indent_level = 3;
void indent() {
indent_level += 3;
}
void outdent() {
indent_level -= 3;
}
};
void ProtoPrinter::print(const TensorShapeProto_Dimension& dim) {
if (dim.has_dim_value())
output_ << dim.dim_value();
else if (dim.has_dim_param())
output_ << dim.dim_param();
else
output_ << "?";
}
void ProtoPrinter::print(const TensorShapeProto& shape) {
printSet("[", ",", "]", shape.dim());
}
void ProtoPrinter::print(const TypeProto_Tensor& tensortype) {
output_ << PrimitiveTypeNameMap::ToString(tensortype.elem_type());
if (tensortype.has_shape()) {
if (tensortype.shape().dim_size() > 0)
print(tensortype.shape());
} else
output_ << "[]";
}
void ProtoPrinter::print(const TypeProto_Sequence& seqType) {
output_ << "seq(";
print(seqType.elem_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_Map& mapType) {
output_ << "map(" << PrimitiveTypeNameMap::ToString(mapType.key_type()) << ", ";
print(mapType.value_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_Optional& optType) {
output_ << "optional(";
print(optType.elem_type());
output_ << ")";
}
void ProtoPrinter::print(const TypeProto_SparseTensor& sparseType) {
output_ << "sparse_tensor(" << PrimitiveTypeNameMap::ToString(sparseType.elem_type());
if (sparseType.has_shape()) {
if (sparseType.shape().dim_size() > 0)
print(sparseType.shape());
} else
output_ << "[]";
output_ << ")";
}
void ProtoPrinter::print(const TypeProto& type) {
if (type.has_tensor_type())
print(type.tensor_type());
else if (type.has_sequence_type())
print(type.sequence_type());
else if (type.has_map_type())
print(type.map_type());
else if (type.has_optional_type())
print(type.optional_type());
else if (type.has_sparse_tensor_type())
print(type.sparse_tensor_type());
}
void ProtoPrinter::print(const TensorProto& tensor, bool is_initializer) {
output_ << PrimitiveTypeNameMap::ToString(tensor.data_type());
if (tensor.dims_size() > 0)
printSet("[", ",", "]", tensor.dims());
if (!tensor.name().empty()) {
output_ << " " << tensor.name();
}
if (is_initializer) {
output_ << " = ";
}
// TODO: does not yet handle all types
if (tensor.has_data_location() && tensor.data_location() == TensorProto_DataLocation_EXTERNAL) {
print(tensor.external_data());
} else if (tensor.has_raw_data()) {
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
case TensorProto::DataType::TensorProto_DataType_INT32:
printSet(" {", ",", "}", ParseData<int32_t>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
printSet(" {", ",", "}", ParseData<int64_t>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
printSet(" {", ",", "}", ParseData<float>(&tensor));
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
printSet(" {", ",", "}", ParseData<double>(&tensor));
break;
default:
output_ << "..."; // ParseData not instantiated for other types.
break;
}
} else {
switch (static_cast<TensorProto::DataType>(tensor.data_type())) {
case TensorProto::DataType::TensorProto_DataType_INT8:
case TensorProto::DataType::TensorProto_DataType_INT16:
case TensorProto::DataType::TensorProto_DataType_INT32:
case TensorProto::DataType::TensorProto_DataType_UINT8:
case TensorProto::DataType::TensorProto_DataType_UINT16:
case TensorProto::DataType::TensorProto_DataType_BOOL:
printSet(" {", ",", "}", tensor.int32_data());
break;
case TensorProto::DataType::TensorProto_DataType_INT64:
printSet(" {", ",", "}", tensor.int64_data());
break;
case TensorProto::DataType::TensorProto_DataType_UINT32:
case TensorProto::DataType::TensorProto_DataType_UINT64:
printSet(" {", ",", "}", tensor.uint64_data());
break;
case TensorProto::DataType::TensorProto_DataType_FLOAT:
printSet(" {", ",", "}", tensor.float_data());
break;
case TensorProto::DataType::TensorProto_DataType_DOUBLE:
printSet(" {", ",", "}", tensor.double_data());
break;
case TensorProto::DataType::TensorProto_DataType_STRING: {
const char* sep = "{";
for (auto& elt : tensor.string_data()) {
output_ << sep;
printQuoted(elt);
sep = ", ";
}
output_ << "}";
break;
}
default:
break;
}
}
}
void ProtoPrinter::print(const ValueInfoProto& value_info) {
print(value_info.type());
output_ << " " << value_info.name();
}
void ProtoPrinter::print(const ValueInfoList& vilist) {
printSet("(", ", ", ")", vilist);
}
void ProtoPrinter::print(const AttributeProto& attr) {
// Special case of attr-ref:
if (attr.has_ref_attr_name()) {
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = @" << attr.ref_attr_name();
return;
}
// General case:
output_ << attr.name() << ": " << AttributeTypeNameMap::ToString(attr.type()) << " = ";
switch (attr.type()) {
case AttributeProto_AttributeType_INT:
output_ << attr.i();
break;
case AttributeProto_AttributeType_INTS:
printSet("[", ", ", "]", attr.ints());
break;
case AttributeProto_AttributeType_FLOAT:
output_ << attr.f();
break;
case AttributeProto_AttributeType_FLOATS:
printSet("[", ", ", "]", attr.floats());
break;
case AttributeProto_AttributeType_STRING:
output_ << "\"" << attr.s() << "\"";
break;
case AttributeProto_AttributeType_STRINGS: {
const char* sep = "[";
for (auto& elt : attr.strings()) {
output_ << sep << "\"" << elt << "\"";
sep = ", ";
}
output_ << "]";
break;
}
case AttributeProto_AttributeType_GRAPH:
indent();
print(attr.g());
outdent();
break;
case AttributeProto_AttributeType_GRAPHS:
indent();
printSet("[", ", ", "]", attr.graphs());
outdent();
break;
case AttributeProto_AttributeType_TENSOR:
print(attr.t());
break;
case AttributeProto_AttributeType_TENSORS:
printSet("[", ", ", "]", attr.tensors());
break;
case AttributeProto_AttributeType_TYPE_PROTO:
print(attr.tp());
break;
case AttributeProto_AttributeType_TYPE_PROTOS:
printSet("[", ", ", "]", attr.type_protos());
break;
default:
break;
}
}
void ProtoPrinter::print(const AttrList& attrlist) {
printSet(" <", ", ", ">", attrlist);
}
void ProtoPrinter::print(const NodeProto& node) {
output_ << std::setw(indent_level) << ' ';
printSet("", ", ", "", node.output());
output_ << " = ";
if (node.domain() != "")
output_ << node.domain() << ".";
output_ << node.op_type();
if (node.overload() != "")
output_ << ":" << node.overload();
bool has_subgraph = false;
for (auto attr : node.attribute())
if (attr.has_g() || (attr.graphs_size() > 0))
has_subgraph = true;
if ((!has_subgraph) && (node.attribute_size() > 0))
print(node.attribute());
printSet(" (", ", ", ")", node.input());
if ((has_subgraph) && (node.attribute_size() > 0))
print(node.attribute());
output_ << "\n";
}
void ProtoPrinter::print(const NodeList& nodelist) {
output_ << "{\n";
for (auto& node : nodelist) {
print(node);
}
if (indent_level > 3)
output_ << std::setw(indent_level - 3) << " ";
output_ << "}";
}
void ProtoPrinter::print(const GraphProto& graph) {
output_ << graph.name() << " " << graph.input() << " => " << graph.output() << " ";
if ((graph.initializer_size() > 0) || (graph.value_info_size() > 0)) {
output_ << std::endl << std::setw(indent_level) << ' ' << '<';
const char* sep = "";
for (auto& init : graph.initializer()) {
output_ << sep;
print(init, true);
sep = ", ";
}
for (auto& vi : graph.value_info()) {
output_ << sep;
print(vi);
sep = ", ";
}
output_ << ">" << std::endl;
}
print(graph.node());
}
void ProtoPrinter::print(const ModelProto& model) {
output_ << "<\n";
printKeyValuePair(KeyWordMap::KeyWord::IR_VERSION, model.ir_version(), false);
printKeyValuePair(KeyWordMap::KeyWord::OPSET_IMPORT, model.opset_import());
if (model.has_producer_name())
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_NAME, model.producer_name());
if (model.has_producer_version())
printKeyValuePair(KeyWordMap::KeyWord::PRODUCER_VERSION, model.producer_version());
if (model.has_domain())
printKeyValuePair(KeyWordMap::KeyWord::DOMAIN_KW, model.domain());
if (model.has_model_version())
printKeyValuePair(KeyWordMap::KeyWord::MODEL_VERSION, model.model_version());
if (model.has_doc_string())
printKeyValuePair(KeyWordMap::KeyWord::DOC_STRING, model.doc_string());
if (model.metadata_props_size() > 0)
printKeyValuePair(KeyWordMap::KeyWord::METADATA_PROPS, model.metadata_props());
output_ << std::endl << ">" << std::endl;
print(model.graph());
for (const auto& fn : model.functions()) {
output_ << std::endl;
print(fn);
}
}
void ProtoPrinter::print(const OperatorSetIdProto& opset) {
output_ << "\"" << opset.domain() << "\" : " << opset.version();
}
void ProtoPrinter::print(const OpsetIdList& opsets) {
printSet("[", ", ", "]", opsets);
}
void ProtoPrinter::print(const FunctionProto& fn) {
output_ << "<\n";
output_ << " "
<< "domain: \"" << fn.domain() << "\",\n";
if (!fn.overload().empty())
output_ << " "
<< "overload: \"" << fn.overload() << "\",\n";
output_ << " "
<< "opset_import: ";
printSet("[", ",", "]", fn.opset_import());
output_ << "\n>\n";
output_ << fn.name() << " ";
if (fn.attribute_size() > 0)
printSet("<", ",", ">", fn.attribute());
printSet("(", ", ", ")", fn.input());
output_ << " => ";
printSet("(", ", ", ")", fn.output());
output_ << "\n";
print(fn.node());
}
#define DEF_OP(T) \
std::ostream& operator<<(std::ostream& os, const T& proto) { \
ProtoPrinter printer(os); \
printer.print(proto); \
return os; \
};
DEF_OP(TensorShapeProto_Dimension)
DEF_OP(TensorShapeProto)
DEF_OP(TypeProto_Tensor)
DEF_OP(TypeProto)
DEF_OP(TensorProto)
DEF_OP(ValueInfoProto)
DEF_OP(ValueInfoList)
DEF_OP(AttributeProto)
DEF_OP(AttrList)
DEF_OP(NodeProto)
DEF_OP(NodeList)
DEF_OP(GraphProto)
DEF_OP(FunctionProto)
DEF_OP(ModelProto)
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,50 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <iostream>
#include <string>
#include "onnx/defs/parser.h"
#include "onnx/onnx_pb.h"
namespace ONNX_NAMESPACE {
std::ostream& operator<<(std::ostream& os, const TensorShapeProto_Dimension& dim);
std::ostream& operator<<(std::ostream& os, const TensorShapeProto& shape);
std::ostream& operator<<(std::ostream& os, const TypeProto_Tensor& tensortype);
std::ostream& operator<<(std::ostream& os, const TypeProto& type);
std::ostream& operator<<(std::ostream& os, const TensorProto& tensor);
std::ostream& operator<<(std::ostream& os, const ValueInfoProto& value_info);
std::ostream& operator<<(std::ostream& os, const ValueInfoList& vilist);
std::ostream& operator<<(std::ostream& os, const AttributeProto& attr);
std::ostream& operator<<(std::ostream& os, const AttrList& attrlist);
std::ostream& operator<<(std::ostream& os, const NodeProto& node);
std::ostream& operator<<(std::ostream& os, const NodeList& nodelist);
std::ostream& operator<<(std::ostream& os, const GraphProto& graph);
std::ostream& operator<<(std::ostream& os, const FunctionProto& fn);
std::ostream& operator<<(std::ostream& os, const ModelProto& model);
template <typename ProtoType>
std::string ProtoToString(const ProtoType& proto) {
std::stringstream ss;
ss << proto;
return ss.str();
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,295 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* QuantizeLinear_ver21_doc = R"DOC(
The linear quantization operator consumes a high-precision tensor, a scale, and a zero point to compute the
low-precision/quantized tensor. The scale factor and zero point must have the same shape, determining the quantization
granularity. The quantization formula is `y = saturate((x / y_scale) + y_zero_point)`.
Saturation is done according to:
- uint16: [0, 65535]
- int16: [-32768, 32767]
- uint8: [0, 255]
- int8: [-128, 127]
- uint4: [0, 15]
- int4: [-8, 7]
For `(x / y_scale)`, it rounds to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
`y_zero_point` and `y` must have the same type. `y_zero_point` is usually not used for quantization to float8 types, but the quantization
formula remains the same for consistency, and the type of the attribute `y_zero_point` still determines the quantization type.
There are three supported quantization granularities, determined by the shape of `y_scale`.
In all cases, `y_zero_point` must have the same shape as `y_scale`.
- Per-tensor (per-layer) quantization: `y_scale` is a scalar.
- Per-axis quantization: The scale must be a 1-D tensor, with the length of the quantization axis. For an input shape
`(D0, ..., Di, ..., Dn)` and `axis=i`, `y_scale` is a 1-D tensor of length `Di`.
- Blocked quantization: The scale's shape is identical to the input's shape, except for one dimension, in which
blocking is performed. Given `x` shape `(D0, ..., Di, ..., Dn)`, `axis=i`, and block size `B`: `y_scale` shape is
`(D0, ..., ceil(Di/B), ..., Dn)`.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
QuantizeLinear,
21,
OpSchema()
.Input(0, "x", "N-D full precision Input tensor to be quantized.", "T1")
.Input(
1,
"y_scale",
"Scale for doing quantization to get `y`. For per-tensor/layer quantization the scale is a scalar, for "
"per-axis quantization it is a 1-D Tensor and for blocked quantization it has the same shape as the "
"input, except for one dimension in which blocking is performed.",
"T1")
.Input(
2,
"y_zero_point",
"Zero point for doing quantization to get `y`. Shape must match `y_scale`."
"Default is uint8 with zero point of 0 if it's not specified.",
"T2",
OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input `x`.", "T2")
.Attr(
"axis",
"(Optional) The axis of the dequantizing dimension of the input tensor. Used only for per-axis and blocked "
"quantization. Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` "
"where `r = rank(input)`. When the rank of the input is 1, per-tensor quantization is applied, "
"rendering the axis unnecessary in this scenario.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"saturate",
"The parameter defines how the conversion behaves if an input value is out of "
"range of the destination type. It only applies for float 8 quantization "
"(float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). It is true by default. "
"All cases are fully described in two tables inserted in the operator description.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"block_size",
"(Optional) The size of the quantization block (number of times every scale is replicated). Used only for "
"blocked quantization. The block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., Dn)`, "
"`y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted range is "
"`[ceil(Di/Si), ceil(Di/(Si-1))-1]`",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"output_dtype",
"(Optional) The output data type. If not supplied, the output data type is inferred from `y_zero_point` data type (`T2`). "
"If neither `output_dtype` nor `y_zero_point` are supplied, output data type is uint8. "
"If both `output_dtype` and `y_zero_point` are specified, `output_dtype` must be `T2`.",
AttributeProto::INT,
static_cast<int64_t>(0))
.TypeConstraint(
"T1",
{"tensor(float)", "tensor(float16)", "tensor(bfloat16)", "tensor(int32)"},
"The type of the input 'x'.")
.TypeConstraint(
"T2",
{"tensor(int8)",
"tensor(uint8)",
"tensor(int16)",
"tensor(uint16)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)",
"tensor(uint4)",
"tensor(int4)"},
"The type of the input `y_zero_point` and the output `y`.")
.SetDoc(QuantizeLinear_ver21_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto const zp_type = ctx.hasInput(2) ? ctx.getInputType(2) : nullptr;
auto const output_dtype =
static_cast<TensorProto_DataType>(getAttribute(ctx, "output_dtype", TensorProto::UNDEFINED));
if (zp_type != nullptr) {
auto const zp_elem_type = static_cast<TensorProto_DataType>(getTensorElementType(*zp_type));
if (output_dtype != TensorProto::UNDEFINED && output_dtype != zp_elem_type) {
fail_type_inference(
"output_dtype ",
TensorProto_DataType_Name(output_dtype),
" does not match y_zero_point type ",
TensorProto_DataType_Name(zp_elem_type),
".");
}
propagateElemTypeFromInputToOutput(ctx, 2, 0);
} else if (output_dtype != TensorProto::UNDEFINED) {
propagateElemTypeFromAttributeToOutput(ctx, "output_dtype", 0);
} else {
updateOutputElemType(ctx, 0, TensorProto::UINT8);
}
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* DequantizeLinear_ver21_doc = R"DOC(
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the
full-precision tensor. The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point`
must have the same shape, determining the quantization's granularity: a scalar for per-tensor/per-layer quantization,
a 1-D tensor for per-axis quantization, or have a rank identical to the input for blocked quantization.
See QuantizeLinear for details on quantization granularity.
`x_zero_point` and `x` must have the same type. `x` and `y` must have the same shape. In the case of dequantizing
`int32`, there's no zero point (zero point is supposed to be 0).
`zero-point` is usually not used in the case of float8 types quantization, but the dequantization formula remains the same
for consistency, and `x_scale` still determines the output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
DequantizeLinear,
21,
OpSchema()
.Input(0, "x", "N-D quantized input tensor to be de-quantized.", "T1")
.Input(
1,
"x_scale",
"Scale for input `x`. For per-tensor/layer dequantization the scale is a scalar, for "
"per per-axis dequantization it is a 1-D Tensor and for blocked dequantization it has the same shape as "
"the input, except for one dimension in which blocking is performed.",
"T2")
.Input(
2,
"x_zero_point",
"Zero point for input `x`. Shape must match x_scale. "
"It's optional. Zero point is 0 when it's not specified.",
"T1",
OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input `x`.", "T2")
.Attr(
"axis",
"(Optional) The axis of the dequantizing dimension of the input tensor. Used for per-axis and blocked "
"quantization. Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` "
"where `r = rank(input)`.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"block_size",
"(Optional) The size of the quantization block (number of times every scale is replicated). Used only for "
"blocked quantization. The block size is a positive integer. Given `x` shape `(D0, ..., Di, ..., Dn)`, "
"`y_scale` shape `(S0, ... Si, ...Sn)` and `axis=i`, the accepted range is "
"`[ceil(Di/Si), ceil(Di/(Si-1))-1]`",
AttributeProto::INT,
static_cast<int64_t>(0))
.TypeConstraint(
"T1",
{"tensor(int8)",
"tensor(uint8)",
"tensor(int16)",
"tensor(uint16)",
"tensor(int32)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)",
"tensor(uint4)",
"tensor(int4)"},
"The type of the inputs 'x_zero_point' and 'x'.")
.TypeConstraint(
"T2",
{"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
"'x_scale' determines the output type.")
.SetDoc(DequantizeLinear_ver21_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* DynamicQuantizeLinear_ver11_doc = R"DOC(
A Function to fuse calculation for Scale, Zero Point and FP32->8Bit conversion of FP32 Input data.
Outputs Scale, ZeroPoint and Quantized Input for a given FP32 Input.
Scale is calculated as:
```
y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
```
* where qmax and qmin are max and min values for quantization range i.e. [0, 255] in case of uint8
* data range is adjusted to include 0.
Zero point is calculated as:
```
intermediate_zero_point = qmin - min(x)/y_scale
y_zero_point = cast(round(saturate(itermediate_zero_point)))
```
* where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported.
* rounding to nearest ties to even.
Data quantization formula is:
```
y = saturate (round (x / y_scale) + y_zero_point)
```
* for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported.
* rounding to nearest ties to even.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
DynamicQuantizeLinear,
11,
OpSchema()
.SetDoc(DynamicQuantizeLinear_ver11_doc)
.Input(0, "x", "Input tensor", "T1")
.Output(0, "y", "Quantized output tensor", "T2")
.Output(
1,
"y_scale",
"Output scale. It's a scalar, which means a per-tensor/layer quantization.",
"tensor(float)")
.Output(
2,
"y_zero_point",
"Output zero point. It's a scalar, which means a per-tensor/layer quantization.",
"T2")
.TypeConstraint("T1", {"tensor(float)"}, "Constrain 'x' to float tensor.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain 'y_zero_point' and 'y' to 8-bit unsigned integer tensor.")
.FunctionBody(R"ONNX(
{
Q_Min = Constant<value = float {0.0}>()
Q_Max = Constant<value = float {255.0}>()
X_Min = ReduceMin <keepdims = 0> (x)
X_Min_Adjusted = Min (X_Min, Q_Min)
X_Max = ReduceMax <keepdims = 0> (x)
X_Max_Adjusted = Max (X_Max, Q_Min)
X_Range = Sub (X_Max_Adjusted, X_Min_Adjusted)
Scale = Div (X_Range, Q_Max)
Min_Scaled = Div (X_Min_Adjusted, Scale)
Initial_ZeroPoint_FP = Sub (Q_Min, Min_Scaled)
Clipped_ZeroPoint_FP = Clip (Initial_ZeroPoint_FP, Q_Min, Q_Max)
Rounded_ZeroPoint_FP = Round (Clipped_ZeroPoint_FP)
Zeropoint = Cast <to = 2> (Rounded_ZeroPoint_FP)
y_scale = Identity (Scale)
y_zero_point = Identity (Zeropoint)
y = QuantizeLinear (x, Scale, Zeropoint)
}
)ONNX")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
updateOutputElemType(ctx, 0, TensorProto::UINT8);
updateOutputElemType(ctx, 1, TensorProto::FLOAT);
updateOutputElemType(ctx, 2, TensorProto::UINT8);
ctx.getOutputType(1)->mutable_tensor_type()->mutable_shape();
ctx.getOutputType(2)->mutable_tensor_type()->mutable_shape();
if (!hasInputShape(ctx, 0))
return;
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,329 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* QuantizeLinear_ver19_doc = R"DOC(
The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor.
The scale factor and zero point must have same shape, and can be either a scalar for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantization.
The quantization formula is `y = saturate ((x / y_scale) + y_zero_point)`.
For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
'y_zero_point' and 'y' must have same type.
'y_zero_point' is usually not used for quantization to float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz,
but the quantization formula remains the same for consistency and
the type of the attribute 'y_zero_point' still determines the quantization type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
QuantizeLinear,
19,
OpSchema()
.Input(0, "x", "N-D full precision Input tensor to be quantized.", "T1")
.Input(
1,
"y_scale",
"Scale for doing quantization to get 'y'. It can be a scalar, which means per-tensor/layer quantization, "
"or a 1-D Tensor for per-axis quantization.",
"T1")
.Input(
2,
"y_zero_point",
"Zero point for doing quantization to get 'y'. Shape must match y_scale. "
"Default is uint8 with zero point of 0 if it's not specified.",
"T2",
OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.Attr(
"axis",
"(Optional) The axis of the quantization dimension of the input tensor. Ignored for per-tensor quantization. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).",
AttributeProto::INT,
static_cast<int64_t>(1))
.Attr(
"saturate",
"The parameter defines how the conversion behaves if an input value is out of "
"range of the destination type. It only applies for float 8 quantization "
"(float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz). It is true by default. "
"All cases are fully described in two tables inserted in the operator description.",
AttributeProto::INT,
static_cast<int64_t>(1))
.TypeConstraint(
"T1",
{"tensor(float)", "tensor(float16)", "tensor(bfloat16)", "tensor(int32)"},
"Constrain 'x' to float, float16, bfloat16 or int32 tensor.")
.TypeConstraint(
"T2",
{"tensor(int8)",
"tensor(uint8)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)"},
"Constrain 'y_zero_point' and 'y' to 8-bit integer/float tensor.")
.SetDoc(QuantizeLinear_ver19_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.hasInput(2)) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
} else {
updateOutputElemType(ctx, 0, TensorProto::UINT8);
}
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* DequantizeLinear_ver19_doc = R"DOC(
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor.
The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point` must have same shape, and can be either a scalar
for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantization.
`x_zero_point` and `x` must have same type. `x` and `y` must have same shape. In the case of dequantizing int32,
there's no zero point (zero point is supposed to be 0).
`zero-point` is usually not used in the case of float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz quantization,
but the dequantization formula remains the same for consistency and 'x_scale' still determines the output type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
DequantizeLinear,
19,
OpSchema()
.Input(0, "x", "N-D quantized input tensor to be de-quantized.", "T1")
.Input(
1,
"x_scale",
"Scale for input 'x'. It can be a scalar, which means a per-tensor/layer dequantization, "
"or a 1-D tensor for per-axis dequantization.",
"T2")
.Input(
2,
"x_zero_point",
"Zero point for input 'x'. Shape must match x_scale. "
"It's optional. Zero point is 0 when it's not specified.",
"T1",
OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "T2")
.Attr(
"axis",
"(Optional) The axis of the dequantizing dimension of the input tensor. Used only for per-axis quantization. "
"Negative value means counting dimensions from the back. Accepted range is `[-r, r-1]` "
"where `r = rank(input)`. When the rank of the input is 1, per-tensor quantization is applied, "
"rendering the axis unnecessary in this scenario.",
AttributeProto::INT,
static_cast<int64_t>(1))
.TypeConstraint(
"T1",
{"tensor(int8)",
"tensor(uint8)",
"tensor(int32)",
"tensor(float8e4m3fn)",
"tensor(float8e4m3fnuz)",
"tensor(float8e5m2)",
"tensor(float8e5m2fnuz)"},
"Constrain 'x_zero_point' and 'x' to 8-bit integer or float, or /32-bit integer tensor.")
.TypeConstraint(
"T2",
{"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
"'x_scale' determines the output type.")
.SetDoc(DequantizeLinear_ver19_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* QuantizeLinear_ver13_doc = R"DOC(
The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor.
The scale factor and zero point must have same shape, and can be either a scalar for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantization.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).
For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
QuantizeLinear,
13,
OpSchema()
.Input(0, "x", "N-D full precision Input tensor to be quantized.", "T1")
.Input(
1,
"y_scale",
"Scale for doing quantization to get 'y'. It can be a scalar, which means per-tensor/layer quantization, "
"or a 1-D Tensor for per-axis quantization.",
"tensor(float)")
.Input(
2,
"y_zero_point",
"Zero point for doing quantization to get 'y'. Shape must match y_scale. "
"Default is uint8 with zero point of 0 if it's not specified.",
"T2",
OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.Attr(
"axis",
"(Optional) The axis of the quantization dimension of the input tensor. Ignored for per-tensor quantization. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).",
AttributeProto::INT,
static_cast<int64_t>(1))
.TypeConstraint("T1", {"tensor(float)", "tensor(int32)"}, "Constrain 'x' to float or int32 tensor.")
.TypeConstraint(
"T2",
{"tensor(int8)", "tensor(uint8)"},
"Constrain 'y_zero_point' and 'y' to 8-bit integer tensor.")
.SetDoc(QuantizeLinear_ver13_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.hasInput(2)) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
} else {
updateOutputElemType(ctx, 0, TensorProto::UINT8);
}
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* DequantizeLinear_ver13_doc = R"DOC(
The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor.
The dequantization formula is `y = (x - x_zero_point) * x_scale`. `x_scale` and `x_zero_point` must have same shape, and can be either a scalar
for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantization.
`x_zero_point` and `x` must have same type. `x` and `y` must have same shape. In the case of dequantizing int32,
there's no zero point (zero point is supposed to be 0).
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
DequantizeLinear,
13,
OpSchema()
.Input(0, "x", "N-D quantized input tensor to be de-quantized.", "T")
.Input(
1,
"x_scale",
"Scale for input 'x'. It can be a scalar, which means a per-tensor/layer dequantization, "
"or a 1-D tensor for per-axis dequantization.",
"tensor(float)")
.Input(
2,
"x_zero_point",
"Zero point for input 'x'. Shape must match x_scale. "
"It's optional. Zero point is 0 when it's not specified.",
"T",
OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "tensor(float)")
.Attr(
"axis",
"(Optional) The axis of the dequantizing dimension of the input tensor. Ignored for per-tensor quantization. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).",
AttributeProto::INT,
static_cast<int64_t>(1))
.TypeConstraint(
"T",
{"tensor(int8)", "tensor(uint8)", "tensor(int32)"},
"Constrain 'x_zero_point' and 'x' to 8-bit/32-bit integer tensor.")
.SetDoc(DequantizeLinear_ver13_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto y_type = ctx.getOutputType(0);
// only float is supported
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT);
if (!hasInputShape(ctx, 0))
return;
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* QuantizeLinear_ver10_doc = R"DOC(
The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
QuantizeLinear,
10,
OpSchema()
.Input(0, "x", "N-D full precision Input tensor to be quantized.", "T1")
.Input(
1,
"y_scale",
"Scale for doing quantization to get 'y'. It's a scalar, which means a per-tensor/layer quantization.",
"tensor(float)")
.Input(
2,
"y_zero_point",
"Zero point for doing quantization to get 'y'. It's a scalar, which means a per-tensor/layer quantization. "
"Default value is uint8 typed 0 if it's not specified.",
"T2",
OpSchema::Optional)
.Output(0, "y", "N-D quantized output tensor. It has same shape as input 'x'.", "T2")
.TypeConstraint("T1", {"tensor(float)", "tensor(int32)"}, "Constrain 'x' to float or int32 tensor.")
.TypeConstraint(
"T2",
{"tensor(int8)", "tensor(uint8)"},
"Constrain 'y_zero_point' and 'y' to 8-bit integer tensor.")
.SetDoc(QuantizeLinear_ver10_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
if (ctx.hasInput(2)) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
} else {
updateOutputElemType(ctx, 0, TensorProto::UINT8);
}
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
static const char* DequantizeLinear_ver10_doc = R"DOC(
The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor.
The dequantization formula is y = (x - x_zero_point) * x_scale. 'x_scale' and 'x_zero_point' are both scalars.
'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. In the case of dequantizing int32,
there's no zero point (zero point is supposed to be 0).
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
DequantizeLinear,
10,
OpSchema()
.Input(0, "x", "N-D quantized input tensor to be de-quantized.", "T")
.Input(
1,
"x_scale",
"Scale for input 'x'. It's a scalar, which means a per-tensor/layer quantization.",
"tensor(float)")
.Input(
2,
"x_zero_point",
"Zero point for input 'x'. It's a scalar, which means a per-tensor/layer quantization. "
"It's optional. 0 is the default value when it's not specified.",
"T",
OpSchema::Optional)
.Output(0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", "tensor(float)")
.TypeConstraint(
"T",
{"tensor(int8)", "tensor(uint8)", "tensor(int32)"},
"Constrain 'x_zero_point' and 'x' to 8-bit/32-bit integer tensor.")
.SetDoc(DequantizeLinear_ver10_doc)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto y_type = ctx.getOutputType(0);
// only float is supported
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT);
if (!hasInputShape(ctx, 0))
return;
auto& input_shape = getInputShape(ctx, 0);
updateOutputShape(ctx, 0, input_shape);
}));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,184 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <functional>
#include "onnx/defs/function.h"
#include "onnx/defs/reduction/utils.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
ONNX_OPERATOR_SET_SCHEMA(
ReduceMax,
20,
OpSchema().FillUsing(ReduceOpGenerator("max", EMPTY_MIN, true, true, nullptr, nullptr, true)));
ONNX_OPERATOR_SET_SCHEMA(
ReduceMin,
20,
OpSchema().FillUsing(ReduceOpGenerator("min", EMPTY_MAX, true, true, nullptr, nullptr, true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceSum, 13, OpSchema().FillUsing(ReduceOpDynamicAxes("sum", EMPTY_ZERO)));
const char* reduce_sum_square_func_body = R"ONNX(
{
data_square = Mul(data, data)
reduced = ReduceSum<keepdims: int = @keepdims>(data_square, axes)
}
)ONNX";
ONNX_OPERATOR_SET_SCHEMA(
ReduceSumSquare,
18,
OpSchema().FillUsing(ReduceFunctionOp("sum square", EMPTY_ZERO, reduce_sum_square_func_body)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMean, 18, OpSchema().FillUsing(ReduceOpDynamicAxes("mean", EMPTY_UNDEFINED)));
ONNX_OPERATOR_SET_SCHEMA(ReduceProd, 18, OpSchema().FillUsing(ReduceOpDynamicAxes("product", EMPTY_ONE)));
const char* reduce_log_sum_func_body = R"ONNX(
{
reduced_sum = ReduceSum<keepdims: int = @keepdims>(data, axes)
reduced = Log (reduced_sum)
}
)ONNX";
ONNX_OPERATOR_SET_SCHEMA(
ReduceLogSum,
18,
OpSchema().FillUsing(ReduceFunctionOp("log sum", EMPTY_MINUS_INF, reduce_log_sum_func_body)));
const char* reduce_log_sum_exp_func_body = R"ONNX(
{
data_double = Cast<to = 11>(data)
data_exp = Exp (data_double)
reduced_sum = ReduceSum<keepdims: int = @keepdims>(data_exp, axes)
reduced_double = Log (reduced_sum)
reduced = CastLike(reduced_double, data)
}
)ONNX";
ONNX_OPERATOR_SET_SCHEMA(
ReduceLogSumExp,
18,
OpSchema().FillUsing(ReduceFunctionOp("log sum exponent", EMPTY_MINUS_INF, reduce_log_sum_exp_func_body)));
const char* reduce_l1_func_body = R"ONNX(
{
data_abs = Abs(data)
reduced = ReduceSum<keepdims: int = @keepdims>(data_abs, axes)
}
)ONNX";
ONNX_OPERATOR_SET_SCHEMA(
ReduceL1,
18,
OpSchema().FillUsing(ReduceFunctionOp("L1 norm", EMPTY_ZERO, reduce_l1_func_body)));
const char* reduce_l2_func_body = R"ONNX(
{
data_square = Mul(data, data)
sum_square = ReduceSum<keepdims: int = @keepdims>(data_square, axes)
sum_square_dbl = Cast <to = 1>(sum_square)
sqrt = Sqrt(sum_square_dbl)
reduced = CastLike(sqrt, data)
}
)ONNX";
ONNX_OPERATOR_SET_SCHEMA(
ReduceL2,
18,
OpSchema().FillUsing(ReduceFunctionOp("L2 norm", EMPTY_ZERO, reduce_l2_func_body)));
std::function<void(OpSchema&)> ArgReduceDocGenerator(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Computes the indices of the {name} elements of the input tensor's element along the
provided axis. The resulting tensor has the same rank as the input if keepdims equals 1.
If keepdims equals 0, then the resulting tensor has the reduced dimension pruned.
If select_last_index is True (default False), the index of the last occurrence of the {name}
is selected if the {name} appears more than once in the input. Otherwise the index of the
first occurrence is selected.
The type of the output tensor is integer.)DOC";
ReplaceAll(doc, "{name}", name););
schema.SetDoc(doc.c_str());
schema.Attr(
"axis",
"The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Attr(
"select_last_index",
"Whether to select the last index or the first index if the {name} appears in multiple indices, default is False (first index).",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable);
schema.Output(
0,
"reduced",
"Reduced output tensor with integer data type.",
"tensor(int64)",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.TypeConstraint(
"T", OpSchema::all_numeric_types_ir4(), "Constrain input and output types to all numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// set output element type to int64
updateOutputElemType(ctx, 0, TensorProto_DataType_INT64);
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
int64_t input_ndim = input_shape.dim_size();
int64_t axis = 0; // default to 0
auto axis_proto = ctx.getAttribute("axis");
if (axis_proto) {
axis = axis_proto->i();
if (axis < -input_ndim || axis >= input_ndim) {
fail_shape_inference("'axis' must be in [-rank(indices), rank(indices)-1]");
}
if (axis < 0)
axis += input_ndim;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
if (i != axis) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
}
ONNX_OPERATOR_SET_SCHEMA(ArgMax, 13, OpSchema().FillUsing(ArgReduceDocGenerator("max")));
ONNX_OPERATOR_SET_SCHEMA(ArgMin, 13, OpSchema().FillUsing(ArgReduceDocGenerator("min")));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,446 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <functional>
#include "onnx/defs/reduction/utils.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
std::vector<std::string> GetSupportedDataTypesForReductionOps_opset12(bool supports8bit) {
if (supports8bit) {
auto data_types = OpSchema::numeric_types_for_math_reduction();
data_types.push_back("tensor(uint8)");
data_types.push_back("tensor(int8)");
return data_types;
}
return OpSchema::numeric_types_for_math_reduction();
}
std::function<void(OpSchema&)> ReduceDocGenerator_opset12(const char* name, bool supports_8bit_datatypes = false) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Computes the {name} of the input tensor's element along the provided axes. The resulting
tensor has the same rank as the input if keepdims equals 1. If keepdims equal 0, then
the resulted tensor have the reduced dimension pruned.
The above behavior is similar to numpy, with the exception that numpy defaults keepdims to
False instead of True.)DOC";
ReplaceAll(doc, "{name}", name););
schema.SetDoc(doc.c_str());
schema.Attr(
"axes",
"A list of integers, along which to reduce. The default is to reduce over "
"all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INTS,
OPTIONAL_VALUE);
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Input(0, "data", "An input tensor.", "T");
schema.Output(0, "reduced", "Reduced output tensor.", "T");
schema.TypeConstraint(
"T",
GetSupportedDataTypesForReductionOps_opset12(supports_8bit_datatypes),
supports_8bit_datatypes ? "Constrain input and output types to high-precision and 8 bit numeric tensors."
: "Constrain input and output types to high-precision numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
std::vector<int64_t> axes;
auto axes_proto = ctx.getAttribute("axes");
if (axes_proto)
axes.assign(axes_proto->ints().begin(), axes_proto->ints().end());
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i] < -input_ndim || axes[i] >= input_ndim) {
fail_shape_inference("axis must be in [-rank, rank-1]. input rank was ", input_ndim);
}
if (axes[i] < 0)
axes[i] += input_ndim;
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
// axes empty means reduce all dim
if (!axes.empty() && std::find(axes.begin(), axes.end(), i) == axes.end()) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
}
ONNX_OPERATOR_SET_SCHEMA(ReduceMax, 12, OpSchema().FillUsing(ReduceDocGenerator_opset12("max", true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMin, 12, OpSchema().FillUsing(ReduceDocGenerator_opset12("min", true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceSum, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("sum")));
ONNX_OPERATOR_SET_SCHEMA(ReduceSumSquare, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("sum square")));
ONNX_OPERATOR_SET_SCHEMA(ReduceMean, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("mean")));
ONNX_OPERATOR_SET_SCHEMA(ReduceProd, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("product")));
ONNX_OPERATOR_SET_SCHEMA(ReduceLogSum, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("log sum")));
ONNX_OPERATOR_SET_SCHEMA(ReduceLogSumExp, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("log sum exponent")));
ONNX_OPERATOR_SET_SCHEMA(ReduceL1, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("L1 norm")));
ONNX_OPERATOR_SET_SCHEMA(ReduceL2, 11, OpSchema().FillUsing(ReduceDocGenerator_opset12("L2 norm")));
std::function<void(OpSchema&)> ArgReduceDocGenerator_opset12(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Computes the indices of the {name} elements of the input tensor's element along the
provided axis. The resulting tensor has the same rank as the input if keepdims equals 1.
If keepdims equal 0, then the resulting tensor has the reduced dimension pruned.
If select_last_index is True (default False), the index of the last occurrence of the {name}
is selected if the {name} appears more than once in the input. Otherwise the index of the
first occurrence is selected.
The type of the output tensor is integer.)DOC";
ReplaceAll(doc, "{name}", name););
schema.SetDoc(doc.c_str());
schema.Attr(
"axis",
"The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Attr(
"select_last_index",
"Whether to select the last index or the first index if the {name} appears in multiple indices, default is False (first index).",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Input(0, "data", "An input tensor.", "T");
schema.Output(0, "reduced", "Reduced output tensor with integer data type.", "tensor(int64)");
schema.TypeConstraint(
"T", OpSchema::all_numeric_types(), "Constrain input and output types to all numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// set output element type to int64
updateOutputElemType(ctx, 0, TensorProto_DataType_INT64);
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
int64_t input_ndim = input_shape.dim_size();
int64_t axis = 0; // default to 0
auto axis_proto = ctx.getAttribute("axis");
if (axis_proto) {
axis = axis_proto->i();
if (axis < -input_ndim || axis >= input_ndim) {
fail_shape_inference("'axis' must be in [-rank(indices), rank(indices)-1]");
}
if (axis < 0)
axis += input_ndim;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
if (i != axis) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
} // namespace ONNX_NAMESPACE
ONNX_OPERATOR_SET_SCHEMA(ArgMax, 12, OpSchema().FillUsing(ArgReduceDocGenerator_opset12("max")));
ONNX_OPERATOR_SET_SCHEMA(ArgMin, 12, OpSchema().FillUsing(ArgReduceDocGenerator_opset12("min")));
std::function<void(OpSchema&)> ReduceDocGenerator_opset1(const char* name, const char* empty_value, int opset = 1) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Computes the {name} of the input tensor's element along the provided axes. The resulting
tensor has the same rank as the input if keepdims equals 1. If keepdims equal 0, then
the resulted tensor have the reduced dimension pruned. Input tensors of rank zero are
valid. Reduction over an empty set of values yields {empty_value}.
The above behavior is similar to numpy, with the exception that numpy defaults keepdims to
False instead of True.)DOC";
ReplaceAll(doc, "{name}", name););
ReplaceAll(doc, "{empty_value}", empty_value);
schema.SetDoc(doc.c_str());
schema.Attr(
"axes",
opset >= 11 ? "A list of integers, along which to reduce. The default is to reduce over "
"all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data)."
: "A list of integers, along which to reduce. The default is to reduce over "
"all the dimensions of the input tensor.",
AttributeProto::INTS,
OPTIONAL_VALUE);
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Input(0, "data", "An input tensor.", "T");
schema.Output(0, "reduced", "Reduced output tensor.", "T");
schema.TypeConstraint(
"T",
OpSchema::numeric_types_for_math_reduction(),
"Constrain input and output types to high-precision numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
std::vector<int64_t> axes;
auto axes_proto = ctx.getAttribute("axes");
if (axes_proto)
axes.assign(axes_proto->ints().begin(), axes_proto->ints().end());
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i] < 0)
axes[i] += input_ndim;
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
// axes empty means reduce all dim
if (!axes.empty() && std::find(axes.begin(), axes.end(), i) == axes.end()) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
}
ONNX_OPERATOR_SET_SCHEMA(ReduceMax, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("max", EMPTY_MIN)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMin, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("min", EMPTY_MAX)));
ONNX_OPERATOR_SET_SCHEMA(ReduceSum, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("sum", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceSumSquare, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("sum square", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMean, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("mean", EMPTY_UNDEFINED)));
ONNX_OPERATOR_SET_SCHEMA(ReduceProd, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("product", EMPTY_ONE)));
ONNX_OPERATOR_SET_SCHEMA(ReduceLogSum, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("log sum", EMPTY_MINUS_INF)));
ONNX_OPERATOR_SET_SCHEMA(
ReduceLogSumExp,
1,
OpSchema().FillUsing(ReduceDocGenerator_opset1("log sum exponent", EMPTY_MINUS_INF)));
ONNX_OPERATOR_SET_SCHEMA(ReduceL1, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("L1 norm", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceL2, 1, OpSchema().FillUsing(ReduceDocGenerator_opset1("L2 norm", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMax, 11, OpSchema().FillUsing(ReduceDocGenerator_opset1("max", EMPTY_MIN, 11)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMin, 11, OpSchema().FillUsing(ReduceDocGenerator_opset1("min", EMPTY_MAX, 11)));
std::function<void(OpSchema&)> ArgReduceDocGenerator_opset1(const char* name) {
return [=](OpSchema& schema) {
std::string doc;
POPULATE_OP_DOC_STR(doc = R"DOC(
Computes the indices of the {name} elements of the input tensor's element along the
provided axis. The resulting tensor has the same rank as the input if keepdims equals 1.
If keepdims equal 0, then the resulted tensor have the reduced dimension pruned.
The type of the output tensor is integer.)DOC";
ReplaceAll(doc, "{name}", name););
schema.SetDoc(doc.c_str());
schema.Attr("axis", "The axis in which to compute the arg indices.", AttributeProto::INT, static_cast<int64_t>(0));
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Input(0, "data", "An input tensor.", "T");
schema.Output(0, "reduced", "Reduced output tensor with integer data type.", "tensor(int64)");
schema.TypeConstraint(
"T", OpSchema::all_numeric_types(), "Constrain input and output types to all numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// set output element type to int64
updateOutputElemType(ctx, 0, TensorProto_DataType_INT64);
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
int64_t input_ndim = input_shape.dim_size();
int64_t axis = 0; // default to 0
auto axis_proto = ctx.getAttribute("axis");
if (axis_proto) {
axis = axis_proto->i();
if (axis < 0)
axis += input_ndim;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
if (i != axis) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
} // namespace ONNX_NAMESPACE
ONNX_OPERATOR_SET_SCHEMA(ArgMax, 1, OpSchema().FillUsing(ArgReduceDocGenerator_opset1("max")));
ONNX_OPERATOR_SET_SCHEMA(ArgMin, 1, OpSchema().FillUsing(ArgReduceDocGenerator_opset1("min")));
std::function<void(OpSchema&)> ArgReduceDocGenerator_opset11(const char* name) {
return [=](OpSchema& schema) {
std::string doc = R"DOC(
Computes the indices of the {name} elements of the input tensor's element along the
provided axis. The resulting tensor has the same rank as the input if keepdims equals 1.
If keepdims equal 0, then the resulting tensor has the reduced dimension pruned.
The input tensor must not be empty.
The type of the output tensor is integer.)DOC";
ReplaceAll(doc, "{name}", name);
schema.SetDoc(doc.c_str());
schema.Attr(
"axis",
"The axis in which to compute the arg indices. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Input(0, "data", "An input tensor.", "T");
schema.Output(0, "reduced", "Reduced output tensor with integer data type.", "tensor(int64)");
schema.TypeConstraint(
"T", OpSchema::all_numeric_types(), "Constrain input and output types to all numeric tensors.");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// set output element type to int64
updateOutputElemType(ctx, 0, TensorProto_DataType_INT64);
if (!hasNInputShapes(ctx, 1)) {
return;
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
int64_t input_ndim = input_shape.dim_size();
int64_t axis = 0; // default to 0
auto axis_proto = ctx.getAttribute("axis");
if (axis_proto) {
axis = axis_proto->i();
if (axis < -input_ndim || axis >= input_ndim) {
fail_shape_inference("'axis' must be in [-rank(indices), rank(indices)-1]");
}
if (axis < 0)
axis += input_ndim;
}
int64_t keep_dims = 1;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
// do we need handle negative axis?
for (int i = 0; i < input_ndim; ++i) {
if (i != axis) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
} // namespace ONNX_NAMESPACE
ONNX_OPERATOR_SET_SCHEMA(ArgMax, 11, OpSchema().FillUsing(ArgReduceDocGenerator_opset11("max")));
ONNX_OPERATOR_SET_SCHEMA(ArgMin, 11, OpSchema().FillUsing(ArgReduceDocGenerator_opset11("min")));
ONNX_OPERATOR_SET_SCHEMA(ReduceMax, 13, OpSchema().FillUsing(ReduceOpGenerator("max", EMPTY_MIN, true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMin, 13, OpSchema().FillUsing(ReduceOpGenerator("min", EMPTY_MAX, true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceSumSquare, 13, OpSchema().FillUsing(ReduceOpGenerator("sum square", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMean, 13, OpSchema().FillUsing(ReduceOpGenerator("mean", EMPTY_UNDEFINED)));
ONNX_OPERATOR_SET_SCHEMA(ReduceProd, 13, OpSchema().FillUsing(ReduceOpGenerator("product", EMPTY_ONE)));
ONNX_OPERATOR_SET_SCHEMA(ReduceLogSum, 13, OpSchema().FillUsing(ReduceOpGenerator("log sum", EMPTY_MINUS_INF)));
ONNX_OPERATOR_SET_SCHEMA(
ReduceLogSumExp,
13,
OpSchema().FillUsing(ReduceOpGenerator("log sum exponent", EMPTY_MINUS_INF)));
ONNX_OPERATOR_SET_SCHEMA(ReduceL1, 13, OpSchema().FillUsing(ReduceOpGenerator("L1 norm", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceL2, 13, OpSchema().FillUsing(ReduceOpGenerator("L2 norm", EMPTY_ZERO)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMax, 18, OpSchema().FillUsing(ReduceOpGenerator("max", EMPTY_MIN, true, true)));
ONNX_OPERATOR_SET_SCHEMA(ReduceMin, 18, OpSchema().FillUsing(ReduceOpGenerator("min", EMPTY_MAX, true, true)));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,163 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/reduction/utils.h"
#include <algorithm>
#include <string>
#include <vector>
namespace ONNX_NAMESPACE {
std::vector<std::string> GetSupportedDataTypesForReductionOps(bool supports8bit, bool supports_bool) {
auto data_types = OpSchema::numeric_types_for_math_reduction_ir4();
if (supports8bit) {
data_types.push_back("tensor(uint8)");
data_types.push_back("tensor(int8)");
}
if (supports_bool) {
data_types.push_back("tensor(bool)");
}
return data_types;
}
std::function<void(OpSchema&)> ReduceOpGenerator(
const char* name,
const char* empty_value,
bool supports_8bit_datatypes,
bool axes_input,
const char* func_body,
ContextDependentFunctionBodyBuilder function_builder,
bool supports_boolean_datatype /* = false */) {
return [=](OpSchema& schema) {
std::string doc = R"DOC(
Computes the {name} of the input tensor's elements along the provided axes. The resulting
tensor has the same rank as the input if `keepdims` equals 1. If `keepdims` equals 0, then
the resulting tensor has the reduced dimension pruned. Input tensors of rank zero are
valid. Reduction over an empty set of values yields {empty_value}.
)DOC";
if (supports_boolean_datatype) {
doc += R"DOC(
If the input data type is Boolean, the comparison should consider `False < True`.)DOC";
}
doc += R"DOC(
The above behavior is similar to numpy, with the exception that numpy defaults `keepdims`
to `False` instead of `True`.)DOC";
ReplaceAll(doc, "{name}", name);
ReplaceAll(doc, "{empty_value}", empty_value);
POPULATE_OP_DOC_STR(doc = doc;);
schema.SetDoc(doc.c_str());
schema.Attr(
"keepdims",
"Keep the reduced dimension or not, default 1 means keep reduced dimension.",
AttributeProto::INT,
static_cast<int64_t>(1));
schema.Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
if (axes_input) {
schema.Attr(
"noop_with_empty_axes",
"Defines behavior if 'axes' is empty. Default behavior with 'false' is to reduce all axes. "
"When axes is empty and this attribute is set to true, input tensor will not be reduced,"
"and the output tensor would be equivalent to input tensor.",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Input(
1,
"axes",
"Optional input list of integers, along which to reduce. "
"The default is to reduce over all the dimensions of the input tensor if 'noop_with_empty_axes' is false, "
"else act as an Identity op when 'noop_with_empty_axes' is true. "
"Accepted range is [-r, r-1] where r = rank(data).",
"tensor(int64)",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable);
} else {
schema.Attr(
"axes",
"A list of integers, along which to reduce. The default is to reduce over "
"all the dimensions of the input tensor. Accepted range is [-r, r-1] where r = rank(data).",
AttributeProto::INTS,
OPTIONAL_VALUE);
}
schema.Output(0, "reduced", "Reduced output tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
schema.TypeConstraint(
"T",
GetSupportedDataTypesForReductionOps(supports_8bit_datatypes, supports_boolean_datatype),
supports_boolean_datatype ? "Constrain input and output types to numeric and Boolean tensors."
: "Constrain input and output types to numeric tensors.");
if (func_body) {
schema.FunctionBody(func_body);
} else if (function_builder) {
schema.SetContextDependentFunctionBodyBuilder(function_builder);
}
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
int64_t keep_dims = 1, noop_with_empty_axes = 0;
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keep_dims = attr_proto->i();
}
auto noop_attr_proto = ctx.getAttribute("noop_with_empty_axes");
if (noop_attr_proto) {
noop_with_empty_axes = noop_attr_proto->i();
}
std::vector<int64_t> axes;
if (ctx.hasInput(1)) { // axes is input
if (ctx.getAttribute("axes")) {
fail_shape_inference("axes as an input and attribute cannot be specified at the same time.");
}
const TensorProto* axesInitializer = ctx.getInputData(1);
if (axesInitializer == nullptr) {
// skip if axes is not an initializer
return;
}
std::vector<int64_t> axes_values = ParseData<int64_t>(axesInitializer);
axes.assign(axes_values.begin(), axes_values.end());
} else { // axes is attribute
auto axes_proto = ctx.getAttribute("axes");
if (axes_proto)
axes.assign(axes_proto->ints().begin(), axes_proto->ints().end());
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
if (noop_with_empty_axes && axes.empty()) {
propagateShapeFromInputToOutput(ctx, 0, 0);
return;
}
int64_t input_ndim = input_shape.dim_size();
auto output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i] < -input_ndim || axes[i] >= input_ndim) {
fail_shape_inference("axis must be in [-rank, rank-1]. input rank was ", input_ndim);
}
if (axes[i] < 0)
axes[i] += input_ndim;
}
for (int i = 0; i < input_ndim; ++i) {
// axes empty means reduce all dim
if (!axes.empty() && std::find(axes.begin(), axes.end(), i) == axes.end()) {
auto dim = output_shape->add_dim();
dim->CopyFrom(input_shape.dim(i));
} else {
if (keep_dims == 1) {
auto dim = output_shape->add_dim();
dim->set_dim_value(1);
}
}
}
});
};
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,42 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <cmath>
#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// Constants used to indicate value returned by reduction of an empty set of values.
constexpr const char* EMPTY_ZERO = "0";
constexpr const char* EMPTY_ONE = "1";
constexpr const char* EMPTY_UNDEFINED = "undefined";
constexpr const char* EMPTY_MIN =
"minus infinity (if supported by the datatype) or the minimum value of the data type otherwise";
constexpr const char* EMPTY_MAX =
"plus infinity (if supported by the datatype) or the maximum value of the data type otherwise";
constexpr const char* EMPTY_MINUS_INF = "minus infinity (if supported by the datatype) or undefined otherwise";
std::function<void(OpSchema&)> ReduceOpGenerator(
const char* name,
const char* empty_value,
bool supports_8bit_datatypes = false,
bool axes_input = false,
const char* func_body = nullptr,
ContextDependentFunctionBodyBuilder function_builder = nullptr,
bool supports_boolean_datatype = false);
inline std::function<void(OpSchema&)> ReduceOpDynamicAxes(const char* name, const char* empty_value) {
return ReduceOpGenerator(name, empty_value, false, true, nullptr, nullptr, false);
}
inline std::function<void(OpSchema&)>
ReduceFunctionOp(const char* name, const char* empty_value, const char* func_body) {
return ReduceOpGenerator(name, empty_value, false, true, func_body);
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,519 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
void RNNShapeInference(InferenceContext& ctx) {
TensorShapeProto::Dimension num_directions, seq_length, batch_size, hidden_size;
auto direction = getAttribute(ctx, "direction", "forward");
if ((direction == "forward") || (direction == "reverse"))
num_directions.set_dim_value(1);
else if (direction == "bidirectional")
num_directions.set_dim_value(2);
// else leave num_directions unknown in case of incorrect attribute value
auto hidden_size_value = getAttribute(ctx, "hidden_size", -1);
if (hidden_size_value > 0)
hidden_size.set_dim_value(hidden_size_value);
auto layout_value = getAttribute(ctx, "layout", 0);
if (hasInputShape(ctx, 0)) {
auto& first_input_shape = getInputShape(ctx, 0);
if (first_input_shape.dim_size() != 3) {
fail_shape_inference("First input tensor must have rank 3");
}
seq_length = first_input_shape.dim((layout_value == 0) ? 0 : 1);
batch_size = first_input_shape.dim((layout_value == 0) ? 1 : 0);
}
auto num_outputs = ctx.getNumOutputs();
if (num_outputs > 0) {
// Y
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (layout_value == 0) {
auto dims = {seq_length, num_directions, batch_size, hidden_size};
updateOutputShape(ctx, 0, dims);
} else {
auto dims = {batch_size, seq_length, num_directions, hidden_size};
updateOutputShape(ctx, 0, dims);
}
}
if (num_outputs > 1) {
// Y_h
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (layout_value == 0) {
auto dims = {num_directions, batch_size, hidden_size};
updateOutputShape(ctx, 1, dims);
} else {
auto dims = {batch_size, num_directions, hidden_size};
updateOutputShape(ctx, 1, dims);
}
}
if (num_outputs > 2) {
// Y_c : only in the case of LSTM
propagateElemTypeFromInputToOutput(ctx, 0, 2);
if (layout_value == 0) {
auto dims = {num_directions, batch_size, hidden_size};
updateOutputShape(ctx, 2, dims);
} else {
auto dims = {batch_size, num_directions, hidden_size};
updateOutputShape(ctx, 2, dims);
}
}
}
std::function<void(OpSchema&)> RNNDocGenerator(const char* /*name*/) {
return [=](OpSchema& schema) {
schema.Attr(
"direction",
"Specify if the RNN is forward, reverse, or bidirectional. "
"Must be one of forward (default), reverse, or bidirectional.",
AttributeProto::STRING,
std::string("forward"));
schema.Attr(
"layout",
"The shape format of inputs X, initial_h and outputs Y, Y_h. "
"If 0, the following shapes are expected: "
"X.shape = [seq_length, batch_size, input_size], "
"Y.shape = [seq_length, num_directions, batch_size, hidden_size], "
"initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size]. "
"If 1, the following shapes are expected: "
"X.shape = [batch_size, seq_length, input_size], "
"Y.shape = [batch_size, seq_length, num_directions, hidden_size], "
"initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].",
AttributeProto::INT,
static_cast<int64_t>(0));
schema.Attr("hidden_size", "Number of neurons in the hidden layer", AttributeProto::INT, OPTIONAL_VALUE);
schema.Attr(
"activation_alpha",
"Optional scaling values used by some activation functions. The values "
"are consumed in the order of activation functions, for example (f, g, h) "
"in LSTM. Default values are the same as of corresponding ONNX operators."
"For example with LeakyRelu, the default alpha is 0.01.",
AttributeProto::FLOATS,
OPTIONAL_VALUE);
schema.Attr(
"activation_beta",
"Optional scaling values used by some activation functions. The values "
"are consumed in the order of activation functions, for example (f, g, h) "
"in LSTM. Default values are the same as of corresponding ONNX operators.",
AttributeProto::FLOATS,
OPTIONAL_VALUE);
schema.Attr(
"clip",
"Cell clip threshold. Clipping bounds the elements of a tensor "
"in the range of [-threshold, +threshold] and is applied to the input "
"of activations. No clip if not specified.",
AttributeProto::FLOAT,
OPTIONAL_VALUE);
schema.Input(
0,
"X",
"The input sequences packed (and potentially padded) into one 3-D "
"tensor with the shape of `[seq_length, batch_size, input_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable);
schema.Input(
4,
"sequence_lens",
"Optional tensor specifying lengths of the sequences in a batch. "
"If not specified - assumed all sequences in the batch to have "
"length `seq_length`. It has shape `[batch_size]`.",
"T1",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable);
schema.Input(
5,
"initial_h",
"Optional initial value of the hidden. If not specified - assumed "
"to be 0. It has shape `[num_directions, batch_size, hidden_size]`.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable);
schema.Output(
0,
"Y",
"A tensor that concats all the intermediate output values of the hidden. "
"It has shape `[seq_length, num_directions, batch_size, hidden_size]`. ",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable);
schema.Output(
1,
"Y_h",
"The last output value of the hidden. It has shape "
"`[num_directions, batch_size, hidden_size]`.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable);
schema.TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.");
schema.TypeConstraint("T1", {"tensor(int32)"}, "Constrain seq_lens to integer tensor.");
schema.TypeAndShapeInferenceFunction(RNNShapeInference);
};
}
static const char* RNN_ver22_doc = R"DOC(
Computes an one-layer simple RNN. This operator is usually supported
via some custom implementation such as CuDNN.
Notations:
* `X` - input tensor
* `i` - input gate
* `t` - time step (t-1 means previous time step)
* `Wi` - W parameter weight matrix for input gate
* `Ri` - R recurrence weight matrix for input gate
* `Wbi` - W parameter bias vector for input gate
* `Rbi` - R parameter bias vector for input gate
* `WBi` - W parameter weight matrix for backward input gate
* `RBi` - R recurrence weight matrix for backward input gate
* `WBbi` - WR bias vectors for backward input gate
* `RBbi` - RR bias vectors for backward input gate
* `H` - Hidden state
* `num_directions` - 2 if direction == bidirectional else 1
Activation functions:
* Relu(x) - max(0, x)
* Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
* Sigmoid(x) - 1/(1 + e^{-x})
NOTE: Below are optional
* Affine(x) - alpha*x + beta
* LeakyRelu(x) - x if x >= 0 else alpha * x
* ThresholdedRelu(x) - x if x >= alpha else 0
* ScaledTanh(x) - alpha*Tanh(beta*x)
* HardSigmoid(x) - min(max(alpha*x + beta, 0), 1)
* Elu(x) - x if x >= 0 else alpha*(e^x - 1)
* Softsign(x) - x/(1 + |x|)
* Softplus(x) - log(1 + e^x)
Equations (Default: f=Tanh):
* Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RNN,
22,
OpSchema()
.SetDoc(GET_OP_DOC_STR(std::string(RNN_ver22_doc) + GenerateOptionalArgumentsDoc()))
.Attr(
"activations",
"One (or two if bidirectional) activation function for "
"input gate. The activation function must be one of the activation "
"functions specified above. Optional: Default `Tanh` if not specified.",
AttributeProto::STRINGS,
std::vector<std::string>{"Tanh", "Tanh"})
.Input(
1,
"W",
"The weight tensor for input gate. Concatenation of `Wi` and `WBi` "
"(if bidirectional). The tensor has shape "
"`[num_directions, hidden_size, input_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
2,
"R",
"The recurrence weight tensor. Concatenation of `Ri` and `RBi` "
"(if bidirectional). The tensor has shape "
"`[num_directions, hidden_size, hidden_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
3,
"B",
"The bias tensor for input gate. Concatenation of `[Wbi, Rbi]` "
"and `[WBbi, RBbi]` (if bidirectional). The tensor has shape "
"`[num_directions, 2*hidden_size]`. Optional: If not specified - assumed "
"to be 0.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable)
.FillUsing(RNNDocGenerator("RNN")));
static const char* GRU_ver22_doc = R"DOC(
Computes an one-layer GRU. This operator is usually supported via some custom
implementation such as CuDNN.
Notations:
* `X` - input tensor
* `z` - update gate
* `r` - reset gate
* `h` - hidden gate
* `t` - time step (t-1 means previous time step)
* `W[zrh]` - W parameter weight matrix for update, reset, and hidden gates
* `R[zrh]` - R recurrence weight matrix for update, reset, and hidden gates
* `Wb[zrh]` - W bias vectors for update, reset, and hidden gates
* `Rb[zrh]` - R bias vectors for update, reset, and hidden gates
* `WB[zrh]` - W parameter weight matrix for backward update, reset, and hidden gates
* `RB[zrh]` - R recurrence weight matrix for backward update, reset, and hidden gates
* `WBb[zrh]` - W bias vectors for backward update, reset, and hidden gates
* `RBb[zrh]` - R bias vectors for backward update, reset, and hidden gates
* `H` - Hidden state
* `num_directions` - 2 if direction == bidirectional else 1
Activation functions:
* Relu(x) - max(0, x)
* Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
* Sigmoid(x) - 1/(1 + e^{-x})
NOTE:
Below are optional
* Affine(x) - alpha * x + beta
* LeakyRelu(x) - x if x >= 0 else alpha * x
* ThresholdedRelu(x) - x if x >= alpha else 0
* ScaledTanh(x) - alpha * Tanh(beta * x)
* HardSigmoid(x) - min(max(alpha * x + beta, 0), 1)
* Elu(x) - x if x >= 0 else alpha * (e^x - 1)
* Softsign(x) - x/(1 + |x|)
* Softplus(x) - log(1 + e^x)
Equations (Default: f=Sigmoid, g=Tanh):
* zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
* rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
* ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
* ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
* Ht = (1 - zt) (.) ht + zt (.) Ht-1
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
GRU,
22,
OpSchema()
.SetDoc(GET_OP_DOC_STR(std::string(GRU_ver22_doc) + GenerateOptionalArgumentsDoc()))
.Attr(
"activations",
"A list of 2 (or 4 if bidirectional) activation functions "
"for update, reset, and hidden gates. The activation functions must be one "
"of the activation functions specified above. Optional: See the equations "
"for default if not specified.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"linear_before_reset",
"When computing the output of the hidden gate, "
"apply the linear transformation before multiplying by the output of the "
"reset gate.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(
1,
"W",
"The weight tensor for the gates. Concatenation of `W[zrh]` and `WB[zrh]` "
"(if bidirectional) along dimension 0. This tensor has shape "
"`[num_directions, 3*hidden_size, input_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
2,
"R",
"The recurrence weight tensor. Concatenation of `R[zrh]` and `RB[zrh]` "
"(if bidirectional) along dimension 0. This tensor has shape "
"`[num_directions, 3*hidden_size, hidden_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
3,
"B",
"The bias tensor for the gates. Concatenation of `[Wb[zrh], Rb[zrh]]` and "
"`[WBb[zrh], RBb[zrh]]` (if bidirectional) along dimension 0. This tensor "
"has shape `[num_directions, 6*hidden_size]`. Optional: If not specified "
"- assumed to be 0",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable)
.FillUsing(RNNDocGenerator("GRU")));
static const char* LSTM_ver22_doc = R"DOC(
Computes an one-layer LSTM. This operator is usually supported via some
custom implementation such as CuDNN.
Notations:
* `X` - input tensor
* `i` - input gate
* `o` - output gate
* `f` - forget gate
* `c` - cell gate
* `t` - time step (t-1 means previous time step)
* `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates
* `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates
* `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates
* `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates
* `P[iof]` - P peephole weight vector for input, output, and forget gates
* `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates
* `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates
* `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates
* `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates
* `PB[iof]` - P peephole weight vector for backward input, output, and forget gates
* `H` - Hidden state
* `num_directions` - 2 if direction == bidirectional else 1
Activation functions:
* Relu(x) - max(0, x)
* Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x})
* Sigmoid(x) - 1/(1 + e^{-x})
NOTE: Below are optional
* Affine(x) - alpha*x + beta
* LeakyRelu(x) - x if x >= 0 else alpha * x
* ThresholdedRelu(x) - x if x >= alpha else 0
* ScaledTanh(x) - alpha*Tanh(beta*x)
* HardSigmoid(x) - min(max(alpha*x + beta, 0), 1)
* Elu(x) - x if x >= 0 else alpha*(e^x - 1)
* Softsign(x) - x/(1 + |x|)
* Softplus(x) - log(1 + e^x)
Equations (Default: f=Sigmoid, g=Tanh, h=Tanh):
* it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
* ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
* ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
* Ct = ft (.) Ct-1 + it (.) ct
* ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
* Ht = ot (.) h(Ct)
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
LSTM,
22,
OpSchema()
.SetDoc(GET_OP_DOC_STR(std::string(LSTM_ver22_doc) + GenerateOptionalArgumentsDoc()))
.Attr(
"activations",
"A list of 3 (or 6 if bidirectional) activation functions "
"for input, output, forget, cell, and hidden. The activation functions must "
"be one of the activation functions specified above. Optional: See the equations "
"for default if not specified.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"layout",
"The shape format of inputs X, initial_h, initial_c and outputs Y, Y_h, Y_c. "
"If 0, the following shapes are expected: "
"X.shape = [seq_length, batch_size, input_size], "
"Y.shape = [seq_length, num_directions, batch_size, hidden_size], "
"initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = "
"[num_directions, batch_size, hidden_size]. "
"If 1, the following shapes are expected: "
"X.shape = [batch_size, seq_length, input_size], "
"Y.shape = [batch_size, seq_length, num_directions, hidden_size], "
"initial_h.shape = Y_h.shape = initial_c.shape = Y_c.shape = "
"[batch_size, num_directions, hidden_size].",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("input_forget", "Couple the input and forget gates if 1.", AttributeProto::INT, static_cast<int64_t>(0))
.Input(
1,
"W",
"The weight tensor for the gates. Concatenation of `W[iofc]` and "
"`WB[iofc]` (if bidirectional) along dimension 0. The tensor has shape "
"`[num_directions, 4*hidden_size, input_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
2,
"R",
"The recurrence weight tensor. Concatenation of `R[iofc]` and "
"`RB[iofc]` (if bidirectional) along dimension 0. This tensor has shape "
"`[num_directions, 4*hidden_size, hidden_size]`.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
3,
"B",
"The bias tensor for input gate. Concatenation of `[Wb[iofc], Rb[iofc]]`, "
"and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0. This "
"tensor has shape `[num_directions, 8*hidden_size]`. Optional: If not "
"specified - assumed to be 0.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable)
.Input(
6,
"initial_c",
"Optional initial value of the cell. If not specified - assumed "
"to be 0. It has shape `[num_directions, batch_size, hidden_size]`.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable)
.Input(
7,
"P",
"The weight tensor for peepholes. Concatenation of `P[iof]` and "
"`PB[iof]` (if bidirectional) along dimension 0. It has shape "
"`[num_directions, 3*hidde_size]`. Optional: If not specified - "
"assumed to be 0.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable)
.FillUsing(RNNDocGenerator("LSTM"))
.Output(
2,
"Y_c",
"The last output value of the cell. It has shape "
"`[num_directions, batch_size, hidden_size]`.",
"T",
OpSchema::Optional,
true,
1,
OpSchema::Differentiable));
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,788 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <numeric>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* SequenceEmpty_ver11_doc = R"DOC(
Construct an empty tensor sequence, with given data type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceEmpty,
11,
OpSchema()
.SetDoc(SequenceEmpty_ver11_doc)
.Attr(
"dtype",
"(Optional) The data type of the tensors in the output sequence. "
"The default type is 'float'.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Output(0, "output", "Empty sequence.", "S")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain output types to any tensor type.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto* attr_proto = ctx.getAttribute("dtype");
auto elem_type = TensorProto::FLOAT;
if (nullptr != attr_proto) {
if (!attr_proto->has_i()) {
fail_type_inference("Attribute dtype should be of integer type and specify a type.");
}
auto attr_value = attr_proto->i();
elem_type = static_cast<TensorProto_DataType>(attr_value);
}
ctx.getOutputType(0)->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type()->set_elem_type(
elem_type);
}));
static const char* SequenceConstruct_ver11_doc = R"DOC(
Construct a tensor sequence containing 'inputs' tensors.
All tensors in 'inputs' must have the same data type.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceConstruct,
11,
OpSchema()
.SetDoc(SequenceConstruct_ver11_doc)
.Input(0, "inputs", "Tensors.", "T", OpSchema::Variadic)
.Output(0, "output_sequence", "Sequence enclosing the input tensors.", "S")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input types to any tensor type.")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain output types to any tensor type.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs < 1) {
fail_type_inference("SequenceConstruct is expected to have at least 1 input.");
}
std::vector<int> input_elem_types;
input_elem_types.reserve(numInputs);
for (size_t i = 0; i < numInputs; ++i) {
auto input_type = ctx.getInputType(i);
if (nullptr == input_type) {
fail_type_inference("Input type for input at index ", i, " is null. Type info is expected.");
}
input_elem_types.emplace_back(input_type->tensor_type().elem_type());
}
if (std::adjacent_find(input_elem_types.begin(), input_elem_types.end(), std::not_equal_to<int>()) !=
input_elem_types.end()) {
// not all input elem types are the same.
fail_type_inference("Element type of inputs are expected to be the same.");
}
auto* output_tensor_type =
ctx.getOutputType(0)->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
output_tensor_type->set_elem_type(static_cast<TensorProto_DataType>(input_elem_types[0]));
if (!hasNInputShapes(ctx, static_cast<int>(numInputs))) {
return;
}
*(output_tensor_type->mutable_shape()) = ctx.getInputType(0)->tensor_type().shape();
for (size_t i = 1; i < numInputs; ++i) {
const auto& input_shape = ctx.getInputType(i)->tensor_type().shape();
UnionShapeInfo(input_shape, *output_tensor_type);
}
}));
static const char* SequenceInsert_ver11_doc = R"DOC(
Outputs a tensor sequence that inserts 'tensor' into 'input_sequence' at 'position'.
'tensor' must have the same data type as 'input_sequence'.
Accepted range for 'position' is in `[-n, n]`, where `n` is the number of tensors in 'input_sequence'.
Negative value means counting positions from the back.
'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceInsert,
11,
OpSchema()
.SetDoc(SequenceInsert_ver11_doc)
.Input(0, "input_sequence", "Input sequence.", "S")
.Input(1, "tensor", "Input tensor to be inserted into the input sequence.", "T")
.Input(
2,
"position",
"Position in the sequence where the new tensor is inserted. "
"It is optional and default is to insert to the back of the sequence. "
"Negative value means counting positions from the back. "
"Accepted range in `[-n, n]`, "
"where `n` is the number of tensors in 'input_sequence'. "
"It is an error if any of the index values are out of bounds. "
"It must be a scalar(tensor of empty shape).",
"I",
OpSchema::Optional)
.Output(0, "output_sequence", "Output sequence that contains the inserted tensor at given position.", "S")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain to any tensor type.")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain to any tensor type.")
.TypeConstraint(
"I",
{"tensor(int32)", "tensor(int64)"},
"Constrain position to integral tensor. It must be a scalar(tensor of empty shape).")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto input0_type = ctx.getInputType(0);
const auto input1_type = ctx.getInputType(1);
if (nullptr == input0_type || nullptr == input1_type) {
fail_type_inference("Input Sequence and Tensor are expected to have type info. Current type is null.");
}
const auto seq_elem_type = input0_type->sequence_type().elem_type().tensor_type().elem_type();
const auto tensor_elem_type = input1_type->tensor_type().elem_type();
if (seq_elem_type != tensor_elem_type) {
fail_type_inference(
"Input Sequence and Tensor are expected to have the same elem type. Sequence=",
seq_elem_type,
" Tensor=",
tensor_elem_type);
}
auto* output_tensor_type =
ctx.getOutputType(0)->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type();
output_tensor_type->set_elem_type(seq_elem_type);
if (!hasNInputShapes(ctx, 2)) {
return;
}
*(output_tensor_type->mutable_shape()) = input0_type->sequence_type().elem_type().tensor_type().shape();
UnionShapeInfo(input1_type->tensor_type().shape(), *output_tensor_type);
}));
static const char* SequenceAt_ver11_doc = R"DOC(
Outputs a tensor copy from the tensor at 'position' in 'input_sequence'.
Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'.
Negative value means counting positions from the back.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceAt,
11,
OpSchema()
.SetDoc(SequenceAt_ver11_doc)
.Input(0, "input_sequence", "Input sequence.", "S")
.Input(
1,
"position",
"Position of the tensor in the sequence. "
"Negative value means counting positions from the back. "
"Accepted range in `[-n, n - 1]`, "
"where `n` is the number of tensors in 'input_sequence'. "
"It is an error if any of the index values are out of bounds. "
"It must be a scalar(tensor of empty shape).",
"I")
.Output(0, "tensor", "Output tensor at the specified position in the input sequence.", "T")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain to any tensor type.")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain to any tensor type.")
.TypeConstraint(
"I",
{"tensor(int32)", "tensor(int64)"},
"Constrain position to integral tensor. It must be a scalar(tensor of empty shape).")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto input0_type = ctx.getInputType(0);
if (nullptr == input0_type) {
fail_type_inference("Input type for input at index 0 is null. Type info is expected.")
}
ctx.getOutputType(0)->CopyFrom(input0_type->sequence_type().elem_type());
}));
static const char* SequenceErase_ver11_doc = R"DOC(
Outputs a tensor sequence that removes the tensor at 'position' from 'input_sequence'.
Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'.
Negative value means counting positions from the back.
'position' is optional, by default it erases the last tensor from 'input_sequence'.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceErase,
11,
OpSchema()
.SetDoc(SequenceErase_ver11_doc)
.Input(0, "input_sequence", "Input sequence.", "S")
.Input(
1,
"position",
"Position of the tensor in the sequence. "
"Negative value means counting positions from the back. "
"Accepted range in `[-n, n - 1]`, "
"where `n` is the number of tensors in 'input_sequence'. "
"It is an error if any of the index values are out of bounds. "
"It must be a scalar(tensor of empty shape).",
"I",
OpSchema::Optional)
.Output(0, "output_sequence", "Output sequence that has the tensor at the specified position removed.", "S")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain to any tensor type.")
.TypeConstraint(
"I",
{"tensor(int32)", "tensor(int64)"},
"Constrain position to integral tensor. It must be a scalar(tensor of empty shape).")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto input0_type = ctx.getInputType(0);
if (nullptr == input0_type) {
fail_type_inference("Input type for input at index 0 is null. Type info is expected.")
}
ctx.getOutputType(0)->CopyFrom(*input0_type);
}));
static const char* SequenceLength_ver11_doc = R"DOC(
Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SequenceLength,
11,
OpSchema()
.SetDoc(SequenceLength_ver11_doc)
.Input(0, "input_sequence", "Input sequence.", "S")
.Output(0, "length", "Length of input sequence. It must be a scalar(tensor of empty shape).", "I")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain to any tensor type.")
.TypeConstraint(
"I",
{"tensor(int64)"},
"Constrain output to integral tensor. It must be a scalar(tensor of empty shape).")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type();
output_tensor_type->set_elem_type(TensorProto::INT64);
output_tensor_type->mutable_shape()->Clear();
}));
// Updated operators that consume/produce sequence of tensors.
static const char* SplitToSequence_ver11_doc =
R"DOC(
Split a tensor into a sequence of tensors, along the specified 'axis'.
Lengths of the parts can be specified using the optional argument 'split'.
If the argument `split' is not specified, a default scalar value of 1
is used as the value of `split'.
'split' must contain only positive numbers.
'split' is either a scalar (tensor of empty shape), or a 1-D tensor.
If 'split' is a scalar, then 'input' will be split into chunks all of size 'split'
if possible. The last chunk alone may be smaller than 'split' if the 'input' size
along the given axis 'axis' is not divisible by 'split'.
If 'split' is a 1-dimensional tensor, the input tensor is split into 'size(split)' chunks,
with lengths of the parts on 'axis' specified in 'split'. In this scenario, the sum of entries
in 'split' must be equal to the dimension size of input tensor on 'axis'.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
SplitToSequence,
11,
OpSchema()
.Input(0, "input", "The tensor to split", "T")
.Input(
1,
"split",
"Length of each output. "
"It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be >= 0. ",
"I",
OpSchema::Optional)
.Output(0, "output_sequence", "One or more outputs forming a sequence of tensors after splitting", "S")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain input types to all tensor types.")
.TypeConstraint("I", {"tensor(int32)", "tensor(int64)"}, "Constrain split size to integral tensor.")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain output types to all tensor types.")
.Attr(
"axis",
"Which axis to split on. "
"A negative value means counting dimensions from the back. Accepted range is [-rank, rank-1].",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"keepdims",
"Keep the split dimension or not. Default 1, which means we keep split dimension. "
"If input 'split' is specified, this attribute is ignored.",
AttributeProto::INT,
static_cast<int64_t>(1))
.SetDoc(SplitToSequence_ver11_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto input0_type = ctx.getInputType(0);
if (nullptr == input0_type) {
fail_type_inference("Input type for input at index 0 is null. Type info is expected.")
}
ctx.getOutputType(0)->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type()->set_elem_type(
input0_type->tensor_type().elem_type());
if (!hasInputShape(ctx, 0)) {
return;
}
const auto& inputShape = input0_type->tensor_type().shape();
int r = inputShape.dim_size();
int axis = static_cast<int>(getAttribute(ctx, "axis", 0));
if (axis < -r || axis > r - 1) {
fail_shape_inference("Invalid value of attribute 'axis'. Rank=", r, " Value=", axis);
}
if (axis < 0) {
axis += r;
}
size_t num_inputs = ctx.getNumInputs();
int64_t splitSize = 1;
int64_t keepdims = 1;
if (num_inputs == 1) {
// input split is omitted, default to split by 1.
auto attr_proto = ctx.getAttribute("keepdims");
if (attr_proto) {
keepdims = attr_proto->i();
}
} else {
splitSize = [&]() -> int64_t {
// Need input split shape info and initializer data to infer split sizes.
if (!hasInputShape(ctx, 1)) {
return -1;
}
const TensorProto* splitInitializer = ctx.getInputData(1);
if (nullptr == splitInitializer || !splitInitializer->has_data_type()) {
return -1;
}
std::vector<int64_t> splitSizes;
if (splitInitializer->data_type() == TensorProto::INT64) {
const auto& data = ParseData<int64_t>(splitInitializer);
splitSizes.insert(splitSizes.end(), data.begin(), data.end());
} else if (splitInitializer->data_type() == TensorProto::INT32) {
const auto& data = ParseData<int32_t>(splitInitializer);
splitSizes.insert(splitSizes.end(), data.begin(), data.end());
} else {
// unaccepted data type
fail_shape_inference("Only supports `int32_t` or `int64_t` inputs for split");
}
if (splitSizes.size() == 0) {
fail_shape_inference("Input 'split' can not be empty.");
}
const auto& splitDim = inputShape.dim(axis);
if (!splitDim.has_dim_value()) {
// Unable to verify nor infer exact split dimension size.
return -1;
}
int64_t splitDimValue = splitDim.dim_value();
const auto& splitShape = getInputShape(ctx, 1);
if (splitShape.dim_size() == 0) {
// split is scalar
if (splitDimValue % splitSizes[0] == 0) {
// all output chunks have the same shape, assign that to output sequence shape.
return splitSizes[0];
}
return -1;
} else {
// split is 1-D tensor
int64_t splitSizesSum = std::accumulate(splitSizes.begin(), splitSizes.end(), (int64_t)0);
if (splitDimValue != splitSizesSum) {
fail_shape_inference(
"Sum of split values not equal to 'input' dim size on 'axis'. 'axis' dim size=",
splitDimValue,
" sum of split values=",
splitSizesSum);
}
if (std::adjacent_find(splitSizes.begin(), splitSizes.end(), std::not_equal_to<int64_t>()) ==
splitSizes.end()) {
// all split sizes are the same.
return splitSizes[0];
}
return -1;
}
}();
}
if (keepdims) {
auto* outputShape = ctx.getOutputType(0)
->mutable_sequence_type()
->mutable_elem_type()
->mutable_tensor_type()
->mutable_shape();
*outputShape = inputShape;
auto* dim = outputShape->mutable_dim(axis);
// Tensors in sequence could not have different shapes explicitly.
// Only assign dim_value when all chunks have the same shape.
if (splitSize > 0) {
dim->set_dim_value(splitSize);
} else {
dim->clear_dim_value();
dim->clear_dim_param();
}
} else {
TensorShapeProto* outputShape = ctx.getOutputType(0)
->mutable_sequence_type()
->mutable_elem_type()
->mutable_tensor_type()
->mutable_shape();
for (int i = 0; i < inputShape.dim_size(); ++i) {
if (i != axis) {
auto* dim = outputShape->add_dim();
dim->CopyFrom(inputShape.dim(i));
}
}
}
}));
static const char* ConcatFromSequence_ver11_doc = R"DOC(
Concatenate a sequence of tensors into a single tensor.
All input tensors must have the same shape, except for the dimension size of the axis to concatenate on.
By default 'new_axis' is 0, the behavior is similar to numpy.concatenate.
When 'new_axis' is 1, the behavior is similar to numpy.stack.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ConcatFromSequence,
11,
OpSchema()
.Attr(
"axis",
"Which axis to concat on. Accepted range in `[-r, r - 1]`, "
"where `r` is the rank of input tensors. "
"When `new_axis` is 1, accepted range is `[-r - 1, r]`. ",
AttributeProto::INT)
.Attr(
"new_axis",
"Insert and concatenate on a new axis or not, "
"default 0 means do not insert new axis.",
AttributeProto::INT,
static_cast<int64_t>(0))
.SetDoc(ConcatFromSequence_ver11_doc)
.Input(0, "input_sequence", "Sequence of tensors for concatenation", "S")
.Output(0, "concat_result", "Concatenated tensor", "T")
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain input types to any tensor type.")
.TypeConstraint("T", OpSchema::all_tensor_types(), "Constrain output types to any tensor type.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const auto input0_type = ctx.getInputType(0);
if (nullptr == input0_type) {
fail_type_inference("Input type for input at index 0 is null. Type info is expected.")
}
auto elem_type = input0_type->sequence_type().elem_type().tensor_type().elem_type();
ctx.getOutputType(0)->mutable_tensor_type()->set_elem_type(elem_type);
if (!hasInputShape(ctx, 0)) {
return;
}
auto axis_attr = ctx.getAttribute("axis");
if (!axis_attr) {
fail_shape_inference("Required attribute axis is missing");
}
int axis = static_cast<int>(axis_attr->i());
int new_axis = 0;
auto new_axis_attr = ctx.getAttribute("new_axis");
if (new_axis_attr) {
new_axis = static_cast<int>(new_axis_attr->i());
}
const auto& input_shape = ctx.getInputType(0)->sequence_type().elem_type().tensor_type().shape();
auto rank = input_shape.dim_size();
if (1 != new_axis && 0 != new_axis) {
fail_shape_inference("new_axis must be either 0 or 1");
}
auto upper_bound = 1 == new_axis ? rank : rank - 1;
auto lower_bound = 1 == new_axis ? -rank - 1 : -rank;
if (axis < lower_bound || axis > upper_bound) {
fail_shape_inference(
"Invalid value of attribute 'axis'. Accepted range=[",
lower_bound,
", ",
upper_bound,
"], Value=",
axis);
}
if (axis < 0) {
axis += (upper_bound + 1);
}
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
for (int i = 0; i <= upper_bound; ++i) {
output_shape->add_dim();
if (i != axis) {
output_shape->mutable_dim(i)->CopyFrom(input_shape.dim((i > axis && new_axis) ? i - 1 : i));
}
}
}));
static const char* SequenceMap_ver17_doc = R"DOC(
Applies a sub-graph to each sample in the input sequence(s).
Inputs can be either tensors or sequences, with the exception of the first input which must
be a sequence. The length of the first input sequence will determine the number of samples in the
outputs. Any other sequence inputs should have the same number of samples. The number of inputs
and outputs, should match the one of the subgraph.
For each i-th element in the output, a sample will be extracted from the input sequence(s) at
the i-th position and the sub-graph will be applied to it.
The outputs will contain the outputs of the sub-graph for each sample, in the same order as in
the input.
This operator assumes that processing each sample is independent and could executed in parallel
or in any order. Users cannot expect any specific ordering in which each subgraph is computed.)DOC";
void SequenceMapInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
assert(num_inputs > 0);
auto num_outputs = ctx.getNumOutputs();
assert(num_outputs > 0);
std::vector<TypeProto> tmp_type_protos(num_inputs);
std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);
for (size_t inputIndex = 0; inputIndex < num_inputs; inputIndex++) {
auto input_type = ctx.getInputType(inputIndex);
if (input_type == nullptr) {
fail_type_inference("Input ", inputIndex, " expected to have type info");
}
if (input_type->value_case() == TypeProto::kSequenceType) {
tmp_type_protos[inputIndex].CopyFrom(input_type->sequence_type().elem_type());
subgraph_input_types.push_back(&tmp_type_protos[inputIndex]);
} else {
if (inputIndex == 0)
fail_type_inference("Input ", inputIndex, " expected to be a sequence type");
subgraph_input_types.push_back(input_type);
}
}
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
if (!graphInferencer)
fail_type_inference("Graph attribute inferencer for \"body\" not available");
std::vector<const TensorProto*> input_data(num_inputs, nullptr);
std::vector<const TypeProto*> subgraph_output_types =
graphInferencer->doInferencing(subgraph_input_types, input_data);
// if empty(), assume inferencing was skipped
if (!subgraph_output_types.empty()) {
if (subgraph_output_types.size() != num_outputs) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
subgraph_output_types.size(),
" outputs. Expected ",
num_outputs);
}
for (size_t outputIndex = 0; outputIndex < num_outputs; outputIndex++) {
auto* subgraph_output_type = subgraph_output_types[outputIndex];
ctx.getOutputType(outputIndex)->mutable_sequence_type()->mutable_elem_type()->CopyFrom(*subgraph_output_type);
}
}
}
bool BuildSequenceMapBodyFunc(
const FunctionBodyBuildContext& ctx,
const OpSchema& schema,
FunctionProto& functionProto) {
schema.BuildFunction(functionProto);
// variadic input/outputs will be expanded
functionProto.clear_input();
functionProto.clear_output();
auto body_attr = ctx.getAttribute("body");
if (!body_attr || !body_attr->has_g())
ONNX_THROW_EX(std::invalid_argument("Invalid ``body`` argument. Expected a graph"));
const GraphProto& body = body_attr->g();
auto g_inputs = body.input();
int ninputs = g_inputs.size();
if (ninputs < 1)
ONNX_THROW_EX(std::invalid_argument("Expected 1 or more inputs."));
auto g_outputs = body.output();
int noutputs = g_outputs.size();
if (noutputs < 1)
ONNX_THROW_EX(std::invalid_argument("Expected 1 or more outputs."));
if (!ctx.hasInput(0))
ONNX_THROW_EX(std::invalid_argument(MakeString("Input 0 expected but not provided")));
const auto* first_input_type = ctx.getInputType(0);
assert(first_input_type);
if (!first_input_type->has_sequence_type())
ONNX_THROW_EX(std::invalid_argument("Expected a sequence type for input 0"));
auto schema_inputs = schema.inputs();
auto input_0_name = schema_inputs[0].GetName();
auto input_1_name = schema_inputs[1].GetName(); // variadic input
*functionProto.add_input() = input_0_name;
for (int i = 1; i < ninputs; i++) {
if (!ctx.hasInput(i))
ONNX_THROW_EX(std::invalid_argument(MakeString("Input ", i, " expected but not provided")));
*functionProto.add_input() = MakeString(input_1_name, "_", i);
}
auto schema_outputs = schema.outputs();
auto output_0_name = schema_outputs[0].GetName();
for (int i = 0; i < noutputs; i++) {
if (!ctx.hasOutput(i))
ONNX_THROW_EX(std::invalid_argument(MakeString("Output ", i, " expected but not provided")));
*functionProto.add_output() = MakeString(output_0_name, "_", i);
}
// Loop body subgraph
std::string loopbody_graph_name("SequenceMap_loop_body");
GraphProto loopbody_graph;
loopbody_graph.set_name(loopbody_graph_name);
{
TypeProto int64_type;
int64_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
int64_type.mutable_tensor_type()->mutable_shape()->Clear();
TypeProto bool_type;
bool_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
bool_type.mutable_tensor_type()->mutable_shape()->Clear();
ValueInfoProto iter_count;
std::string iter_count_name = MakeString(loopbody_graph_name, "_itercount");
iter_count.set_name(iter_count_name);
*iter_count.mutable_type() = int64_type;
*loopbody_graph.add_input() = iter_count;
ValueInfoProto cond_in;
std::string cond_in_name = MakeString(loopbody_graph_name, "_cond_in");
cond_in.set_name(cond_in_name);
*cond_in.mutable_type() = bool_type;
*loopbody_graph.add_input() = cond_in;
ValueInfoProto cond_out;
std::string cond_out_name = MakeString(loopbody_graph_name, "_cond_out");
cond_out.set_name(cond_out_name);
*cond_out.mutable_type() = bool_type;
*loopbody_graph.add_output() = cond_out;
NodeProto cond_identity;
cond_identity.set_domain(ONNX_DOMAIN);
cond_identity.set_op_type("Identity");
cond_identity.add_input(cond_in_name);
cond_identity.add_output(cond_out_name);
*loopbody_graph.add_node() = cond_identity;
for (int inputIndex = 0; inputIndex < ninputs; inputIndex++) {
const auto* input_type = ctx.getInputType(inputIndex);
if (input_type && input_type->has_sequence_type()) {
// If it's a sequence input, extract ``iter_count`` element
NodeProto seq_at_node;
seq_at_node.set_domain(ONNX_DOMAIN);
seq_at_node.set_op_type("SequenceAt");
seq_at_node.add_input(functionProto.input(inputIndex));
seq_at_node.add_input(iter_count_name);
seq_at_node.add_output(g_inputs.Get(inputIndex).name());
*loopbody_graph.add_node() = seq_at_node;
} else {
// If not a sequence, simply connect
NodeProto identity;
identity.set_domain(ONNX_DOMAIN);
identity.set_op_type("Identity");
identity.add_input(functionProto.input(inputIndex));
identity.add_output(g_inputs.Get(inputIndex).name());
*loopbody_graph.add_node() = identity;
}
}
for (const auto& item : body.node())
*loopbody_graph.add_node() = item;
for (const auto& item : body.value_info())
*loopbody_graph.add_value_info() = item;
for (const auto& item : body.initializer())
*loopbody_graph.add_initializer() = item;
for (const auto& item : body.sparse_initializer())
*loopbody_graph.add_sparse_initializer() = item;
for (int outputIndex = 0; outputIndex < noutputs; outputIndex++) {
const auto& body_out_i = body.output(outputIndex);
assert(body_out_i.type().has_tensor_type());
std::string prefix = MakeString(loopbody_graph_name, "_", body_out_i.name());
std::string loopbody_in_name = MakeString(prefix, "_in");
ValueInfoProto tmp;
*tmp.mutable_type()->mutable_sequence_type()->mutable_elem_type()->mutable_tensor_type() =
body_out_i.type().tensor_type();
tmp.set_name(loopbody_in_name);
*loopbody_graph.add_input() = tmp;
std::string loopbody_out_name = MakeString(prefix, "_out");
tmp.set_name(loopbody_out_name);
*loopbody_graph.add_output() = tmp;
NodeProto seq_insert_node;
seq_insert_node.set_domain(ONNX_DOMAIN);
seq_insert_node.set_op_type("SequenceInsert");
seq_insert_node.add_input(loopbody_in_name);
seq_insert_node.add_input(body_out_i.name());
seq_insert_node.add_output(loopbody_out_name);
*loopbody_graph.add_node() = seq_insert_node;
}
}
std::vector<FunctionBodyHelper::NodeDef> nodes;
// TODO: figure out a way to prevent name collisions?
auto first_input_name = functionProto.input(0);
std::string prefix = MakeString("SequenceMap_", first_input_name);
std::string seqlen = MakeString(prefix, "_seqlen");
nodes.push_back({{seqlen}, "SequenceLength", {first_input_name}});
std::string cond_bool = MakeString(prefix, "_cond");
nodes.push_back(FunctionBodyHelper::Const<bool>(cond_bool, true));
std::vector<std::string> loop_node_inputs = {seqlen, cond_bool};
std::vector<std::string> loop_node_outputs;
for (int outputIndex = 0; outputIndex < noutputs; outputIndex++) {
auto output_name = functionProto.output(outputIndex);
std::string out_prefix = MakeString("SequenceMap_", output_name);
std::string seqempty_name = MakeString(out_prefix, "_seqempty");
int64_t dtype = g_outputs.Get(outputIndex).type().tensor_type().elem_type();
nodes.push_back({{seqempty_name}, "SequenceEmpty", {}, {MakeAttribute("dtype", dtype)}});
loop_node_inputs.push_back(seqempty_name);
loop_node_outputs.push_back(output_name);
}
nodes.push_back({loop_node_outputs, "Loop", loop_node_inputs, {MakeAttribute("body", loopbody_graph)}});
auto func_nodes = FunctionBodyHelper::BuildNodes(nodes);
for (const auto& node : func_nodes) {
auto new_node = functionProto.add_node();
new_node->CopyFrom(node);
}
return true;
}
ONNX_OPERATOR_SET_SCHEMA(
SequenceMap,
17,
OpSchema()
.SetDoc(SequenceMap_ver17_doc)
.Attr(
"body",
"The graph to be run for each sample in the sequence(s). "
"It should have as many inputs and outputs as inputs and "
"outputs to the SequenceMap function.",
AttributeProto::GRAPH)
.Input(0, "input_sequence", "Input sequence.", "S")
.Input(1, "additional_inputs", "Additional inputs to the graph", "V", OpSchema::Variadic, false, 0)
.Output(0, "out_sequence", "Output sequence(s)", "S", OpSchema::Variadic, false)
.TypeConstraint("S", OpSchema::all_tensor_sequence_types(), "Constrain input types to any sequence type.")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain to any tensor or sequence type.")
.SetContextDependentFunctionBodyBuilder(BuildSequenceMapBodyFunc)
.TypeAndShapeInferenceFunction(SequenceMapInferenceFunction));
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,550 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "shape_inference.h"
#include <vector>
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// Note: for all methods below for propagating type or shape, callers are
// responsible to handle optional inputs/outputs and ensure that the specified
// index value is less than NumInputs/NumOutputs.
// Supports mixed tensor and sparse tensor
void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case != TypeProto::kTensorType && input_value_case != TypeProto::kSparseTensorType) {
fail_type_inference(
"Input ", inputIndex, " expected to have tensor or sparse tensor type. Got: ", input_value_case);
}
const auto input_elem_type = getTensorElementType(*input_type);
if (input_elem_type == TensorProto::UNDEFINED) {
fail_type_inference("Element type of input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
setTensorElementType(input_elem_type, output_value_case, *output_type);
} else if (output_value_case == TypeProto::VALUE_NOT_SET) {
// Assume output will have the same type
setTensorElementType(input_elem_type, input_value_case, *output_type);
} else {
// This is not expected to happen
fail_type_inference(
"Output ", outputIndex, " expected to have tensor or sparse tensor type. Got: ", output_value_case);
}
}
void propagateElemTypeFromSequenceInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kSequenceType) {
fail_type_inference("Input ", inputIndex, " expected to have sequence type");
}
auto input_seq_type = input_type->sequence_type();
if (!input_seq_type.has_elem_type()) {
fail_type_inference("Element type of sequence input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_sequence_type()->mutable_elem_type()->CopyFrom(input_seq_type.elem_type());
}
void propagateElemTypeFromOptionalInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kOptionalType) {
fail_type_inference("Input ", inputIndex, " expected to have optional type");
}
auto input_opt_type = input_type->optional_type();
if (!input_opt_type.has_elem_type()) {
fail_type_inference("Element type of optional input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_optional_type()->mutable_elem_type()->CopyFrom(input_opt_type.elem_type());
}
void propagateElemTypeFromMapInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type || input_type->value_case() != TypeProto::kMapType) {
fail_type_inference("Input ", inputIndex, " expected to have map type");
}
auto input_map_type = input_type->map_type();
if (!input_map_type.has_key_type()) {
fail_type_inference("Key type of map input ", inputIndex, " unknown");
}
if (!input_map_type.has_value_type()) {
fail_type_inference("Value type of map input ", inputIndex, " unknown");
}
auto output_type = ctx.getOutputType(outputIndex);
output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
output_type->mutable_map_type()->mutable_value_type()->CopyFrom(input_map_type.value_type());
}
void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto input_type = ctx.getInputType(inputIndex);
if (nullptr == input_type) {
fail_type_inference("Input ", inputIndex, " expected to have type but instead is null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
propagateElemTypeFromTensorInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kSequenceType) {
propagateElemTypeFromSequenceInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kOptionalType) {
propagateElemTypeFromOptionalInputToOutput(ctx, inputIndex, outputIndex);
} else if (input_value_case == TypeProto::kMapType) {
propagateElemTypeFromMapInputToOutput(ctx, inputIndex, outputIndex);
}
}
/*
Merge shape information from a source shape into a target shape.
* merges each TensorShapeProto_Dimension separately.
* prefer values over params.
* If both have values, values must match.
* prefer target param over source param if mismatched.
* Fail if there are mismatches in number of dimensions or dimension values.
*/
void mergeInShapeInfo(const TensorShapeProto& source, TensorShapeProto& target) {
auto num_source_dims = source.dim_size();
auto num_target_dims = target.dim_size();
if (num_source_dims != num_target_dims) {
fail_shape_inference(
"Mismatch between number of inferred and declared dimensions. inferred=",
num_source_dims,
" declared=",
num_target_dims);
}
auto& source_dims = source.dim();
auto* target_dims = target.mutable_dim();
for (int i = 0, end = source_dims.size(); i < end; ++i) {
auto& source_dim = source_dims.Get(i);
auto& target_dim = *target_dims->Mutable(i);
mergeInDimensionInfo(source_dim, target_dim, i);
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
if (target_type.has_shape()) {
// merge with existing info.
mergeInShapeInfo(source_shape, *target_type.mutable_shape());
} else {
// copy to target
(*target_type.mutable_shape()) = source_shape;
}
}
/*
Merge the shape information from two TypeProto_Tensor instances.
Values are merged into target from source.
If target has no shape information, copy from source.
If source has no shape information, ignore source.
If both have shape information:
- merge each TensorShapeProto_Dimension separately.
- Prefer values over params. If both have values, values must match.
- Prefer target param over source param if mismatched.
Fail if there are mismatches in number of dimensions or dimension values.
*/
void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target) {
if (source.has_shape())
mergeInShapeInfo(source.shape(), target);
}
/// <summary>
/// Utility function for UnionShapeInfoForTensor.
/// Both shapes must be of the same rank
/// </summary>
/// <param name="source_shape"></param>
/// <param name="target_shape">destination shape</param>
void UnionShapeInfo(const TensorShapeProto& source_shape, TensorShapeProto& target_shape) {
auto source_rank = source_shape.dim_size();
for (int i = 0; i < source_rank; ++i) {
const auto source_dim = source_shape.dim(i);
const auto target_dim = target_shape.dim(i);
bool is_dims_conflict = [&]() {
if (source_dim.has_dim_value()) {
if (target_dim.has_dim_value() && target_dim.dim_value() == source_dim.dim_value()) {
return false;
}
return true;
}
if (source_dim.has_dim_param()) {
if (target_dim.has_dim_param() && target_dim.dim_param() == source_dim.dim_param()) {
return false;
}
return true;
}
return (target_dim.has_dim_value() || target_dim.has_dim_param());
}();
if (is_dims_conflict && (target_dim.has_dim_value() || target_dim.has_dim_param())) {
auto dim = target_shape.mutable_dim(i);
dim->clear_dim_value();
dim->clear_dim_param();
}
}
}
template <typename TENSOR_TYPE>
void UnionShapeInfoForTensor(const TensorShapeProto& source_shape, TENSOR_TYPE& target_type) {
if (target_type.has_shape()) {
TensorShapeProto* target_shape = target_type.mutable_shape();
auto source_rank = source_shape.dim_size();
auto target_rank = target_shape->dim_size();
if (source_rank != target_rank) {
target_type.clear_shape();
return;
}
UnionShapeInfo(source_shape, *target_shape);
}
}
void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type) {
UnionShapeInfoForTensor(source_shape, target_type);
}
void UnionShapeInfo(const TypeProto_Tensor& source_type, TypeProto_Tensor& target_type) {
// The union of a tensor of unknown rank and a tensor of known rank is a tensor of unknown rank.
// Hence, if the source_type had unknown rank, we clear the shape of the target_type.
// Otherwise, UnionShapeInfoForTensor handles the rest.
if (source_type.has_shape()) {
UnionShapeInfoForTensor(source_type.shape(), target_type);
} else {
target_type.clear_shape();
}
}
void UnionShapeInfo(const TypeProto_SparseTensor& source_type, TypeProto_SparseTensor& target_type) {
// The union of a tensor of unknown rank and a tensor of known rank is a tensor of unknown rank.
// Hence, if the source_type had unknown rank, we clear the shape of the target_type.
// Otherwise, UnionShapeInfoForTensor handles the rest.
if (source_type.has_shape()) {
UnionShapeInfoForTensor(source_type.shape(), target_type);
} else {
target_type.clear_shape();
}
}
void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type) {
UnionShapeInfoForTensor(source_shape, target_type);
}
void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type) {
if (source_type.value_case() != target_type.value_case()) {
fail_type_inference(
"Mismatched type:", " inferred=", source_type.value_case(), " declared=", target_type.value_case());
}
const auto target_case = target_type.value_case();
if (target_case == TypeProto::ValueCase::kTensorType) {
auto source_elem_type = source_type.tensor_type().elem_type();
auto target_elem_type = target_type.tensor_type().elem_type();
if (source_elem_type != target_elem_type) {
fail_type_inference(
"Mismatched tensor element type:",
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(source_elem_type),
" declared=",
Utils::DataTypeUtils::ToDataTypeString(target_elem_type));
}
UnionShapeInfo(source_type.tensor_type(), *target_type.mutable_tensor_type());
} else if (target_case == TypeProto::ValueCase::kSparseTensorType) {
auto source_elem_type = source_type.sparse_tensor_type().elem_type();
auto target_elem_type = target_type.sparse_tensor_type().elem_type();
if (source_elem_type != target_elem_type) {
fail_type_inference(
"Mismatched sparse tensor element type:",
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(source_elem_type),
" declared=",
Utils::DataTypeUtils::ToDataTypeString(target_elem_type));
}
UnionShapeInfo(source_type.sparse_tensor_type(), *target_type.mutable_sparse_tensor_type());
} else if (target_case == TypeProto::ValueCase::kSequenceType) {
if (!source_type.sequence_type().has_elem_type()) {
fail_type_inference("source sequence type missing element type.");
}
if (!target_type.sequence_type().has_elem_type()) {
fail_type_inference("target sequence type missing element type.");
}
UnionTypeInfo(source_type.sequence_type().elem_type(), *target_type.mutable_sequence_type()->mutable_elem_type());
} else if (target_case == TypeProto::ValueCase::kOptionalType) {
if (!source_type.optional_type().has_elem_type()) {
fail_type_inference("source optional type missing element type.");
}
if (!target_type.optional_type().has_elem_type()) {
fail_type_inference("target optional type missing element type.");
}
UnionTypeInfo(source_type.optional_type().elem_type(), *target_type.mutable_optional_type()->mutable_elem_type());
} else if (target_case == TypeProto::ValueCase::kMapType) {
if (!source_type.map_type().has_key_type()) {
fail_type_inference("source map type missing key type.");
}
if (!target_type.map_type().has_key_type()) {
fail_type_inference("target map type missing key type.");
}
auto source_key_type = source_type.map_type().key_type();
auto target_key_type = target_type.map_type().key_type();
if (source_key_type != target_key_type) {
fail_type_inference(
"Mismatched map tensor key type:",
" inferred=",
Utils::DataTypeUtils::ToDataTypeString(source_key_type),
" declared=",
Utils::DataTypeUtils::ToDataTypeString(target_key_type));
}
if (!source_type.map_type().has_value_type()) {
fail_type_inference("source map type missing value type.");
}
if (!target_type.map_type().has_value_type()) {
fail_type_inference("target map type missing value type.");
}
UnionTypeInfo(source_type.map_type().value_type(), *target_type.mutable_map_type()->mutable_value_type());
}
}
// Supports both Tensor and SparseTensor
// This does not fail if input_type is Tensor and output type is SparseTensor
// or the other way around. This is to support mixed cases when an op receives
// sparse input and outputs dense or vice-versa.
// If the output value_case is not set, then
// the input value_case is propagated.
void propagateTensorElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
int32_t input_elem_type = TensorProto::UNDEFINED;
const auto input_value_case = input_type->value_case();
if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
input_elem_type = getTensorElementType(*input_type);
if (input_elem_type == TensorProto::UNDEFINED) {
fail_type_inference("Element type of tensor or sparse tensor input was unknown");
}
} else {
fail_type_inference("Input was expected to have tensor or sparse tensor type. Got ", input_value_case);
}
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::VALUE_NOT_SET) {
setTensorElementType(input_elem_type, input_value_case, *output_type);
} else if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
const auto output_elem_type = getTensorElementType(*output_type);
if (output_elem_type != TensorProto::UNDEFINED) {
if (input_elem_type != output_elem_type) {
fail_type_inference(
"Input element type of ", input_elem_type, " does not match existing output type of ", output_elem_type);
}
} else {
setTensorElementType(input_elem_type, output_value_case, *output_type);
}
} else {
// This is not expected to happen
fail_type_inference("Output was expected to have tensor type. Got ", output_value_case);
}
}
void propagateSequenceElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
if (input_type->value_case() != TypeProto::kSequenceType) {
fail_type_inference("Input was expected to have sequence type. Got ", input_type->value_case());
}
auto input_seq_type = input_type->sequence_type();
if (input_seq_type.has_elem_type()) {
propagateElemTypeWithValidation(
&input_seq_type.elem_type(), output_type->mutable_sequence_type()->mutable_elem_type());
} else {
fail_type_inference("Element type of sequence input was unknown");
}
}
void propagateOptionalElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
if (input_type->value_case() != TypeProto::kOptionalType) {
fail_type_inference("Input was expected to have optional type. Got ", input_type->value_case());
}
auto input_opt_type = input_type->optional_type();
if (input_opt_type.has_elem_type()) {
propagateElemTypeWithValidation(
&input_opt_type.elem_type(), output_type->mutable_optional_type()->mutable_elem_type());
} else {
fail_type_inference("Element type of optional input was unknown");
}
}
void propagateMapElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
if (input_type->value_case() != TypeProto::kMapType) {
fail_type_inference("Input was expected to have map type. Got ", input_type->value_case());
}
auto input_map_type = input_type->map_type();
if (!input_map_type.has_key_type()) {
fail_type_inference("Key type of map input was unknown");
}
if (!input_map_type.has_value_type()) {
fail_type_inference("Value type of map input was unknown");
}
output_type->mutable_map_type()->set_key_type(input_map_type.key_type());
propagateElemTypeWithValidation(&input_map_type.value_type(), output_type->mutable_map_type()->mutable_value_type());
}
// propagate the element type from an input type to an output type.
// if an existing output element type exists, validate it matches.
void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type) {
if (nullptr == input_type) {
fail_type_inference("Input type was null");
}
const auto input_value_case = input_type->value_case();
if (input_value_case == TypeProto::kTensorType || input_value_case == TypeProto::kSparseTensorType) {
propagateTensorElemTypeWithValidation(input_type, output_type);
} else if (input_value_case == TypeProto::kSequenceType) {
propagateSequenceElemTypeWithValidation(input_type, output_type);
} else if (input_value_case == TypeProto::kOptionalType) {
propagateOptionalElemTypeWithValidation(input_type, output_type);
} else if (input_value_case == TypeProto::kMapType) {
propagateMapElemTypeWithValidation(input_type, output_type);
} else {
fail_type_inference(
"Input was expected to have either tensor, sequence, optional or map type. Got ", input_value_case);
}
}
TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found) {
TensorShapeProto shape_input;
// First, check initializer.
const TensorProto* shape_initializer = ctx.getInputData(input_index);
if (shape_initializer) {
const std::vector<int64_t>& shape_data = ParseData<int64_t>(shape_initializer);
for (const int64_t& e : shape_data) {
shape_input.add_dim()->set_dim_value(e);
}
found = true;
return shape_input;
}
// Then, check symbolic input.
const TensorShapeProto* symbolic_input = ctx.getSymbolicInput(input_index);
if (symbolic_input) {
shape_input.CopyFrom(*symbolic_input);
found = true;
return shape_input;
}
// Try rank inference.
if (hasInputShape(ctx, input_index)) {
const TensorShapeProto& shape_input_shape = getInputShape(ctx, input_index);
if (shape_input_shape.dim_size() != 1) {
fail_shape_inference("shape input must be 1D tensor");
}
if (shape_input_shape.dim(0).has_dim_value()) {
// Attempt rank inference using shape of shape input
int64_t dim_value = shape_input_shape.dim(0).dim_value();
for (int64_t i = 0; i < dim_value; ++i) {
shape_input.add_dim();
}
found = true;
return shape_input;
}
}
// Shape input was not found.
found = false;
return shape_input;
}
template <typename Container>
std::string stringify(const Container& elements) {
std::stringstream ss;
for (const auto& element : elements) {
ss << element << ", ";
}
return ss.str();
}
std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto* attr_proto) {
if (attr_proto->ints_size()) {
return {TensorProto_DataType_INT64, attr_proto->ints_size()};
} else if (attr_proto->floats_size()) {
return {TensorProto_DataType_FLOAT, attr_proto->floats_size()};
} else if (attr_proto->strings_size()) {
return {TensorProto_DataType_STRING, attr_proto->strings_size()};
} else if (attr_proto->has_t()) {
if (attr_proto->t().dims_size() != 1) {
fail_type_inference(
"Attribute ", attr_proto->name(), " expected to be a 1D tensor but was ", attr_proto->t().dims_size(), "D");
}
return {attr_proto->t().data_type(), attr_proto->t().dims(0)};
}
return {TensorProto::UNDEFINED, 0};
}
std::pair<int, int> getAttributeElementTypeAndLength(
const InferenceContext& ctx,
const std::initializer_list<std::string>& attribute_names) {
// Get element type and lengths of 1D attribute lists
int32_t elem_type = TensorProto::UNDEFINED;
int32_t length = 0;
for (const auto& attribute : attribute_names) {
const AttributeProto* attr_proto = ctx.getAttribute(attribute);
if (attr_proto != nullptr) {
if (elem_type != TensorProto::UNDEFINED) {
// Another attribute was already set
fail_shape_inference("One and only one attribute must be set out of ", stringify(attribute_names));
}
std::tie(elem_type, length) = getAttributeProtoElemTypeAndLength(attr_proto);
}
}
return {elem_type, length};
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,920 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "onnx/defs/data_type_utils.h"
#include "onnx/proto_utils.h"
#include "onnx/string_utils.h"
namespace ONNX_NAMESPACE {
using Dim = TensorShapeProto_Dimension;
struct ShapeInferenceOptions {
// Checks the type-equality for input and output
bool check_type;
// 1: Will throw any node level shape infer errors
// 0: Won't throw node-level shape infer errors, but other errors
// like merging existing shape with inferred etc are thrown
int error_mode;
// Enables data propagation for limited operators
// to perform shape computation
bool enable_data_propagation;
ShapeInferenceOptions(bool check_type_val = false, int strict_mode_val = 0, bool data_prop_val = false)
: check_type(check_type_val), error_mode(strict_mode_val), enable_data_propagation(data_prop_val){};
};
// Maintains a SymbolTable for symbolic shape inference
class SymbolTable {
public:
// Adds existing symbols from a main graph or subgraph
virtual void addFromGraph(const GraphProto& g) = 0;
// Creates a new symbol which is not duplicate as any existing one
std::string createNew() {
return createNew("unk__");
}
virtual std::string createNew(const std::string& symbol_prefix) = 0;
virtual ~SymbolTable() = default;
};
class GraphInferencer {
public:
// Perform inferencing on the graph contained in GraphInferencer.
// Returns the graph output types post-inferencing.
virtual std::vector<const TypeProto*> doInferencing(
const std::vector<const TypeProto*>& inputTypes,
const std::vector<const TensorProto*>& inputData) = 0;
virtual ~GraphInferencer() = default;
};
// Exception class used for handling errors in type and shape inference
class InferenceError final : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
InferenceError(const std::string& message) : std::runtime_error(message) {}
const char* what() const noexcept override {
if (!expanded_message_.empty()) {
return expanded_message_.c_str();
}
return std::runtime_error::what();
}
void AppendContext(const std::string& context) {
expanded_message_ = ONNX_NAMESPACE::MakeString(std::runtime_error::what(), "\n\n==> Context: ", context);
}
private:
std::string expanded_message_;
};
#define fail_type_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[TypeInferenceError] ", __VA_ARGS__)));
#define fail_shape_inference(...) \
ONNX_THROW_EX(ONNX_NAMESPACE::InferenceError(ONNX_NAMESPACE::MakeString("[ShapeInferenceError] ", __VA_ARGS__)));
struct InferenceContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
virtual const TypeProto* getInputType(size_t index) const = 0;
virtual bool hasInput(size_t index) const {
// The default implementation below is used for backward-compatibility
// for implementations of InferenceContext that don't provide an explicit
// implementation. This works for normal usage, but may be imprecise in
// the edge-case where an input is supplied but has no known type.
// However, inference-methods work only under the assumption that the
// input-types of all inputs are known.
return ((index < getNumInputs()) && (getInputType(index) != nullptr));
}
virtual const TensorProto* getInputData(size_t index) const = 0;
virtual size_t getNumOutputs() const = 0;
virtual TypeProto* getOutputType(size_t index) = 0;
virtual GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) = 0;
virtual ~InferenceContext() {}
virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
// Gets the shape inputs computed by partial data propagation.
virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
// To display a name the user can use to narrow its search.
virtual std::string getDisplayName() const {
return "";
}
};
// We use data propagation to perform partial evaluation of the model, to compute statically
// known information about tensor values. It is intended to improve the precision of shape
// inference. We reuse TensorShapeProto to represent the statically known values. One
// limitation of this is that TensorShapeProto can represent only integer values.
// As an example, data-propagation is intended to handle code-fragments like below:
// shape = Shape(X)
// batchsize = Slice(shape, [0], [1])
// newshape = Concat (batchsize, [1024, 1024])
// Z = Reshape(Y, newshape)
// If the shape of X is statically known, then data-propagation should be able to determine
// the value of newshape, as well as the shape of Z.
struct DataPropagationContext {
virtual const AttributeProto* getAttribute(const std::string& name) const = 0;
virtual size_t getNumInputs() const = 0;
virtual const TypeProto* getInputType(size_t index) const = 0;
virtual size_t getNumOutputs() const = 0;
virtual const TypeProto* getOutputType(size_t index) const = 0;
virtual ~DataPropagationContext() {}
virtual const TensorShapeProto* getInputData(size_t index) = 0;
virtual void addOutputData(size_t index, TensorShapeProto&& tp) = 0;
};
using InferenceFunction = std::function<void(InferenceContext&)>;
using DataPropagationFunction = std::function<void(DataPropagationContext&)>;
// This no-op inference function is used for operators without an
// inference implementation.
inline void dummyInferenceFunction(InferenceContext&){};
// This no-op data propagation function is used for operators without a defined data propagator
inline void dummyDataPropagationFunction(DataPropagationContext&){};
template <typename T>
inline bool getRepeatedAttribute(InferenceContext& ctx, std::string attr_name, std::vector<T>& values) {
const auto* attr = ctx.getAttribute(attr_name);
if (attr) {
values = RetrieveValues<T>(*attr);
return true;
} else {
return false;
}
}
inline int64_t getAttribute(InferenceContext& ctx, const std::string& attributeName, int64_t defaultValue) {
auto attr_proto = ctx.getAttribute(attributeName);
if ((nullptr != attr_proto) && attr_proto->has_i())
return attr_proto->i();
return defaultValue;
}
inline int64_t getAttribute(DataPropagationContext& ctx, const std::string& attributeName, int64_t defaultValue) {
auto attr_proto = ctx.getAttribute(attributeName);
if ((nullptr != attr_proto) && attr_proto->has_i())
return attr_proto->i();
return defaultValue;
}
inline std::string
getAttribute(InferenceContext& ctx, const std::string& attributeName, const std::string& defaultValue) {
auto attr_proto = ctx.getAttribute(attributeName);
if ((nullptr != attr_proto) && attr_proto->has_s())
return attr_proto->s();
return defaultValue;
}
inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, TensorShapeProto::Dimension dim2) {
TensorShapeProto::Dimension result;
if (dim1.has_dim_value() && dim2.has_dim_value()) {
result.set_dim_value(dim1.dim_value() * dim2.dim_value());
} else if (dim1.has_dim_value() && (dim1.dim_value() == 1)) {
return dim2;
} else if (dim2.has_dim_value() && (dim2.dim_value() == 1)) {
return dim1;
}
return result;
}
template <typename Container>
std::string stringify(const Container& elements);
std::pair<int, int> getAttributeProtoElemTypeAndLength(const AttributeProto* attr_proto);
std::pair<int, int> getAttributeElementTypeAndLength(
const InferenceContext& ctx,
const std::initializer_list<std::string>& attribute_names);
inline TensorShapeProto::Dimension operator*(TensorShapeProto::Dimension dim1, int64_t dim2) {
TensorShapeProto::Dimension result;
if (dim1.has_dim_value()) {
result.set_dim_value(dim1.dim_value() * dim2);
} else if (dim2 == 1) {
return dim1;
}
return result;
}
inline TensorShapeProto::Dimension operator/(TensorShapeProto::Dimension dim1, int64_t dim2) {
TensorShapeProto::Dimension result;
if (dim1.has_dim_value()) {
result.set_dim_value(dim1.dim_value() / dim2);
} else if (dim2 == 1) {
return dim1;
}
return result;
}
// if from >= upto_exclusive, return 1.
// Caller must make sure upto_exclusive is less than or equal to shape.size()
// Caller must make sure from>=0
inline TensorShapeProto::Dimension multiplyDims(const TensorShapeProto& shape, int from, int upto_exclusive) {
TensorShapeProto::Dimension dim;
dim.set_dim_value(1);
for (int i = from; i < upto_exclusive; ++i) {
dim = dim * shape.dim(i);
}
return dim;
}
inline int32_t getTensorElementType(const TypeProto& type) {
int32_t result = TensorProto::UNDEFINED;
const auto value_case = type.value_case();
if (value_case == TypeProto::kTensorType) {
result = type.tensor_type().elem_type();
} else if (value_case == TypeProto::kSparseTensorType) {
result = type.sparse_tensor_type().elem_type();
}
return result;
}
inline void setTensorElementType(int32_t elem_type, TypeProto::ValueCase value_case, TypeProto& type) {
if (value_case == TypeProto::kTensorType) {
type.mutable_tensor_type()->set_elem_type(elem_type);
} else if (value_case == TypeProto::kSparseTensorType) {
type.mutable_sparse_tensor_type()->set_elem_type(elem_type);
}
}
void propagateElemTypeWithValidation(const TypeProto* input_type, TypeProto* output_type);
void propagateElemTypeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
void propagateElemTypeFromTensorInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex);
inline void propagateElemTypeFromDtypeToOutput(
InferenceContext& ctx,
const int data_type,
size_t outputIndex,
TypeProto::ValueCase expected_value_case) {
const auto attribute_tensor_datatype = data_type;
auto output_type = ctx.getOutputType(outputIndex);
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::VALUE_NOT_SET || output_value_case == expected_value_case) {
setTensorElementType(attribute_tensor_datatype, expected_value_case, *output_type);
} else {
// This is not expected to happen
fail_type_inference(
"Output ",
outputIndex,
" expected to have: ",
expected_value_case,
" or UNDEFINED. Got: ",
output_value_case,
" in ",
ctx.getDisplayName(),
".");
}
}
inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const int data_type, size_t outputIndex) {
propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, TypeProto::kTensorType);
}
inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const AttributeProto* attr, size_t outputIndex) {
int32_t data_type = TensorProto::UNDEFINED;
TypeProto::ValueCase expected_value_case = TypeProto::VALUE_NOT_SET;
const auto attr_type = attr->type();
if (attr_type == AttributeProto::TENSOR) {
if (attr->t().dims().size() != 1) {
fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), ".");
}
data_type = attr->t().data_type();
expected_value_case = TypeProto::kTensorType;
} else if (attr_type == AttributeProto::SPARSE_TENSOR) {
if (attr->sparse_tensor().dims().size() != 1) {
fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), ".");
}
data_type = attr->sparse_tensor().values().data_type();
expected_value_case = TypeProto::kSparseTensorType;
} else {
fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), ".");
}
propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
}
inline bool hasShape(const TypeProto& type) {
if (type.has_tensor_type()) {
return type.tensor_type().has_shape();
} else if (type.has_sparse_tensor_type()) {
return type.sparse_tensor_type().has_shape();
} else if (type.has_sequence_type() && type.sequence_type().has_elem_type()) {
return hasShape(type.sequence_type().elem_type());
} else if (type.has_optional_type() && type.optional_type().has_elem_type()) {
return hasShape(type.optional_type().elem_type());
}
return false;
}
template <typename Context>
inline bool hasInputShape(const Context& ctx, size_t n) {
return ctx.getNumInputs() > static_cast<size_t>(n) && ctx.getInputType(n) && hasShape(*ctx.getInputType(n));
}
template <typename Context>
inline bool hasNInputShapes(const Context& ctx, size_t n) {
for (size_t i = 0; i < n; i++) {
if (!hasInputShape(ctx, i)) {
return false;
}
}
return true;
}
inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t n) {
const auto* input_type = ctx.getInputType(n);
const auto value_case = input_type->value_case();
if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
}
if (!hasShape(*input_type)) {
fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), ".");
}
if (value_case == TypeProto::kTensorType) {
return input_type->tensor_type().shape();
} else {
return input_type->sparse_tensor_type().shape();
}
}
inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size_t n) {
const auto* input_type = ctx.getInputType(n);
if (input_type == nullptr) {
return nullptr;
}
const auto value_case = input_type->value_case();
if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
}
if (value_case == TypeProto::kTensorType) {
return &input_type->tensor_type().shape();
} else {
return &input_type->sparse_tensor_type().shape();
}
}
// Caller must make sure fromDimIndex is strictly less than shape.dim_size()
inline void appendSingleDimCopiedFromInputTypeToOutputType(
InferenceContext& ctx,
size_t inputIndex,
size_t outputIndex,
size_t fromDimIndex) {
auto output_type = ctx.getOutputType(outputIndex);
const auto output_value_case = output_type->value_case();
auto input_type = ctx.getInputType(inputIndex);
const auto input_value_case = input_type->value_case();
if (output_value_case != input_value_case) {
fail_type_inference(
"Input: ",
inputIndex,
" type: ",
input_value_case,
" does not match type of output: ",
outputIndex,
"type: ",
output_value_case,
" in ",
ctx.getDisplayName(),
".");
}
if (TypeProto::kTensorType == input_value_case) {
auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
*dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
} else if (TypeProto::kSparseTensorType == input_value_case) {
auto* dim = output_type->mutable_sparse_tensor_type()->mutable_shape()->add_dim();
*dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
} else {
fail_type_inference(
"Input ",
inputIndex,
" and Output ",
outputIndex,
" expected to have tensor or sparse tensor type in ",
ctx.getDisplayName(),
".");
}
}
inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
const auto from_type_case = from_type->value_case();
const auto to_type_case = to_type->value_case();
if (from_type_case != to_type_case) {
fail_shape_inference(
"Mismatch between inferred and declared type. Inferred=", from_type_case, " Declared=", to_type_case);
}
if (TypeProto::kTensorType == from_type_case || TypeProto::kSparseTensorType == from_type_case) {
// If input shape is "unknown", the corresponding should be "unknown" too.
// The way to make output shape unknown is not to assign it any value.
if (hasShape(*from_type)) {
if (TypeProto::kTensorType == from_type_case) {
*to_type->mutable_tensor_type()->mutable_shape() = from_type->tensor_type().shape();
} else {
*to_type->mutable_sparse_tensor_type()->mutable_shape() = from_type->sparse_tensor_type().shape();
}
}
} else if (TypeProto::kSequenceType == from_type_case) {
propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
} else if (TypeProto::kOptionalType == from_type_case) {
propagateShape(&from_type->optional_type().elem_type(), to_type->mutable_optional_type()->mutable_elem_type());
} else if (TypeProto::kMapType == from_type_case) {
propagateShape(&from_type->map_type().value_type(), to_type->mutable_map_type()->mutable_value_type());
} else {
fail_shape_inference("Unsupported Source/Target type=", from_type_case);
}
}
inline void propagateShapeFromInputToOutput(InferenceContext& ctx, size_t inputIndex, size_t outputIndex) {
auto output_type = ctx.getOutputType(outputIndex);
auto input_type = ctx.getInputType(inputIndex);
propagateShape(input_type, output_type);
}
inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
propagateShapeFromInputToOutput(ctx, 0, 0);
}
inline void
updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType, TypeProto::ValueCase expected_type) {
auto output_type = ctx.getOutputType(outputIndex);
if (output_type == nullptr) {
fail_type_inference("Output ", outputIndex, " is null");
}
if (output_type->value_case() == expected_type || output_type->value_case() == TypeProto::VALUE_NOT_SET) {
setTensorElementType(elemType, expected_type, *output_type);
} else {
// This is not expected to happen
fail_type_inference(
"Output ",
outputIndex,
" expected to have tensor or sparse tensor type: ",
expected_type,
" in ",
ctx.getDisplayName(),
".");
}
}
inline void updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType) {
updateOutputElemType(ctx, outputIndex, elemType, TypeProto::kTensorType);
}
// Infer type of an output from the value of a specified attribute, which is
// expected to have a valid value representing a TensorProto_DataType.
inline void propagateElemTypeFromAttributeToOutput(
InferenceContext& ctx,
const std::string& attributeName,
size_t outputIndex,
TypeProto::ValueCase expected_type,
TensorProto_DataType default_value = TensorProto::UNDEFINED) {
auto attr_proto = ctx.getAttribute(attributeName);
if (nullptr == attr_proto) { // attribute not present
if (default_value != TensorProto::UNDEFINED) {
updateOutputElemType(ctx, outputIndex, default_value, expected_type);
return;
} else {
fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), ".");
}
}
if (!attr_proto->has_i()) {
fail_type_inference(
"Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), ".");
}
auto attr_value = attr_proto->i();
auto elem_type = static_cast<TensorProto_DataType>(attr_value);
if (!TensorProto_DataType_IsValid(elem_type)) {
fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), ".");
}
updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
}
inline void propagateElemTypeFromAttributeToOutput(
InferenceContext& ctx,
const std::string& attributeName,
size_t outputIndex,
TensorProto_DataType default_value = TensorProto::UNDEFINED) {
propagateElemTypeFromAttributeToOutput(ctx, attributeName, outputIndex, TypeProto::kTensorType, default_value);
}
inline TensorShapeProto* getTensorMutableShape(TypeProto::ValueCase value_case, TypeProto& type) {
if (value_case == TypeProto::kTensorType) {
return type.mutable_tensor_type()->mutable_shape();
} else if (value_case == TypeProto::kSparseTensorType) {
return type.mutable_tensor_type()->mutable_shape();
}
return nullptr;
}
inline TensorShapeProto*
getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
auto output_type = ctx.getOutputType(n);
if (output_type == nullptr) {
fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), ".");
}
const auto output_value_case = output_type->value_case();
if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
return getTensorMutableShape(output_value_case, *output_type);
} else if (output_value_case == TypeProto::VALUE_NOT_SET) {
return getTensorMutableShape(default_type, *output_type);
} else {
fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), ".");
}
}
inline void appendDim(TensorShapeProto* shape, int64_t dim_value) {
shape->add_dim()->set_dim_value(dim_value);
}
inline void updateOutputShape(
InferenceContext& ctx,
size_t outputIndex,
const TensorShapeProto& shape,
TypeProto::ValueCase default_type = TypeProto::kTensorType) {
auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
*output_shape = shape;
}
inline void updateOutputShape(
InferenceContext& ctx,
size_t outputIndex,
const TensorProto& tensorProto,
TypeProto::ValueCase default_type = TypeProto::kTensorType) {
auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
for (auto d : tensorProto.dims()) {
auto* dim = output_shape->add_dim();
dim->set_dim_value(d);
}
}
inline void updateOutputShape(
InferenceContext& ctx,
size_t outputIndex,
std::initializer_list<TensorShapeProto::Dimension> dims,
TypeProto::ValueCase default_type = TypeProto::kTensorType) {
auto* output_shape = getOutputShape(ctx, outputIndex, default_type);
for (auto& d : dims) {
auto* dim = output_shape->add_dim();
*dim = d;
}
}
// Get shape input by first checking initializer and then propagated symbolic data.
// If neither is available, try rank inference.
// When one of above succeeds, `true` is stored in `found`.
// Otherwise, `false` is stored, which means that returned TensorShapeProto does not make sense.
TensorShapeProto getShapeInput(const InferenceContext& ctx, size_t input_index, bool& found);
// Infer shape of an output from the value of a specified attribute, which is
// expected to be a list of integers specifying a valid shape.
inline void propagateShapeFromAttributeToOutput(
InferenceContext& ctx,
const std::string& attributeName,
size_t outputIndex,
TypeProto::ValueCase default_type = TypeProto::kTensorType) {
auto attr_proto = ctx.getAttribute(attributeName);
if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
(attr_proto->type() != AttributeProto_AttributeType_INTS)) {
fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), ".");
}
auto& int_list = attr_proto->ints();
TensorShapeProto shape;
for (auto dim_size : int_list) {
if (dim_size < 0) {
fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), ".");
}
shape.add_dim()->set_dim_value(dim_size);
}
updateOutputShape(ctx, outputIndex, shape, default_type);
}
inline void multidirectionalBroadcastShapeInference(
const std::vector<const TensorShapeProto*>& shapes,
TensorShapeProto& resultShape) {
int result_shape_size = 0;
// Get the result shape size.
for (size_t i = 0; i < shapes.size(); ++i) {
if (shapes[i]->dim_size() > result_shape_size) {
result_shape_size = shapes[i]->dim_size();
}
}
for (int i = 0; i < result_shape_size; ++i) {
int64_t dim_value = 1;
TensorShapeProto_Dimension symbolic_dim;
int num_symbolic_dims = 0;
for (size_t j = 0; j < shapes.size(); ++j) {
if (i < result_shape_size - shapes[j]->dim_size()) {
// Shape j will be filled with 1 at dimension i;
continue;
}
auto dim_i_j = shapes[j]->dim(i - result_shape_size + shapes[j]->dim_size());
if (dim_i_j.has_dim_value()) {
if (dim_i_j.dim_value() != 1) {
if (dim_value != dim_i_j.dim_value() && dim_value != 1) {
fail_shape_inference("Incompatible dimensions");
} else {
dim_value = dim_i_j.dim_value();
}
}
} else {
if (num_symbolic_dims == 0) {
symbolic_dim = dim_i_j;
++num_symbolic_dims;
} else if (dim_i_j.dim_param() != symbolic_dim.dim_param()) {
++num_symbolic_dims;
}
}
}
if (dim_value != 1 || num_symbolic_dims == 0) {
resultShape.add_dim()->set_dim_value(dim_value);
} else if (num_symbolic_dims == 1) {
*resultShape.add_dim() = symbolic_dim;
} else {
resultShape.add_dim();
}
}
}
inline void bidirectionalBroadcastShapeInference(
const TensorShapeProto& shapeL,
const TensorShapeProto& shapeR,
TensorShapeProto& resultShape) {
std::vector<const TensorShapeProto*> shapes;
shapes.push_back(&shapeL);
shapes.push_back(&shapeR);
multidirectionalBroadcastShapeInference(shapes, resultShape);
}
/*
Merge the dimension information from two TensorShapeProto_Dimension instances.
Values are merged into target from source.
If target has no dimension information, copy from source.
If source has no dimension information, ignore source.
If both have dimension information:
- Prefer values over params. If both have values, values must match.
- Prefer target param over source param if mismatched.
Fail if there are mismatches in dimension values.
Currently, there is no way to refine/update dimension information for the
source from information available in the target.
*/
inline void mergeInDimensionInfo(
const TensorShapeProto_Dimension& source_dim,
TensorShapeProto_Dimension& target_dim,
int dim_index) {
// if source has value, merge into target
// else if target has value, preserve it
// else merge params
if (source_dim.has_dim_value()) {
auto source_value = source_dim.dim_value();
if (target_dim.has_dim_value()) {
auto target_value = target_dim.dim_value();
if (target_value != source_value) {
fail_shape_inference(
"Can't merge shape info. "
"Both inferred and declared dimension have values but they differ. Inferred=",
source_value,
" Declared=",
target_value,
" Dimension=",
dim_index);
}
} else {
target_dim.set_dim_value(source_value);
}
} else if (target_dim.has_dim_value()) {
// if target has a value we preserve it so do nothing
} else if (target_dim.has_dim_param()) {
// prefer target param over source
} else if (source_dim.has_dim_param()) {
target_dim.set_dim_param(source_dim.dim_param());
}
}
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
void mergeInShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
/*
Merge the shape information from two TypeProto_Tensor instances.
Values are merged into target from source.
If target has no shape information, copy from source.
If source has no shape information, ignore source.
If both have shape information:
- merge each TensorShapeProto_Dimension separately.
- Prefer values over params. If both have values, values must match.
- Prefer target param over source param if mismatched.
Fail if there are mismatches in number of dimensions or dimension values.
*/
void mergeInShapeInfo(const TypeProto_Tensor& source, TypeProto_Tensor& target);
void mergeInShapeInfo(const TypeProto_SparseTensor& source, TypeProto_SparseTensor& target);
// Return a copy of a type, with a specified dimension removed from its shape.
inline TypeProto RemoveIthDimensionFromShape(const TypeProto& proto, int removed_dim) {
TypeProto t(proto);
auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
mutable_shape->clear_dim();
const auto& dims = proto.tensor_type().shape().dim();
for (int j = 0, end = dims.size(); j < end; ++j) {
if (j != removed_dim)
(*mutable_shape->add_dim()) = dims.Get(j);
}
return t;
}
// Return a copy of a type, with specified number of dimensions removed from the
// beginning.
inline TypeProto RemoveDimensionsFromShape(const TypeProto& proto, int num_dimensions) {
TypeProto t(proto);
auto mutable_shape = t.mutable_tensor_type()->mutable_shape();
mutable_shape->clear_dim();
const auto& dims = proto.tensor_type().shape().dim();
// skip first num_dimensions
for (int j = num_dimensions, end = dims.size(); j < end; ++j) {
(*mutable_shape->add_dim()) = dims.Get(j);
}
return t;
}
// copied from GSL:
// https://github.com/microsoft/GSL/blob/main/include/gsl/util
template <class T, class U>
static constexpr T narrow_cast(U&& u) noexcept {
return static_cast<T>(std::forward<U>(u));
}
inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expected_rank) {
// We check the rank only if a rank is known for the input:
if (hasInputShape(ctx, input_index)) {
auto rank = getInputShape(ctx, input_index).dim_size();
if (rank != expected_rank) {
fail_shape_inference(
"Input ",
input_index,
" expected to have rank ",
expected_rank,
" but has rank ",
rank,
" in ",
ctx.getDisplayName(),
".");
}
}
}
// Unification (between dimensions and/or shapes) is at the heart of
// shape-inference. The current inference algorithm can check input
// shapes/dimensions of a node and update the output shapes/dimensions. It
// cannot currently update input shapes and dimensions (even though in some
// contexts this inference is possible). Hence, we have the variants below to
// support "const" and "mutable" dimensions/shapes in unification.
inline void checkDimEquality(int64_t value1, int64_t value2) {
if (value1 != value2) {
fail_shape_inference("Dimension mismatch in unification between ", value1, " and ", value2);
}
}
inline void unifyDim(const Dim& dim1, const Dim& dim2) {
if (dim1.has_dim_value() && dim2.has_dim_value())
checkDimEquality(dim1.dim_value(), dim2.dim_value());
}
// TODO: The functionality of unifyDim is similar to that of
// mergeInDimensionInfo. However, the error messages are different. Leaving this
// duplication in-place to preserve error message content.
inline void unifyDim(const Dim& source_dim, Dim& target_dim) {
if (source_dim.has_dim_value()) {
auto source_value = source_dim.dim_value();
if (target_dim.has_dim_value()) {
auto target_value = target_dim.dim_value();
checkDimEquality(source_value, target_value);
} else {
target_dim.set_dim_value(source_value);
}
} else if (target_dim.has_dim_value()) {
// if target has a value we preserve it.
// we cannot set source dim value.
} else if (target_dim.has_dim_param()) {
// prefer target param over source
// we cannot currently unify the dim_params
} else if (source_dim.has_dim_param()) {
target_dim.set_dim_param(source_dim.dim_param());
}
}
inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_index, Dim& dim) {
// We unify the dimensions only if it is available for specified input:
if (hasInputShape(ctx, input_index)) {
auto& input_shape = getInputShape(ctx, input_index);
// This shape is expected to have rank > dim_index:
if (input_shape.dim_size() <= dim_index) {
fail_shape_inference(
"Input ",
input_index,
" expected to have rank >",
dim_index,
" but has rank ",
input_shape.dim_size(),
" in ",
ctx.getDisplayName(),
".");
}
const Dim& input_dim = input_shape.dim(dim_index);
// Now, unify dim and input_dim:
unifyDim(input_dim, dim);
}
}
// unifyDim: unifies a dimension with a constant value. If the dimension
// already has a value, we check for equality of new value with old value.
inline void unifyDim(Dim& dim, int64_t value) {
if (dim.has_dim_value()) {
checkDimEquality(dim.dim_value(), value);
} else
dim.set_dim_value(value);
}
// target-shape = Union (target-shape, source_shape)
// Example 1: same rank, different dimensions
// input1 shape: (2, 3, 4, 'x')
// input2 shape: (2, 'y', 5, 'x')
// output shape: (2, None, None, 'x')
// Example 2: different rank
// input1 shape: (2, 3, 4, 'x')
// input2 shape: (2, 3, 4)
// output shape: None
void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_Tensor& target_type);
void UnionShapeInfo(const TensorShapeProto& source_shape, TypeProto_SparseTensor& target_type);
// target-type = Union (target-type, source-type)
// target and source are required to have the same type.
// Example 1: same tensor type, different shape
// source: tensor elem_type: int64, shape: (2, 3, 4, 'x')
// target: tensor elem_type: int64, shape: (2, 'y', 5, 'x')
// output: tensor elem_type: int64, shape: (2, None, None, 'x')
// Example 2: same sequence type, different shape
// source: sequence of tensor, elem_type: float, shape: (2, 3, 4)
// target: sequence of tensor, elem_type: float, shape: None
// output: sequence of tensor, elem_type: float, shape: None
void UnionTypeInfo(const TypeProto& source_type, TypeProto& target_type);
// adjustNegativeAxes: Negative axes values are translated to the right axis in the positive range
template <typename Axes>
void adjustNegativeAxes(Axes& axes, int rank) {
std::transform(
axes.begin(), axes.end(), axes.begin(), [&](int64_t axis) -> int64_t { return axis < 0 ? axis + rank : axis; });
}
// checkAxesRange: Checks that values are within the range [-rank, rank)
template <typename Axes>
void checkAxesRange(Axes& axes, int rank) {
for (auto axis : axes) {
if (axis < -rank || axis > (rank - 1))
fail_shape_inference("Unexpected axis value: ", axis, ". Expected range [", -rank, ", ", rank, ")");
}
}
// checkDuplicateAxes: Check that there are no duplicated axes
template <typename Axes>
void checkDuplicateAxes(Axes& axes, int rank) {
std::vector<bool> tmp(rank, false);
for (auto axis : axes) {
int actual_axis = axis < 0 ? axis + rank : axis;
if (tmp[actual_axis])
fail_shape_inference("Axis ", axis, " is referred to more than once.");
tmp[actual_axis] = true;
}
}
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,518 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/tensor/utils.h"
#include <algorithm>
#include <limits>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
namespace ONNX_NAMESPACE {
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<int64_t>& sizes_data,
TensorShapeProto* output_shape) {
if (!sizes_data.empty()) {
for (int i = 0; i < input_shape.dim_size(); ++i) {
auto* dim = output_shape->mutable_dim(i);
if (sizes_data[i] > 0) {
dim->set_dim_value(sizes_data[i]);
}
}
return;
}
}
void KeepAspectRatioHelper(
KeepAspectRatioPolicy policy,
const TensorShapeProto& input_shape,
const std::vector<int64_t>& axes,
std::vector<int64_t>& sizes_data) {
if (policy != KeepAspectRatioPolicy::NOT_LARGER && policy != KeepAspectRatioPolicy::NOT_SMALLER) {
return;
}
float scale = policy == KeepAspectRatioPolicy::NOT_LARGER ? std::numeric_limits<float>::max()
: std::numeric_limits<float>::min();
std::function<float(float, float)> reduce_f;
if (policy == KeepAspectRatioPolicy::NOT_LARGER) {
reduce_f = [](float a, float b) { return std::min(a, b); };
} else {
reduce_f = [](float a, float b) { return std::max(a, b); };
}
bool has_unknown_dim = false;
for (size_t i = 0; i < sizes_data.size(); i++) {
int d = axes.empty() ? i : axes[i];
if (!input_shape.dim(d).has_dim_value()) {
has_unknown_dim = true;
break;
}
float s = sizes_data[i] / static_cast<float>(input_shape.dim(d).dim_value());
scale = reduce_f(scale, s);
}
// If there's at least one unknown dim we can't infer the output shape, since it
// will depend on the original aspect ratio of the input.
for (size_t i = 0; i < sizes_data.size(); i++) {
int d = axes.empty() ? i : axes[i];
sizes_data[i] = has_unknown_dim ? -1 : std::roundf(scale * input_shape.dim(d).dim_value());
}
}
void gridSampleShapeInference(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// If there is any input shape unknown, skip the shape inference.
if (!hasNInputShapes(ctx, 2)) {
return;
}
// Grid sample input tensor indices.
size_t const input_param = 0, grid_param = 1;
auto const& input_shape = getInputShape(ctx, input_param);
auto const& grid_shape = getInputShape(ctx, grid_param);
if (input_shape.dim_size() != grid_shape.dim_size()) {
fail_shape_inference(
"The input tensor and grid tensor must have the same rank for GridSample. ",
"Got input tensor rank: ",
input_shape.dim_size(),
". ",
"Got grid tensor rank: ",
grid_shape.dim_size(),
". ");
}
int const num_dims = input_shape.dim_size();
if (num_dims < 3) {
fail_shape_inference(
"The input tensor and grid tensor ranks must be >= 3. ",
"Got input tensor and grid tensor ranks: ",
num_dims,
". ");
}
auto const& last_dim = grid_shape.dim(num_dims - 1);
if (last_dim.has_dim_value() && (last_dim.dim_value() != num_dims - 2)) {
fail_shape_inference(
"The last dimension of the grid tensor must be the rank of the grid tensor - 2. ",
"Got grid tensor rank: ",
num_dims,
"Got the last dimension of the grid tensor: ",
last_dim.dim_value(),
". ");
}
auto* output_shape = getOutputShape(ctx, 0);
// N
Dim& N = *(output_shape->add_dim());
// The first call sets the dimension using the dimensions from input_shape.
unifyDim(input_shape.dim(0), N);
// The second call checks the dimension using the dimensions from grid_shape.
unifyDim(grid_shape.dim(0), N);
// C
Dim& C = *(output_shape->add_dim());
unifyDim(input_shape.dim(1), C);
// Other Dimensions.
for (int i = 0; i < num_dims - 2; ++i) {
Dim& D = *(output_shape->add_dim());
unifyDim(grid_shape.dim(1 + i), D);
}
}
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape) {
for (int i = 0; i < input_shape.dim_size(); ++i) {
auto* dim = output_shape->mutable_dim(i);
// If input_shape has dim_value, we calculate the scaled result
// If input_shape doesn's have one, we leave it here
if (input_shape.dim(i).has_dim_value()) {
int64_t dim_value =
static_cast<int64_t>(std::floor(static_cast<float>(input_shape.dim(i).dim_value()) * scales_data[i]));
// If output_shape has dim_value, we validate the caculated result
// If output_shape doesn's have one, we set it to the scaled result
if (dim->has_dim_value()) {
if (static_cast<int64_t>(dim->dim_value()) != dim_value) {
fail_shape_inference(
"Dimension value inferred (",
dim_value,
") is not equal to the existing dim value (",
dim->dim_value(),
").");
}
} else {
dim->set_dim_value(static_cast<int64_t>(dim_value));
} // dim->has_dim_value()
} // input_shape.dim(i).has_dim_value()
}
}
void resizeShapeInferenceVersioned(InferenceContext& ctx, int opset_version) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
const auto& input_shape = getInputShape(ctx, 0);
auto* output_shape = getOutputShape(ctx, 0);
bool hasScalesInput = ctx.hasInput(2);
bool hasSizesInput = ctx.hasInput(3);
const TensorProto* scales = 2 < ctx.getNumInputs() ? ctx.getInputData(2) : nullptr;
std::vector<int64_t> sizes_data;
if (3 < ctx.getNumInputs()) {
bool found_sizes = false;
const auto sizes_shape = getShapeInput(ctx, 3, found_sizes);
// If sizes is an empty shape, assume it's not provided
if (found_sizes) {
if (sizes_shape.dim_size() == 0) {
hasSizesInput = false;
} else {
for (int i = 0; i < sizes_shape.dim_size(); ++i) {
sizes_data.push_back(sizes_shape.dim(i).dim_value());
}
}
}
}
// If scales is an empty constant, assume it's not provided
if (scales && ParseData<float>(scales).empty()) {
hasScalesInput = false;
scales = nullptr;
}
if (opset_version >= 13) {
if (hasScalesInput + hasSizesInput != 1) {
fail_shape_inference("Either `sizes` or `scales` must be provided, but not both of them");
}
}
auto keep_aspect_ratio_policy_attr = ctx.getAttribute("keep_aspect_ratio_policy");
KeepAspectRatioPolicy keep_aspect_ratio_policy = KeepAspectRatioPolicy::STRETCH;
if (keep_aspect_ratio_policy_attr && keep_aspect_ratio_policy_attr->has_s()) {
auto str = keep_aspect_ratio_policy_attr->s();
if (str == "stretch") {
keep_aspect_ratio_policy = KeepAspectRatioPolicy::STRETCH;
} else if (str == "not_larger") {
keep_aspect_ratio_policy = KeepAspectRatioPolicy::NOT_LARGER;
} else if (str == "not_smaller") {
keep_aspect_ratio_policy = KeepAspectRatioPolicy::NOT_SMALLER;
} else {
fail_shape_inference("Unknown value for `keep_aspect_ratio_policy`: ", str, ".");
}
}
if (hasScalesInput && keep_aspect_ratio_policy != KeepAspectRatioPolicy::STRETCH) {
fail_shape_inference(
"Providing `scales` is incompatible with a `keep_aspect_ratio_policy` other than \"stretch\".");
}
if (output_shape->dim_size() > 0) {
if (output_shape->dim_size() != input_shape.dim_size()) {
fail_shape_inference(
"Ranks inferred (",
input_shape.dim_size(),
") is not equal to the existing rank value (",
output_shape->dim_size(),
").");
}
} else { // Infer the rank of output anyway
for (int i = 0; i < input_shape.dim_size(); ++i) {
output_shape->add_dim();
}
}
auto axes_attr = ctx.getAttribute("axes");
size_t rank_x = input_shape.dim_size();
std::vector<int64_t> axes;
if (axes_attr) {
axes = RetrieveValues<int64_t>(*axes_attr);
checkAxesRange(axes, rank_x);
adjustNegativeAxes(axes, rank_x);
checkDuplicateAxes(axes, rank_x);
}
if (hasSizesInput) {
if (!axes.empty()) {
if (sizes_data.size() != axes.size()) {
fail_shape_inference(
"Number of elements of input 'sizes' (",
sizes_data.size(),
") does not match the number of axes (",
axes.size(),
").");
}
} else {
// sizes_data contains scales for all axes
if (sizes_data.size() != rank_x) {
fail_shape_inference(
"Number of elements of input 'sizes' (",
sizes_data.size(),
") must be same as rank of input 'X' (",
rank_x,
").");
}
}
// Process sizes_data according to the selected policy
KeepAspectRatioHelper(keep_aspect_ratio_policy, input_shape, axes, sizes_data);
// If axes subset is provided, populate new sizes_data with all dims
if (!axes.empty()) {
std::vector<int64_t> tmp(rank_x);
for (size_t i = 0; i < rank_x; i++) {
tmp[i] = input_shape.dim(i).has_dim_value() ? input_shape.dim(i).dim_value() : -1;
}
for (size_t i = 0; i < axes.size(); i++) {
int d = axes[i];
tmp[d] = sizes_data[i];
}
std::swap(tmp, sizes_data);
}
resizeShapeInferenceHelper(input_shape, sizes_data, output_shape);
} else if (nullptr != scales) {
// Infer output shape's dimension value if 'scales' is known.
if (scales->data_type() == TensorProto::FLOAT) {
auto scales_data = ParseData<float>(scales);
if (!axes.empty()) {
// scales_data contains scales for a subset of axes. The rest should not be resized
if (scales_data.size() != axes.size()) {
fail_shape_inference(
"Number of elements of input 'scales' (",
scales_data.size(),
") does not match the number of axes (",
axes.size(),
").");
}
std::vector<float> tmp(rank_x, 1.0f);
for (size_t i = 0; i < axes.size(); i++) {
int d = axes[i];
tmp[d] = scales_data[i];
}
std::swap(tmp, scales_data);
} else {
// scales_data contains scales for all axes
if (scales_data.size() != static_cast<size_t>(input_shape.dim_size())) {
fail_shape_inference("Number of elements of input 'scales' must be same as rank of input 'X'");
}
}
resizeShapeInferenceHelper(input_shape, scales_data, output_shape);
} else {
fail_shape_inference("Input 'scales' must have float element type.");
}
} // nullptr != scales
}
void resizeShapeInference_opset18_to_19(InferenceContext& ctx) {
resizeShapeInferenceVersioned(ctx, 19);
}
void resizeShapeInference_opset13_to_18(InferenceContext& ctx) {
resizeShapeInferenceVersioned(ctx, 13);
}
void resizeShapeInference_opset11_to_12(InferenceContext& ctx) {
resizeShapeInferenceVersioned(ctx, 11);
}
void resizeShapeInferenceHelper_opset7_to_10(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape) {
for (int i = 0; i < input_shape.dim_size(); ++i) {
auto* dim = output_shape->mutable_dim(i);
// If input_shape has dim_value, we calculate the scaled result
// If input_shape doesn's have one, we leave it here
if (input_shape.dim(i).has_dim_value()) {
int64_t dim_value =
static_cast<int64_t>(std::floor(static_cast<float>(input_shape.dim(i).dim_value()) * scales_data[i]));
// If output_shape has dim_value, we validate the caculated result
// If output_shape doesn's have one, we set it to the scaled result
if (dim->has_dim_value()) {
if (static_cast<int64_t>(dim->dim_value()) != dim_value) {
fail_shape_inference(
"Dimension value inferred (",
dim_value,
") is not equal to the existing dim value (",
dim->dim_value(),
").");
}
} else {
dim->set_dim_value(static_cast<int64_t>(dim_value));
} // dim->has_dim_value()
} // input_shape.dim(i).has_dim_value()
}
}
void resizeShapeInference_opset7_to_10(InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 1)) {
return;
}
const auto& input_shape = getInputShape(ctx, 0);
auto* output_shape = getOutputShape(ctx, 0);
const auto scales = ctx.getInputData(1);
if (output_shape->dim_size() > 0) {
if (output_shape->dim_size() != input_shape.dim_size()) {
fail_shape_inference(
"Ranks inferred (",
input_shape.dim_size(),
") is not equal to the existing rank value (",
output_shape->dim_size(),
").");
}
} else { // Infer the rank of output anyway
for (int i = 0; i < input_shape.dim_size(); ++i) {
output_shape->add_dim();
}
}
if (nullptr != scales) {
// Infer output shape's dimension value if 'scales' is known.
if (scales->data_type() == TensorProto::FLOAT) {
const auto& scales_data = ParseData<float>(scales);
if (scales_data.size() != static_cast<size_t>(input_shape.dim_size())) {
fail_shape_inference("Number of elements of input 'scales' must be same as rank of input 'X'");
}
resizeShapeInferenceHelper_opset7_to_10(input_shape, scales_data, output_shape);
} else {
fail_shape_inference("Input 'scales' must have float element type.");
} // nullptr != scales
}
}
std::function<void(OpSchema&)> PadDocGenerator(
const char* description,
const char* mode_description,
const std::vector<std::string> op_schema,
const std::string op_schema_description) {
return [=](OpSchema& schema) {
schema.SetDoc(description);
schema.Attr("mode", mode_description, AttributeProto::STRING, std::string("constant"));
schema.Input(0, "data", "Input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
schema.Input(
1,
"pads",
"Tensor of integers indicating the number of padding elements to add or remove (if negative) "
"at the beginning and end of each axis. For 2D input tensor, it is the number of pixels. "
"`pads` should be a 1D tensor of shape [2 * num_axes] where `num_axes` refers to the number "
"of elements in the `axes` input or the input rank if `axes` are not provided explicitly. "
"`pads` format should be: [x1_begin, x2_begin, ..., x1_end, x2_end,...], "
"where xi_begin is the number of pad values added at the beginning of axis `axes[i]` and "
"xi_end, the number of pad values added at the end of axis `axes[i]`.",
"tensor(int64)",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable);
schema.Input(
2,
"constant_value",
"(Optional) A scalar value to be used if the mode chosen is `constant` (by default it is 0, "
"empty string or False).",
"T",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable);
schema.Input(
3,
"axes",
"1-D tensor of axes that `pads` apply to. Negative value means counting dimensions "
"from the back. Accepted range is [-r, r-1] where r = rank(data). Behavior is undefined if an "
"axis is repeated. If not provided, all axes are assumed (`[0, 1, ..., input_rank-1]`).",
"Tind",
OpSchema::Optional,
true,
1,
OpSchema::NonDifferentiable);
schema.Output(0, "output", "Tensor after padding.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable);
schema.TypeConstraint("T", op_schema, op_schema_description);
schema.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types");
schema.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference needs the input data shape
if (!hasNInputShapes(ctx, 1)) {
return;
}
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
const auto input_rank = input_shape.dim_size();
std::vector<int64_t> axes;
if (hasInputShape(ctx, 3)) { //'axes' input
auto axes_initializer = ctx.getInputData(3);
if (axes_initializer == nullptr)
return; // can't do shape inference then
axes = ParseData<int64_t>(axes_initializer);
checkAxesRange(axes, input_rank);
adjustNegativeAxes(axes, input_rank);
checkDuplicateAxes(axes, input_rank);
} else {
axes.resize(input_rank);
std::iota(axes.begin(), axes.end(), 0);
}
int num_axes = axes.size();
auto* output_shape = ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
// Populating default dims
std::vector<TensorShapeProto_Dimension*> out_dims(input_rank);
for (int i = 0; i < input_rank; ++i) {
out_dims[i] = output_shape->add_dim();
}
// Shape Inference if
// 1. 'pads' are available.
// and 2. 'axes' are available, or default.
const TensorProto* pads_initializer = ctx.getInputData(1);
if (nullptr != pads_initializer && !axes.empty()) {
if (pads_initializer->dims_size() != 1 || pads_initializer->data_type() != TensorProto::INT64) {
fail_shape_inference("'pads' input must be a 1D (shape: [2 * num_axes]) tensor of type int64");
}
const auto& pads_data = ParseData<int64_t>(pads_initializer);
if (pads_data.size() != static_cast<size_t>(2 * num_axes)) {
fail_shape_inference(
"Pads has incorrect number of values. Expected 2 * ",
num_axes,
" values. Got ",
pads_data.size(),
" values.");
}
// Set default dim values
for (int i = 0; i < input_rank; ++i) {
const auto& input_dim = input_shape.dim(i);
if (input_dim.has_dim_value()) {
out_dims[i]->set_dim_value(input_dim.dim_value());
}
}
for (int i = 0; i < num_axes; ++i) {
auto axis = axes[i];
const auto& input_dim = input_shape.dim(axis);
auto& out_dim = *out_dims[axis];
auto total_pad = pads_data[i] + pads_data[num_axes + i];
if (input_dim.has_dim_value()) {
out_dim.set_dim_value(input_dim.dim_value() + total_pad);
} else if (total_pad == 0) {
out_dim = input_dim;
}
}
}
});
};
};
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <cmath>
#include <vector>
#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// The below is called by ops after opset 11, inclusively.
void resizeShapeInference(InferenceContext& ctx);
void gridSampleShapeInference(InferenceContext& ctx);
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape);
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<int64_t>& sizes_data,
TensorShapeProto* output_shape);
// Belows are called by ops between opset versions in the name inclusively.
void resizeShapeInference_opset7_to_10(InferenceContext& ctx);
void resizeShapeInference_opset11_to_12(InferenceContext& ctx);
void resizeShapeInference_opset13_to_18(InferenceContext& ctx);
void resizeShapeInference_opset18_to_19(InferenceContext& ctx);
void resizeShapeInferenceHelper_opset7_to_10(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape);
enum class KeepAspectRatioPolicy {
STRETCH,
NOT_LARGER,
NOT_SMALLER,
};
void KeepAspectRatioHelper(
KeepAspectRatioPolicy policy,
const TensorShapeProto& input_shape,
const std::vector<int64_t>& axes,
std::vector<int64_t>& sizes_data);
extern const char* NonZero_ver9_doc;
std::function<void(OpSchema&)> PadDocGenerator(
const char* description,
const char* mode_description,
const std::vector<std::string> op_schema = OpSchema::all_tensor_types_ir4(),
const std::string op_schema_description = "Constrain input and output types to all tensor types.");
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,141 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "tensor_proto_util.h"
#include <string>
#include <vector>
#include "onnx/common/platform_helpers.h"
#include "onnx/defs/data_type_utils.h"
#include "onnx/defs/shape_inference.h"
namespace ONNX_NAMESPACE {
#define DEFINE_TO_TENSOR_ONE(type, enumType, field) \
template <> \
TensorProto ToTensor<type>(const type& value) { \
TensorProto t; \
t.set_data_type(enumType); \
t.add_##field##_data(value); \
return t; \
}
#define DEFINE_TO_TENSOR_LIST(type, enumType, field) \
template <> \
TensorProto ToTensor<type>(const std::vector<type>& values) { \
TensorProto t; \
t.clear_##field##_data(); \
t.set_data_type(enumType); \
for (const type& val : values) { \
t.add_##field##_data(val); \
} \
return t; \
}
#define DEFINE_PARSE_DATA(type, typed_data_fetch, tensorproto_datatype) \
template <> \
const std::vector<type> ParseData(const TensorProto* tensor_proto) { \
if (!tensor_proto->has_data_type() || tensor_proto->data_type() == TensorProto_DataType_UNDEFINED) { \
fail_shape_inference("The type of tensor: ", tensor_proto->name(), " is undefined so it cannot be parsed."); \
} else if (tensor_proto->data_type() != tensorproto_datatype) { \
fail_shape_inference( \
"ParseData type mismatch for tensor: ", \
tensor_proto->name(), \
". Expected:", \
Utils::DataTypeUtils::ToDataTypeString(tensorproto_datatype), \
" Actual:", \
Utils::DataTypeUtils::ToDataTypeString(tensor_proto->data_type())); \
} \
std::vector<type> res; \
if (tensor_proto->has_data_location() && tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { \
fail_shape_inference( \
"Cannot parse data from external tensors. Please ", \
"load external data into raw data for tensor: ", \
tensor_proto->name()); \
} else if (!tensor_proto->has_raw_data()) { \
const auto& data = tensor_proto->typed_data_fetch(); \
int expected_size = 1; \
for (int i = 0; i < tensor_proto->dims_size(); ++i) { \
expected_size *= tensor_proto->dims(i); \
} \
if (tensor_proto->dims_size() != 0 && data.size() != expected_size) { \
fail_shape_inference( \
"Data size mismatch. Tensor: ", \
tensor_proto->name(), \
" expected size ", \
expected_size, \
" does not match the actual size", \
data.size()); \
} \
res.insert(res.end(), data.begin(), data.end()); \
return res; \
} \
if (tensor_proto->data_type() == TensorProto_DataType_STRING) { \
fail_shape_inference( \
tensor_proto->name(), \
" data type is string. string", \
" content is required to be stored in repeated bytes string_data field.", \
" raw_data type cannot be string."); \
} \
/* The given tensor does have raw_data itself so parse it by given type */ \
/* make copy as we may have to reverse bytes */ \
std::string raw_data = tensor_proto->raw_data(); \
if (raw_data.empty()) { \
return res; \
} \
/* okay to remove const qualifier as we have already made a copy */ \
char* bytes = const_cast<char*>(raw_data.c_str()); \
/* onnx is little endian serialized always-tweak byte order if needed */ \
if (!is_processor_little_endian()) { \
const size_t element_size = sizeof(type); \
const size_t num_elements = raw_data.size() / element_size; \
for (size_t i = 0; i < num_elements; ++i) { \
char* start_byte = bytes + i * element_size; \
char* end_byte = start_byte + element_size - 1; \
/* keep swapping */ \
for (size_t count = 0; count < element_size / 2; ++count) { \
char temp = *start_byte; \
*start_byte = *end_byte; \
*end_byte = temp; \
++start_byte; \
--end_byte; \
} \
} \
} \
/* raw_data.c_str()/bytes is a byte array and may not be properly */ \
/* aligned for the underlying type */ \
/* We need to copy the raw_data.c_str()/bytes as byte instead of */ \
/* copying as the underlying type, otherwise we may hit memory */ \
/* misalignment issues on certain platforms, such as arm32-v7a */ \
const size_t raw_data_size = raw_data.size(); \
res.resize(raw_data_size / sizeof(type)); \
memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \
return res; \
}
DEFINE_TO_TENSOR_ONE(float, TensorProto_DataType_FLOAT, float)
DEFINE_TO_TENSOR_ONE(bool, TensorProto_DataType_BOOL, int32)
DEFINE_TO_TENSOR_ONE(int32_t, TensorProto_DataType_INT32, int32)
DEFINE_TO_TENSOR_ONE(int64_t, TensorProto_DataType_INT64, int64)
DEFINE_TO_TENSOR_ONE(uint64_t, TensorProto_DataType_UINT64, uint64)
DEFINE_TO_TENSOR_ONE(double, TensorProto_DataType_DOUBLE, double)
DEFINE_TO_TENSOR_ONE(std::string, TensorProto_DataType_STRING, string)
DEFINE_TO_TENSOR_LIST(float, TensorProto_DataType_FLOAT, float)
DEFINE_TO_TENSOR_LIST(bool, TensorProto_DataType_BOOL, int32)
DEFINE_TO_TENSOR_LIST(int32_t, TensorProto_DataType_INT32, int32)
DEFINE_TO_TENSOR_LIST(int64_t, TensorProto_DataType_INT64, int64)
DEFINE_TO_TENSOR_LIST(uint64_t, TensorProto_DataType_UINT64, uint64)
DEFINE_TO_TENSOR_LIST(double, TensorProto_DataType_DOUBLE, double)
DEFINE_TO_TENSOR_LIST(std::string, TensorProto_DataType_STRING, string)
DEFINE_PARSE_DATA(int32_t, int32_data, TensorProto_DataType_INT32)
DEFINE_PARSE_DATA(int64_t, int64_data, TensorProto_DataType_INT64)
DEFINE_PARSE_DATA(float, float_data, TensorProto_DataType_FLOAT)
DEFINE_PARSE_DATA(double, double_data, TensorProto_DataType_DOUBLE)
#undef DEFINE_PARSE_DATA
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <vector>
#include "onnx/onnx-operators_pb.h"
namespace ONNX_NAMESPACE {
template <typename T>
TensorProto ToTensor(const T& value);
template <typename T>
TensorProto ToTensor(const std::vector<T>& values);
template <typename T>
const std::vector<T> ParseData(const TensorProto* tensor_proto);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,63 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "tensor_util.h"
#include <string>
#include <vector>
#include "onnx/common/platform_helpers.h"
namespace ONNX_NAMESPACE {
#define DEFINE_PARSE_DATA(type, typed_data_fetch) \
template <> \
const std::vector<type> ParseData(const Tensor* tensor) { \
std::vector<type> res; \
if (!tensor->is_raw_data()) { \
const auto& data = tensor->typed_data_fetch(); \
res.insert(res.end(), data.begin(), data.end()); \
return res; \
} \
/* make copy as we may have to reverse bytes */ \
std::string raw_data = tensor->raw(); \
/* okay to remove const qualifier as we have already made a copy */ \
char* bytes = const_cast<char*>(raw_data.c_str()); \
/*onnx is little endian serialized always-tweak byte order if needed*/ \
if (!is_processor_little_endian()) { \
const size_t element_size = sizeof(type); \
const size_t num_elements = raw_data.size() / element_size; \
for (size_t i = 0; i < num_elements; ++i) { \
char* start_byte = bytes + i * element_size; \
char* end_byte = start_byte + element_size - 1; \
/* keep swapping */ \
for (size_t count = 0; count < element_size / 2; ++count) { \
char temp = *start_byte; \
*start_byte = *end_byte; \
*end_byte = temp; \
++start_byte; \
--end_byte; \
} \
} \
} \
/* raw_data.c_str()/bytes is a byte array and may not be properly */ \
/* aligned for the underlying type */ \
/* We need to copy the raw_data.c_str()/bytes as byte instead of */ \
/* copying as the underlying type, otherwise we may hit memory */ \
/* misalignment issues on certain platforms, such as arm32-v7a */ \
const size_t raw_data_size = raw_data.size(); \
res.resize(raw_data_size / sizeof(type)); \
memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \
return res; \
}
DEFINE_PARSE_DATA(int32_t, int32s)
DEFINE_PARSE_DATA(int64_t, int64s)
DEFINE_PARSE_DATA(float, floats)
DEFINE_PARSE_DATA(double, doubles)
DEFINE_PARSE_DATA(uint64_t, uint64s)
#undef DEFINE_PARSE_DATA
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,16 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <vector>
#include "onnx/common/ir.h"
namespace ONNX_NAMESPACE {
template <typename T>
const std::vector<T> ParseData(const Tensor* tensor);
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,199 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* StringConcat_doc =
R"DOC(StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support))DOC";
ONNX_OPERATOR_SET_SCHEMA(
StringConcat,
20,
OpSchema()
.Input(
0,
"X",
"Tensor to prepend in concatenation",
"T",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Input(1, "Y", "Tensor to append in concatenation", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "Z", "Concatenated string tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.TypeConstraint("T", {"tensor(string)"}, "Inputs and outputs must be UTF-8 strings")
.SetDoc(StringConcat_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2))
bidirectionalBroadcastShapeInference(
ctx.getInputType(0)->tensor_type().shape(),
ctx.getInputType(1)->tensor_type().shape(),
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}));
static const char* RegexFullMatch_doc =
R"DOC(RegexFullMatch performs a full regex match on each element of the input tensor. If an element fully matches the regex pattern specified as an attribute, the corresponding element in the output is True and it is False otherwise. [RE2](https://github.com/google/re2/wiki/Syntax) regex syntax is used.)DOC";
ONNX_OPERATOR_SET_SCHEMA(
RegexFullMatch,
20,
OpSchema()
.Input(0, "X", "Tensor with strings to match on.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Attr("pattern", "Regex pattern to match on. This must be valid RE2 syntax.", AttributeProto::STRING, false)
.Output(
0,
"Y",
"Tensor of bools indicating if each input string fully matches the regex pattern specified.",
"T2",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.TypeConstraint("T1", {"tensor(string)"}, "Inputs must be UTF-8 strings")
.TypeConstraint(
"T2",
{"tensor(bool)"},
"Outputs are bools and are True where there is a full regex match and False otherwise.")
.SetDoc(RegexFullMatch_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
updateOutputElemType(ctx, 0, TensorProto::BOOL);
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
static const char* StringSplit_doc =
R"DOC(StringSplit splits a string tensor's elements into substrings based on a delimiter attribute and a maxsplit attribute.
The first output of this operator is a tensor of strings representing the substrings from splitting each input string on the `delimiter` substring. This tensor has one additional rank compared to the input tensor in order to store the substrings for each input element (where the input tensor is not empty). Note that, in order to ensure the same number of elements are present in the final dimension, this tensor will pad empty strings as illustrated in the examples below. Consecutive delimiters are not grouped together and are deemed to delimit empty strings, except if the `delimiter` is unspecified or is the empty string (""). In the case where the `delimiter` is unspecified or the empty string, consecutive whitespace characters are regarded as a single separator and leading or trailing whitespace is removed in the output.
The second output tensor represents the number of substrings generated. `maxsplit` can be used to limit the number of splits performed - after the `maxsplit`th split if the string is not fully split, the trailing suffix of input string after the final split point is also added. For elements where fewer splits are possible than specified in `maxsplit`, it has no effect.)DOC";
ONNX_OPERATOR_SET_SCHEMA(
StringSplit,
20,
OpSchema()
.Input(0, "X", "Tensor of strings to split.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Attr(
"delimiter",
"Delimiter to split on. If left unset or set to the empty string (\"\"), the input is split on consecutive whitespace.",
AttributeProto::STRING,
false)
.Attr(
"maxsplit",
"Maximum number of splits (from left to right). If left unset (or if the number of possible splits are less than maxsplit), it will make as many splits as possible. Note that the maximum possible number of substrings returned with `maxsplit` specified is `maxsplit+1` since the remaining suffix after the `maxsplit`th split is included in the output.",
AttributeProto::INT,
false)
.Output(
0,
"Y",
"Tensor of substrings representing the outcome of splitting the strings in the input on the delimiter. Note that to ensure the same number of elements are present in the final rank, this tensor will pad any necessary empty strings.",
"T2",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Output(
1,
"Z",
"The number of substrings generated for each input element.",
"T3",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.TypeConstraint("T1", {"tensor(string)"}, "The input must be a UTF-8 string tensor")
.TypeConstraint("T2", {"tensor(string)"}, "Tensor of substrings.")
.TypeConstraint("T3", {"tensor(int64)"}, "The number of substrings generated.")
.SetDoc(StringSplit_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
if (!hasInputShape(ctx, 0)) {
return;
}
const TypeProto* input_type = ctx.getInputType(0);
if (input_type == nullptr || !input_type->has_tensor_type() ||
input_type->tensor_type().elem_type() != TensorProto::STRING) {
return;
}
// We produce a string tensor per input element. Therefore we have one additional rank with a runtime
// dependent number of elements. All except the final dimension of the output shape can be inferred directly
// from the input.
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 0);
getOutputShape(ctx, 0)->add_dim();
// The output tensor containing the number of substrings has identical shape to the input but produces int32
// results.
ctx.getOutputType(1)->mutable_tensor_type()->set_elem_type(TensorProto::INT64);
propagateShapeFromInputToOutput(ctx, 0, 1);
}));
static const char* StringNormalizer_ver10_doc = R"DOC(
StringNormalization performs string operations for basic cleaning.
This operator has only one input (denoted by X) and only one output
(denoted by Y). This operator first examines the elements in the X,
and removes elements specified in "stopwords" attribute.
After removing stop words, the intermediate result can be further lowercased,
uppercased, or just returned depending the "case_change_action" attribute.
This operator only accepts [C]- and [1, C]-tensor.
If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]
if input shape is [C] and shape [1, 1] if input shape is [1, C].
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
StringNormalizer,
10,
OpSchema()
.Input(0, "X", "UTF-8 strings to normalize", "tensor(string)")
.Output(0, "Y", "UTF-8 Normalized strings", "tensor(string)")
.Attr(
std::string("case_change_action"),
std::string("string enum that cases output to be lowercased/uppercases/unchanged."
" Valid values are \"LOWER\", \"UPPER\", \"NONE\". Default is \"NONE\""),
AttributeProto::STRING,
std::string("NONE"))
.Attr(
std::string("is_case_sensitive"),
std::string("Boolean. Whether the identification of stop words in X is case-sensitive. Default is false"),
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"stopwords",
"List of stop words. If not set, no word would be removed from X.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"locale",
"Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased."
"Default en_US or platform specific equivalent as decided by the implementation.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.SetDoc(StringNormalizer_ver10_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
output_elem_type->set_elem_type(TensorProto::STRING);
if (!hasInputShape(ctx, 0)) {
return;
}
TensorShapeProto output_shape;
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
auto dim_size = input_shape.dim_size();
// Last axis dimension is unknown if we have stop-words since we do
// not know how many stop-words are dropped
if (dim_size == 1) {
// Unknown output dimension
output_shape.add_dim();
} else if (dim_size == 2) {
// Copy B-dim
auto& b_dim = input_shape.dim(0);
if (!b_dim.has_dim_value() || b_dim.dim_value() != 1) {
fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0");
}
*output_shape.add_dim() = b_dim;
output_shape.add_dim();
} else {
fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0");
}
updateOutputShape(ctx, 0, output_shape);
}));
} // namespace ONNX_NAMESPACE

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,657 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/schema.h"
#ifdef ONNX_ML
namespace ONNX_NAMESPACE {
static const char* LabelEncoder_ver1_doc = R"DOC(
Converts strings to integers and vice versa.<br>
If the string default value is set, it will convert integers to strings.
If the int default value is set, it will convert strings to integers.<br>
Each operator converts either integers to strings or strings to integers, depending
on which default value attribute is provided. Only one default value attribute
should be defined.<br>
When converting from integers to strings, the string is fetched from the
'classes_strings' list, by simple indexing.<br>
When converting from strings to integers, the string is looked up in the list
and the index at which it is found is used as the converted value.
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
LabelEncoder,
1,
OpSchema()
.SetDoc(LabelEncoder_ver1_doc)
.Input(0, "X", "Input data.", "T1")
.Output(0, "Y", "Output data. If strings are input, the output values are integers, and vice versa.", "T2")
.TypeConstraint(
"T1",
{"tensor(string)", "tensor(int64)"},
"The input type must be a tensor of integers or strings, of any shape.")
.TypeConstraint(
"T2",
{"tensor(string)", "tensor(int64)"},
"The output type will be a tensor of strings or integers, and will have the same shape as the input.")
.Attr("classes_strings", "A list of labels.", AttributeProto::STRINGS, OPTIONAL_VALUE)
.Attr(
"default_int64",
"An integer to use when an input string value is not found in the map.<br>One and only one of the "
"'default_*' attributes must be defined.",
AttributeProto::INT,
static_cast<int64_t>(-1))
.Attr(
"default_string",
"A string to use when an input integer value is not found in the map.<br>One and only one of the "
"'default_*' attributes must be defined.",
AttributeProto::STRING,
std::string("_Unused"))
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto input_elem_type = ctx.getInputType(0)->tensor_type().elem_type();
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
if (TensorProto::STRING == input_elem_type) {
output_elem_type->set_elem_type(TensorProto::INT64);
} else if (TensorProto::INT64 == input_elem_type) {
output_elem_type->set_elem_type(TensorProto::STRING);
}
}));
static const char* TreeEnsembleClassifier_ver1_doc = R"DOC(
Tree Ensemble classifier. Returns the top class for each of N inputs.<br>
The attributes named 'nodes_X' form a sequence of tuples, associated by
index into the sequences, which must all be of equal length. These tuples
define the nodes.<br>
Similarly, all fields prefixed with 'class_' are tuples of votes at the leaves.
A leaf may have multiple votes, where each vote is weighted by
the associated class_weights index.<br>
One and only one of classlabels_strings or classlabels_int64s
will be defined. The class_ids are indices into this list.
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
TreeEnsembleClassifier,
1,
OpSchema()
.SetDoc(TreeEnsembleClassifier_ver1_doc)
.Input(0, "X", "Input of shape [N,F]", "T1")
.Output(0, "Y", "N, Top class for each point", "T2")
.Output(1, "Z", "The class score for each class, for each point, a tensor of shape [N,E].", "tensor(float)")
.TypeConstraint(
"T1",
{"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
"The input type must be a tensor of a numeric type.")
.TypeConstraint(
"T2",
{"tensor(string)", "tensor(int64)"},
"The output type will be a tensor of strings or integers, depending on which of the classlabels_* "
"attributes is used.")
.Attr("nodes_treeids", "Tree id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_nodeids",
"Node id for each node. Ids may restart at zero for each tree, but it not required to.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("nodes_featureids", "Feature id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_values",
"Thresholds to do the splitting on for each node.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_modes",
"The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
"node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("nodes_truenodeids", "Child node if expression is true.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("nodes_falsenodeids", "Child node if expression is false.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_missing_value_tracks_true",
"For each node, define what to do in the presence of a missing value: if a value is missing (NaN), use the "
"'true' or 'false' branch based on the value in this array.<br>This attribute may be left undefined, and "
"the default value is false (0) for all nodes.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("class_treeids", "The id of the tree that this node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_nodeids", "node id that this weight is for.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_ids", "The index of the class list that each weight is for.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_weights", "The weight for the class in class_id.", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr(
"classlabels_strings",
"Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be "
"defined.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"classlabels_int64s",
"Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be "
"defined.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr(
"post_transform",
"Indicates the transform to apply to the score. <br> One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
"or 'PROBIT.'",
AttributeProto::STRING,
std::string("NONE"))
.Attr(
"base_values",
"Base values for classification, added to final class score; the size must be the same as the classes or "
"can be left unassigned (assumed 0)",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
std::vector<std::string> label_strs;
auto result = getRepeatedAttribute(ctx, "classlabels_strings", label_strs);
bool using_strings = (result && !label_strs.empty());
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
if (using_strings) {
output_elem_type->set_elem_type(TensorProto::STRING);
} else {
output_elem_type->set_elem_type(TensorProto::INT64);
}
}));
static const char* TreeEnsembleClassifier_ver3_doc = R"DOC(
Tree Ensemble classifier. Returns the top class for each of N inputs.<br>
The attributes named 'nodes_X' form a sequence of tuples, associated by
index into the sequences, which must all be of equal length. These tuples
define the nodes.<br>
Similarly, all fields prefixed with 'class_' are tuples of votes at the leaves.
A leaf may have multiple votes, where each vote is weighted by
the associated class_weights index.<br>
One and only one of classlabels_strings or classlabels_int64s
will be defined. The class_ids are indices into this list.
All fields ending with <i>_as_tensor</i> can be used instead of the
same parameter without the suffix if the element type is double and not float.
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
TreeEnsembleClassifier,
3,
OpSchema()
.SetDoc(TreeEnsembleClassifier_ver3_doc)
.Input(0, "X", "Input of shape [N,F]", "T1")
.Output(0, "Y", "N, Top class for each point", "T2")
.Output(1, "Z", "The class score for each class, for each point, a tensor of shape [N,E].", "tensor(float)")
.TypeConstraint(
"T1",
{"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
"The input type must be a tensor of a numeric type.")
.TypeConstraint(
"T2",
{"tensor(string)", "tensor(int64)"},
"The output type will be a tensor of strings or integers, depending on which of the classlabels_* "
"attributes is used.")
.Attr("nodes_treeids", "Tree id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_nodeids",
"Node id for each node. Ids may restart at zero for each tree, but it not required to.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("nodes_featureids", "Feature id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_values",
"Thresholds to do the splitting on for each node.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_values_as_tensor",
"Thresholds to do the splitting on for each node.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates_as_tensor",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Attr(
"nodes_modes",
"The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
"node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("nodes_truenodeids", "Child node if expression is true.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("nodes_falsenodeids", "Child node if expression is false.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_missing_value_tracks_true",
"For each node, define what to do in the presence of a missing value: if a value is missing (NaN), use the "
"'true' or 'false' branch based on the value in this array.<br>This attribute may be left undefined, and "
"the default value is false (0) for all nodes.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("class_treeids", "The id of the tree that this node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_nodeids", "node id that this weight is for.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_ids", "The index of the class list that each weight is for.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("class_weights", "The weight for the class in class_id.", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr(
"class_weights_as_tensor",
"The weight for the class in class_id.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Attr(
"classlabels_strings",
"Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be "
"defined.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"classlabels_int64s",
"Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be "
"defined.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr(
"post_transform",
"Indicates the transform to apply to the score. <br> One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
"or 'PROBIT.'",
AttributeProto::STRING,
std::string("NONE"))
.Attr(
"base_values",
"Base values for classification, added to final class score; the size must be the same as the classes or "
"can be left unassigned (assumed 0)",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"base_values_as_tensor",
"Base values for classification, added to final class score; the size must be the same as the classes or "
"can be left unassigned (assumed 0)",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto* nodes_values = ctx.getAttribute("nodes_values");
auto* nodes_values_as_tensor = ctx.getAttribute("nodes_values_as_tensor");
auto* nodes_hitrates = ctx.getAttribute("nodes_hitrates");
auto* nodes_hitrates_as_tensor = ctx.getAttribute("nodes_hitrates_as_tensor");
auto* class_weights = ctx.getAttribute("class_weights");
auto* class_weights_as_tensor = ctx.getAttribute("class_weights_as_tensor");
auto* base_values = ctx.getAttribute("base_values");
auto* base_values_as_tensor = ctx.getAttribute("base_values_as_tensor");
if (nullptr != nodes_values && nullptr != nodes_values_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'nodes_values', 'nodes_values_as_tensor' should be specified.");
}
if (nullptr != nodes_hitrates && nullptr != nodes_hitrates_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'nodes_hitrates', 'nodes_hitrates_as_tensor' should be specified.");
}
if (nullptr != class_weights && nullptr != class_weights_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'class_weights', 'class_weights_as_tensor' should be specified.");
}
if (nullptr != base_values && nullptr != base_values_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
}
std::vector<std::string> classlabels_strings;
auto result = getRepeatedAttribute(ctx, "classlabels_strings", classlabels_strings);
bool using_strings = (result && !classlabels_strings.empty());
if (using_strings) {
updateOutputElemType(ctx, 0, TensorProto::STRING);
} else {
updateOutputElemType(ctx, 0, TensorProto::INT64);
}
updateOutputElemType(ctx, 1, TensorProto::FLOAT);
checkInputRank(ctx, 0, 2);
Dim N, E;
unifyInputDim(ctx, 0, 0, N);
if (using_strings) {
unifyDim(E, classlabels_strings.size());
} else {
std::vector<int64_t> classlabels_int64s;
result = getRepeatedAttribute(ctx, "classlabels_int64s", classlabels_int64s);
if (!result || classlabels_int64s.empty()) {
fail_shape_inference("Non of classlabels_int64s or classlabels_strings is set.");
}
unifyDim(E, classlabels_int64s.size());
}
updateOutputShape(ctx, 0, {N});
updateOutputShape(ctx, 1, {N, E});
}));
static const char* TreeEnsembleRegressor_ver1_doc = R"DOC(
Tree Ensemble regressor. Returns the regressed values for each input in N.<br>
All args with nodes_ are fields of a tuple of tree nodes, and
it is assumed they are the same length, and an index i will decode the
tuple across these inputs. Each node id can appear only once
for each tree id.<br>
All fields prefixed with target_ are tuples of votes at the leaves.<br>
A leaf may have multiple votes, where each vote is weighted by
the associated target_weights index.<br>
All trees must have their node ids start at 0 and increment by 1.<br>
Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
TreeEnsembleRegressor,
1,
OpSchema()
.SetDoc(TreeEnsembleRegressor_ver1_doc)
.Input(0, "X", "Input of shape [N,F]", "T")
.Output(0, "Y", "N classes", "tensor(float)")
.TypeConstraint(
"T",
{"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
"The input type must be a tensor of a numeric type.")
.Attr("nodes_treeids", "Tree id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_nodeids",
"Node id for each node. Node ids must restart at zero for each tree and increase sequentially.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("nodes_featureids", "Feature id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_values",
"Thresholds to do the splitting on for each node.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_modes",
"The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
"node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("nodes_truenodeids", "Child node if expression is true", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("nodes_falsenodeids", "Child node if expression is false", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_missing_value_tracks_true",
"For each node, define what to do in the presence of a NaN: use the 'true' (if the attribute value is 1) "
"or 'false' (if the attribute value is 0) branch based on the value in this array.<br>This attribute may "
"be left undefined and the default value is false (0) for all nodes.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("target_treeids", "The id of the tree that each node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_nodeids", "The node id of each weight", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_ids", "The index of the target that each weight is for", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_weights", "The weight for each target", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr("n_targets", "The total number of targets.", AttributeProto::INT, OPTIONAL_VALUE)
.Attr(
"post_transform",
"Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
"or 'PROBIT'",
AttributeProto::STRING,
std::string("NONE"))
.Attr(
"aggregate_function",
"Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'",
AttributeProto::STRING,
std::string("SUM"))
.Attr(
"base_values",
"Base values for classification, added to final class score; the size must be the same as the classes or "
"can be left unassigned (assumed 0)",
AttributeProto::FLOATS,
OPTIONAL_VALUE));
static const char* TreeEnsembleRegressor_ver3_doc = R"DOC(
Tree Ensemble regressor. Returns the regressed values for each input in N.<br>
All args with nodes_ are fields of a tuple of tree nodes, and
it is assumed they are the same length, and an index i will decode the
tuple across these inputs. Each node id can appear only once
for each tree id.<br>
All fields prefixed with target_ are tuples of votes at the leaves.<br>
A leaf may have multiple votes, where each vote is weighted by
the associated target_weights index.<br>
All fields ending with <i>_as_tensor</i> can be used instead of the
same parameter without the suffix if the element type is double and not float.
All trees must have their node ids start at 0 and increment by 1.<br>
Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
TreeEnsembleRegressor,
3,
OpSchema()
.SetDoc(TreeEnsembleRegressor_ver3_doc)
.Input(0, "X", "Input of shape [N,F]", "T")
.Output(0, "Y", "N classes", "tensor(float)")
.TypeConstraint(
"T",
{"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
"The input type must be a tensor of a numeric type.")
.Attr("nodes_treeids", "Tree id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_nodeids",
"Node id for each node. Node ids must restart at zero for each tree and increase sequentially.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("nodes_featureids", "Feature id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_values",
"Thresholds to do the splitting on for each node.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_values_as_tensor",
"Thresholds to do the splitting on for each node.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"nodes_hitrates_as_tensor",
"Popularity of each node, used for performance and may be omitted.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.Attr(
"nodes_modes",
"The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
"node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("nodes_truenodeids", "Child node if expression is true", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("nodes_falsenodeids", "Child node if expression is false", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr(
"nodes_missing_value_tracks_true",
"For each node, define what to do in the presence of a NaN: use the 'true' (if the attribute value is 1) "
"or 'false' (if the attribute value is 0) branch based on the value in this array.<br>This attribute may "
"be left undefined and the default value is false (0) for all nodes.",
AttributeProto::INTS,
OPTIONAL_VALUE)
.Attr("target_treeids", "The id of the tree that each node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_nodeids", "The node id of each weight", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_ids", "The index of the target that each weight is for", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("target_weights", "The weight for each target", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr("target_weights_as_tensor", "The weight for each target", AttributeProto::TENSOR, OPTIONAL_VALUE)
.Attr("n_targets", "The total number of targets.", AttributeProto::INT, OPTIONAL_VALUE)
.Attr(
"post_transform",
"Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
"or 'PROBIT'",
AttributeProto::STRING,
std::string("NONE"))
.Attr(
"aggregate_function",
"Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE,' 'SUM,' 'MIN,' 'MAX.'",
AttributeProto::STRING,
std::string("SUM"))
.Attr(
"base_values",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be "
"the same as the classes or can be left unassigned (assumed 0)",
AttributeProto::FLOATS,
OPTIONAL_VALUE)
.Attr(
"base_values_as_tensor",
"Base values for regression, added to final prediction after applying aggregate_function; the size must be "
"the same as the classes or can be left unassigned (assumed 0)",
AttributeProto::TENSOR,
OPTIONAL_VALUE)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto* nodes_values = ctx.getAttribute("nodes_values");
auto* nodes_values_as_tensor = ctx.getAttribute("nodes_values_as_tensor");
auto* nodes_hitrates = ctx.getAttribute("nodes_hitrates");
auto* nodes_hitrates_as_tensor = ctx.getAttribute("nodes_hitrates_as_tensor");
auto* target_weights = ctx.getAttribute("target_weights");
auto* target_weights_as_tensor = ctx.getAttribute("target_weights_as_tensor");
auto* base_values = ctx.getAttribute("base_values");
auto* base_values_as_tensor = ctx.getAttribute("base_values_as_tensor");
if (nullptr != nodes_values && nullptr != nodes_values_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'nodes_values', 'nodes_values_as_tensor' should be specified.");
}
if (nullptr != nodes_hitrates && nullptr != nodes_hitrates_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'nodes_hitrates', 'nodes_hitrates_as_tensor' should be specified.");
}
if (nullptr != target_weights && nullptr != target_weights_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'target_weights', 'target_weights_as_tensor' should be specified.");
}
if (nullptr != base_values && nullptr != base_values_as_tensor) {
fail_shape_inference(
"Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
}
checkInputRank(ctx, 0, 2);
Dim N, E;
unifyInputDim(ctx, 0, 0, N);
if (nullptr != ctx.getAttribute("n_targets")) {
unifyDim(E, ctx.getAttribute("n_targets")->i());
}
updateOutputElemType(ctx, 0, TensorProto::FLOAT);
updateOutputShape(ctx, 0, {N, E});
}));
static const char* LabelEncoder_ver2_doc = R"DOC(
Maps each element in the input tensor to another value.<br>
The mapping is determined by the two parallel attributes, 'keys_*' and
'values_*' attribute. The i-th value in the specified 'keys_*' attribute
would be mapped to the i-th value in the specified 'values_*' attribute. It
implies that input's element type and the element type of the specified
'keys_*' should be identical while the output type is identical to the
specified 'values_*' attribute. If an input element can not be found in the
specified 'keys_*' attribute, the 'default_*' that matches the specified
'values_*' attribute may be used as its output value.<br>
Let's consider an example which maps a string tensor to an integer tensor.
Assume and 'keys_strings' is ["Amy", "Sally"], 'values_int64s' is [5, 6],
and 'default_int64' is '-1'. The input ["Dori", "Amy", "Amy", "Sally",
"Sally"] would be mapped to [-1, 5, 5, 6, 6].<br>
Since this operator is an one-to-one mapping, its input and output shapes
are the same. Notice that only one of 'keys_*'/'values_*' can be set.<br>
For key look-up, bit-wise comparison is used so even a float NaN can be
mapped to a value in 'values_*' attribute.<br>
)DOC";
ONNX_ML_OPERATOR_SET_SCHEMA(
LabelEncoder,
2,
OpSchema()
.SetDoc(LabelEncoder_ver2_doc)
.Input(0, "X", "Input data. It can be either tensor or scalar.", "T1")
.Output(0, "Y", "Output data.", "T2")
.TypeConstraint(
"T1",
{"tensor(string)", "tensor(int64)", "tensor(float)"},
"The input type is a tensor of any shape.")
.TypeConstraint(
"T2",
{"tensor(string)", "tensor(int64)", "tensor(float)"},
"Output type is determined by the specified 'values_*' attribute.")
.Attr(
"keys_strings",
"A list of strings. One and only one of 'keys_*'s should be set.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("keys_int64s", "A list of ints.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("keys_floats", "A list of floats.", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr(
"values_strings",
"A list of strings. One and only one of 'value_*'s should be set.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr("values_int64s", "A list of ints.", AttributeProto::INTS, OPTIONAL_VALUE)
.Attr("values_floats", "A list of floats.", AttributeProto::FLOATS, OPTIONAL_VALUE)
.Attr("default_string", "A string.", AttributeProto::STRING, std::string("_Unused"))
.Attr("default_int64", "An integer.", AttributeProto::INT, static_cast<int64_t>(-1))
.Attr("default_float", "A float.", AttributeProto::FLOAT, -0.f)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Label encoder is one-to-one mapping.
if (ctx.getNumInputs() != 1) {
fail_shape_inference("Label encoder has only one input.");
}
if (ctx.getNumOutputs() != 1) {
fail_shape_inference("Label encoder has only one output.");
}
// Load all key_* attributes.
std::vector<std::string> keys_strings;
bool keys_strings_result = getRepeatedAttribute(ctx, "keys_strings", keys_strings);
std::vector<int64_t> keys_int64s;
bool keys_int64s_result = getRepeatedAttribute(ctx, "keys_int64s", keys_int64s);
std::vector<float> keys_floats;
bool keys_floats_result = getRepeatedAttribute(ctx, "keys_floats", keys_floats);
// Check if only one keys_* attribute is set.
if (static_cast<int>(keys_strings_result) + static_cast<int>(keys_int64s_result) +
static_cast<int>(keys_floats_result) !=
1) {
fail_shape_inference("Only one of keys_*'s can be set in label encoder.");
}
// Check if the specified keys_* matches input type.
auto input_elem_type = ctx.getInputType(0)->tensor_type().elem_type();
if (keys_strings_result && input_elem_type != TensorProto::STRING) {
fail_shape_inference("Input type is not string tensor but key_strings is set");
}
if (keys_int64s_result && input_elem_type != TensorProto::INT64) {
fail_shape_inference("Input type is not int64 tensor but keys_int64s is set");
}
if (keys_floats_result && input_elem_type != TensorProto::FLOAT) {
fail_shape_inference("Input type is not float tensor but keys_floats is set");
}
// Load all values_* attributes.
std::vector<std::string> values_strings;
bool values_strings_result = getRepeatedAttribute(ctx, "values_strings", values_strings);
std::vector<int64_t> values_int64s;
bool values_int64s_result = getRepeatedAttribute(ctx, "values_int64s", values_int64s);
std::vector<float> values_floats;
bool values_floats_result = getRepeatedAttribute(ctx, "values_floats", values_floats);
// Check if only one values_* attribute is set.
if (static_cast<int>(values_strings_result) + static_cast<int>(values_int64s_result) +
static_cast<int>(values_floats_result) !=
1) {
fail_shape_inference("Only one of values_*'s can be set in label encoder.");
}
// Assign output type based on the specified values_*.
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
if (values_strings_result)
output_elem_type->set_elem_type(TensorProto::STRING);
if (values_int64s_result)
output_elem_type->set_elem_type(TensorProto::INT64);
if (values_floats_result)
output_elem_type->set_elem_type(TensorProto::FLOAT);
// Input and output shapes are the same.
propagateShapeFromInputToOutput(ctx, 0, 0);
}));
} // namespace ONNX_NAMESPACE
#endif

View File

@ -0,0 +1,27 @@
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
namespace ONNX_NAMESPACE {
void AssertAttributeProtoTypeAndLength(
const AttributeProto* attr_proto,
int expected_length,
TensorProto_DataType expected_type,
bool required) {
if (nullptr == attr_proto) {
if (required) {
fail_shape_inference("Unspecified required attribute.");
}
return;
}
const auto& [type, length] = getAttributeProtoElemTypeAndLength(attr_proto);
if (type != expected_type) {
fail_shape_inference(
"Attribute '", attr_proto->name(), "' must have type ", TensorProto_DataType_Name(expected_type), ".");
}
if (length != expected_length) {
fail_shape_inference("Attribute '", attr_proto->name(), "' must have ", expected_length, " elements.");
}
}
} // namespace ONNX_NAMESPACE

View File

@ -0,0 +1,624 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <cmath>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* Gradient_ver1_doc = R"DOC(
Gradient operator computes the partial derivatives of a specific tensor w.r.t.
some other tensors. This operator is widely used in gradient-based training
algorithms. To illustrate its use, let's consider a computation graph,
```
X -----.
|
v
W --> Conv --> H --> Gemm --> Y
^
|
Z
```
, where W and Z are trainable tensors. Note that operators' attributes are
omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of
Y with respect to W (Z). The user can compute gradient by inserting Gradient
operator to form another graph shown below.
```
W --> Conv --> H --> Gemm --> Y
| ^ ^
| | |
| X Z
| | |
| | .----------'
| | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
| | | "xs" followed by "zs")
| v v
'---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
| |
| '-----------------------------------> dY/dW (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```
By definition, the tensor "y" is a function of independent variables in "xs"
and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable
variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H"
cannot appear in "xs" and "zs". The reason is that "H" can be determined by
tensors "W" and "X" and therefore "H" is not an independent variable.
All outputs are optional. If needed, for example, user can assign an empty
string to the 1st output name of that Gradient to skip the generation of dY/dW.
Note that the concept of optional outputs can also be found in ONNX's RNN, GRU,
and LSTM.
Gradient operator can compute derivative against intermediate tensors. For
example, the gradient of Y with respect to H can be done via
```
W --> Conv --> H --> Gemm --> Y
^ | ^
| | |
X | Z
.-------' |
| .----------'
| | (H/Z is the 1st/2nd input of Gradient as shown in "xs")
v v
Gradient(xs=["H", "Z"], y="Y")
| |
| '-----------------------------------> dY/dH (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```
It is possible to represent high-order differentiation using Gradient operators.
For example, given the following linear model:
```
W --> Gemm --> Y --> Loss --> O
^ ^
| |
X L
```
To compute the 2nd order derivative of O with respect to W (denoted by
d^2O/dW^2), one can do
```
W --> Gemm --> Y --> Loss --> O
| ^ ^
| | |
| X .------------L
| | | |
| | | v
+------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient)
| | | |
| | | '---> dO/dW (2nd output of Gradient)
| v v
'---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of
| Gradient)
|
|
'---> d^2O/dW^2 (2nd output of Gradient)
```
The tensors named in attributes "xs", "zs", and "y" define the differentiated
computation graph, and the inputs to Gradient node define the values at
which the gradient is computed. We can feed different tensors to the identified
graph. For example, one can compute the gradient of Y with respect to H at
a specific value of H, H_1, by providing that value as an input to the Gradient
node.
```
W --> Conv --> H --> Gemm --> Y
^ ^
| |
X Z
Z_1 (2nd input of Gradient)
|
v
H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1.
|
'------------------------------> dY/dZ (2nd output of Gradient)
```
When the inputs of Gradient are the tensors named in "xs" and "zs", the
computation can be optimized. More specifically, intermediate variables in
forward pass can be reused if the gradient is computed via reverse-mode
auto-differentiation.
)DOC";
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Gradient,
1,
OpSchema()
.SetDoc(Gradient_ver1_doc)
.Input(
0,
"Inputs",
"The values fed into graph identified by the attributes. "
"The i-th input is the value of the i-th tensor specified in the "
"concatenated list of the attribute \"xs\" and the attribute "
" \"zs\". For example, if xs=[\"A\", \"B\"] and zs=[\"C\"], the "
"first input is used as the value of symbol \"A\" and the 3rd "
"input is substituted for all the occurrences of \"C\".",
"T1",
OpSchema::Variadic,
false)
.Output(
0,
"Outputs",
"The gradient of the tensor specified by the attribute \"y\" "
"with respect to each of tensors specified in the "
"attribute \"xs\". The i-th output is the gradient of \"y\" with "
"respect to the i-th tensor specified in the attribute \"xs\".",
"T2",
OpSchema::Variadic,
false)
.Attr(
"xs",
"Input tensor names of the differentiated sub-graph. It "
"contains only the necessary differentiated "
"inputs of a (sub-)graph. Variables (usually called "
"intermediate variables) that can be generated from inputs "
"cannot be included in this attribute.",
AttributeProto::STRINGS)
.Attr(
"zs",
"Input tensor names of the differentiated sub-graph. It "
"contains only the necessary non-differentiated "
"inputs of a (sub-)graph. Variables (usually called "
"intermediate variables) that can be generated from inputs "
"cannot be included in this attribute.",
AttributeProto::STRINGS,
OPTIONAL_VALUE)
.Attr(
"y",
"The targeted tensor. It can be viewed as the output of the "
"differentiated function. The attribute \"xs\" and attribute "
"\"zs\" are the minimal independent variable set that determines "
"the value of \"y\".",
AttributeProto::STRING)
.TypeConstraint("T1", OpSchema::all_tensor_types(), "Allow outputs to be any kind of tensor.")
.TypeConstraint(
"T2",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Allow inputs to be any kind of floating-point tensor."));
static const char* Adagrad_ver1_doc = R"DOC(
Compute one iteration of ADAGRAD, a stochastic gradient based optimization
algorithm. This operator can conduct the optimization of multiple tensor variables.
Let's define the behavior of this operator. As you can imagine, ADAGRAD requires
some parameters:
- The initial learning-rate "R".
- The update count "T". That is, the number of training iterations conducted.
- A L2-norm regularization coefficient "norm_coefficient".
- A learning-rate decay factor "decay_factor".
- A small constant "epsilon" to avoid dividing-by-zero.
At each ADAGRAD iteration, the optimized tensors are moved along a direction
computed based on their estimated gradient and accumulated squared gradient. Assume
that only a single tensor "X" is updated by this operator. We need the value of "X",
its gradient "G", and its accumulated squared gradient "H". Therefore, variables in
this operator's input list are sequentially "R", "T", "X", "G", and "H". Other
parameters are given as attributes because they are usually constants. Also, the
corresponding output tensors are the new value of "X" (called "X_new"), and then
the new accumulated squared gradient (called "H_new"). Those outputs are computed
from the given inputs following the pseudo code below.
Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
numpy-style broadcasting support. The pseudo code to compute those outputs is:
// Compute a scalar learning-rate factor. At the first update of X, T is generally
// 0 (0-based update index) or 1 (1-based update index).
r = R / (1 + T * decay_factor);
// Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
G_regularized = norm_coefficient * X + G;
// Compute new accumulated squared gradient.
H_new = H + G_regularized * G_regularized;
// Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...)
// computes element-wise square-root.
H_adaptive = Sqrt(H_new) + epsilon
// Compute the new value of "X".
X_new = X - r * G_regularized / H_adaptive;
If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same
pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a
concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
be concatenated too) and then just reuse the entire pseudo code.
Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
In that reference paper, this operator is a special case of the Figure 1's composite mirror
descent update.
)DOC";
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Adagrad,
1,
OpSchema()
.SetDoc(Adagrad_ver1_doc)
.Input(0, "R", "The initial learning rate.", "T1")
.Input(1, "T", "The update count of \"X\". It should be a scalar.", "T2")
.Input(
2,
"inputs",
"The current values of optimized tensors, followed by their "
"respective gradients, followed by their respective accumulated squared gradients."
"For example, if two tensor \"X_1\" and \"X_2\" "
"are optimized, "
"The input list would be "
"[\"X_1\", \"X_2\", "
"gradient of \"X_1\", "
"gradient of \"X_2\", "
"accumulated squared gradient of \"X_1\", "
"accumulated squared gradient of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Output(
0,
"outputs",
"Updated values of optimized tensors, followed by their updated "
"values of accumulated squared gradients. For example, "
"if two tensor \"X_1\" and \"X_2\" are "
"optimized, the output list would be [new value of \"X_1,\" new value of \"X_2\" "
"new accumulated squared gradient of \"X_1\", new accumulated squared gradient of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Attr("epsilon", "Small scalar to avoid dividing by zero.", AttributeProto::FLOAT, 1e-6f)
.Attr(
"decay_factor",
"The decay factor of learning rate after one update."
"The effective learning rate is computed by r = R / (1 + T * decay_factor). "
"Default to 0 so that increasing update counts doesn't reduce the learning rate.",
AttributeProto::FLOAT,
0.0f)
.Attr(
"norm_coefficient",
"Regularization coefficient in 0.5 * norm_coefficient * ||X||_2^2. Default to 0, "
"which means no regularization.",
AttributeProto::FLOAT,
0.0f)
.TypeConstraint("T1", {"tensor(float)", "tensor(double)"}, "Constrain input types to float scalars.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain input types to 64-bit integer scalars.")
.TypeConstraint("T3", {"tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// In comments below, we assume that the input list is
// [R, T, X1, X2, G1, G2, H1, H2] and the output list is
// [X1_new, X2_new, H1_new, H2_new].
// Compute the number of tuples (X, G, H).
auto num_optimized_tensors = (ctx.getNumInputs() - 2) / 3;
for (size_t i = 0; i < num_optimized_tensors; ++i) {
// Pass X1's and X2's shapes to X1_new and X2_new, respectively.
size_t i_in = 2 + i;
size_t i_out = i;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
// Pass H1's and H2's shapes to H1_new and H2_new, respectively.
i_in = 2 + 2 * num_optimized_tensors + i;
i_out = i + num_optimized_tensors;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
}
}));
static const char* Momentum_ver1_doc = R"DOC(
Compute one iteration of stochastic gradient update with momentum.
This operator can conduct the optimization of multiple tensor variables.
Let's define the behavior of this operator. As you can imagine, SG with momentum requires
several parameters:
- The learning-rate "R".
- The update count "T". That is, the number of conducted training iterations. It should
be zero in the first training iteration.
- A L2-norm regularization coefficient "norm_coefficient".
- A decay coefficient of previous accumulated gradient (i.e., momentum) "alpha".
- The scaling coefficient of current gradient "beta".
- An attribute to choose either standard momentum or Nesterov's momentum "mode" should
be used.
For the sake of simplicity, assume that there is only one tensor (called "X") to be optimized.
Other necessary inputs are "X"'s gradient (called "G") and "X"'s momentum (called "V"). This
Momentum operator maps all these inputs to the new value of "X" (called "X_new") and its new
momentum (called "V_new").
This operator supports two different momentum algorithms. Set the attribute "mode" to
"nesterov" if Nesterov's momentum is desired. Otherwise, set the attribute "model" to
"standard" to use standard momentum. Computation details are described subsequently.
Let "+", "-", "*", and "/" are all element-wise operations with numpy-style broadcasting.
Pseudo code for SG with standard momentum:
// Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared
// values of all elements in X.
G_regularized = norm_coefficient * X + G
// In the first training iteration, beta should always be 1.
beta_adjusted = T > 0 ? beta : 1
// Compute the current momentum based on previous momentum and the current gradient.
V_new = alpha * V + beta_adjusted * G_regularized
// Update X.
X_new = X - R * V_new
Pseudo code for SG with Nesterov's momentum:
// Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared
// values of all elements in X.
G_regularized = norm_coefficient * X + G;
// In the first training iteration, beta should always be 1.
beta_adjusted = T > 0 ? beta : 1
// Compute the current momentum based on previous momentum and the current gradient.
V_new = alpha * V + beta_adjusted * G_regularized;
// Compute final update direction and then update X.
X_new = X - R * (G_regularized + alpha * V_new)
If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2". The same
pseudo code would be extended to handle all tensors jointly. More specifically, we can view "X" as a
concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
be concatenated too) and then our pseudo code becomes applicable.
)DOC";
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Momentum,
1,
OpSchema()
.SetDoc(Momentum_ver1_doc)
.Input(0, "R", "The learning rate.", "T1")
.Input(1, "T", "Update count of \"X\". It should be a scalar.", "T2")
.Input(
2,
"inputs",
"It sequentially contains the current values of optimized tensors, then their "
"gradient tensors, and finally their momentum tensors. For example, if two tensors "
"\"X_1\" and \"X_2\" are optimized, The expected input list would be "
"[\"X_1\", \"X_2\", gradient of \"X_1\", gradient of \"X_2\", momentum of \"X_1\", momentum of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Output(
0,
"outputs",
"It sequentially contains the new values of optimized tensors and then the new "
"values of their momentum tensors. For example, if two tensors \"X_1\" and \"X_2\" are "
"optimized, the output list would be [new value of \"X_1,\" new value of \"X_2\" "
"new momentum of \"X_1\", new momentum of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Attr("alpha", "The decay factor of momentum. It should be a scalar.", AttributeProto::FLOAT)
.Attr(
"beta",
"The coefficient of gradient in computing new momentum. It should be a scalar.",
AttributeProto::FLOAT)
.Attr("norm_coefficient", "Coefficient of 0.5 * norm_coefficient * ||X||^2.", AttributeProto::FLOAT)
.Attr(
"mode",
"Its value should be either \"nesterov\" or \"standard\". The value \"nesterov\" leads "
"to the use of Nesterov's momentum while \"standard\" invokes stochastic gradient method "
"using standard momentum",
AttributeProto::STRING)
.TypeConstraint("T1", {"tensor(float)", "tensor(double)"}, "Constrain input types to float scalars.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain input types to 64-bit integer scalars.")
.TypeConstraint("T3", {"tensor(float)", "tensor(double)"}, "Constrain input types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Assume that the input list is [R, T, X1, X2, G1, G2, V1, V2] and
// output list is [X1_new, X2_new, V1_new, V2_new] for explaining
// the code below in a simpler way.
// The count of input tensors excluding "R" and "T".
auto num_adjustable_tensors = ctx.getNumInputs() - 2;
// Check number of (optimized tensor, gradient, momentum) tuples.
if (num_adjustable_tensors % 3 != 0) {
fail_shape_inference(
"The sum of optimized tensor count and momentum tensor count ",
"should be a multiple of 2 in the input list of Momentum operator");
}
// The count of "X1" and "X2".
auto num_optimized_tensors = num_adjustable_tensors / 3;
for (size_t i = 0; i < num_optimized_tensors; ++i) {
// Pass X1's/X2's shapes to X1_new/X2_new.
size_t i_in = 2 + i;
size_t i_out = i;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
// Pass V1's/V2's shapes to V1_new/V2_new.
i_in = 2 + 2 * num_optimized_tensors + i;
i_out = i + num_optimized_tensors;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
}
}));
static const char* Adam_ver1_doc = R"DOC(
Compute one iteration of Adam, a stochastic gradient based optimization
algorithm. This operator can conduct the optimization of multiple tensor variables.
Let's define the behavior of this operator. First of all, Adam requires
some parameters:
- The learning-rate "R".
- The update count "T". That is, the number of training iterations conducted.
- A L2-norm regularization coefficient "norm_coefficient".
- A small constant "epsilon" to avoid dividing-by-zero.
- Two coefficients, "alpha" and "beta".
At each Adam iteration, the optimized tensors are moved along a direction
computed based on their exponentially-averaged historical gradient and
exponentially-averaged historical squared gradient. Assume that only a tensor
"X" is being optimized. The rest of required information is
- the value of "X",
- "X"'s gradient (denoted by "G"),
- "X"'s exponentially-averaged historical gradient (denoted by "V"), and
- "X"'s exponentially-averaged historical squared gradient (denoted by "H").
Some of those parameters are passed into this operator as input tensors and others
are stored as this operator's attributes. Specifically, this operator's input tensor
list is ["R", "T", "X", "G", "V", "H"]. That is, "R" is the first input, "T" is
the second input, and so on. Other parameters are given as attributes because they
are constants. Moreover, the corresponding output tensors are
- the new value of "X" (called "X_new"),
- the new exponentially-averaged historical gradient (denoted by "V_new"), and
- the new exponentially-averaged historical squared gradient (denoted by "H_new").
Those outputs are computed following the pseudo code below.
Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
numpy-style broadcasting support. The pseudo code to compute those outputs is:
// Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
G_regularized = norm_coefficient * X + G
// Update exponentially-averaged historical gradient.
V_new = alpha * V + (1 - alpha) * G_regularized
// Update exponentially-averaged historical squared gradient.
H_new = beta * H + (1 - beta) * G_regularized * G_regularized
// Compute the element-wise square-root of H_new. V_new will be element-wisely
// divided by H_sqrt for a better update direction.
H_sqrt = Sqrt(H_new) + epsilon
// Compute learning-rate. Note that "alpha**T"/"beta**T" is alpha's/beta's T-th power.
R_adjusted = T > 0 ? R * Sqrt(1 - beta**T) / (1 - alpha**T) : R
// Compute new value of "X".
X_new = X - R_adjusted * V_new / H_sqrt
// Post-update regularization.
X_final = (1 - norm_coefficient_post) * X_new
If there are multiple inputs to be optimized, the pseudo code will be applied
independently to each of them.
)DOC";
ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA(
Adam,
1,
OpSchema()
.SetDoc(Adam_ver1_doc)
.Input(0, "R", "The initial learning rate.", "T1")
.Input(1, "T", "The update count of \"X\". It should be a scalar.", "T2")
.Input(
2,
"inputs",
"The tensors to be optimized, followed by their respective gradients, "
"followed by their respective accumulated gradients (aka momentum), "
"followed by their respective accumulated squared gradients. For example, "
"to optimize tensors \"X_1\" and \"X_2,\", the input list would be "
"[\"X_1\", \"X_2\", "
"gradient of \"X_1\", gradient of \"X_2\", "
"accumulated gradient of \"X_1\", accumulated gradient of \"X_2\", "
"accumulated squared gradient of \"X_1\", accumulated squared gradient of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Output(
0,
"outputs",
"New values of optimized tensors, "
"followed by their respective new accumulated gradients, "
"followed by their respective new accumulated squared gradients. "
"For example, if two tensors \"X_1\" and \"X_2\" are optimized, "
"the outputs list would be "
"[new value of \"X_1\", new value of \"X_2\", "
"new accumulated gradient of \"X_1\", "
"new accumulated gradient of \"X_2\", "
"new accumulated squared gradient of \"X_1\", "
"new accumulated squared gradient of \"X_2\"].",
"T3",
OpSchema::Variadic,
false)
.Attr(
"alpha",
"Coefficient of previously accumulated gradient in running average. Default to 0.9.",
AttributeProto::FLOAT,
0.9f)
.Attr(
"beta",
"Coefficient of previously accumulated squared-gradient in running average. Default to 0.999.",
AttributeProto::FLOAT,
0.999f)
.Attr(
"norm_coefficient",
"Regularization coefficient of 0.5 * norm_coefficient * ||X||_2^2. Default to 0, "
"which means no regularization.",
AttributeProto::FLOAT,
0.0f)
.Attr(
"norm_coefficient_post",
"Regularization coefficient of 0.5 * norm_coefficient * ||X||_2^2. Default to 0, "
"which means no regularization.",
AttributeProto::FLOAT,
0.0f)
.Attr("epsilon", "Small scalar to avoid dividing by zero.", AttributeProto::FLOAT, 1e-6f)
.TypeConstraint("T1", {"tensor(float)", "tensor(double)"}, "Constrain input types to float scalars.")
.TypeConstraint("T2", {"tensor(int64)"}, "Constrain input types to 64-bit integer scalars.")
.TypeConstraint("T3", {"tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Assume that the input list is [R, T, X1, X2, G1, G2, V1, V2, H1, H2] and
// output list is [X1_new, X2_new, V1_new, V2_new, H1_new, H2_new] for explaining
// the code below in a simpler way.
// The count of input tensors excluding "R" and "T".
auto num_adjustable_tensors = ctx.getNumInputs() - 2;
// Check number of (optimized tensor, gradient, momentum) tuples.
if (num_adjustable_tensors % 4 != 0) {
fail_shape_inference(
"The sum of optimized tensor count, gradient tensor count, momentum tensor count, ",
"accumulated squared-gradient tensor count should be a multiple of 4 in the ",
"\"inputs\" of Adam operator.");
}
// The count of "X1" and "X2".
auto num_optimized_tensors = num_adjustable_tensors / 4;
for (size_t i = 0; i < num_optimized_tensors; ++i) {
// Pass X1's/X2's shapes to X1_new/X2_new.
size_t i_in = 2 + i;
size_t i_out = i;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
// Pass V1's/V2's shapes to V1_new/V2_new.
i_in = 2 + 2 * num_optimized_tensors + i;
i_out = num_optimized_tensors + i;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
// Pass H1's/H2's shapes to H1_new/H2_new.
i_in = 2 + 3 * num_optimized_tensors + i;
i_out = 2 * num_optimized_tensors + i;
propagateElemTypeFromInputToOutput(ctx, i_in, i_out);
propagateShapeFromInputToOutput(ctx, i_in, i_out);
}
}));
} // namespace ONNX_NAMESPACE