455 lines
20 KiB
C++
455 lines
20 KiB
C++
// Copyright (c) ONNX Project Contributors
|
|
|
|
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#include <assert.h>
|
|
|
|
#include "onnx/defs/controlflow/utils.h"
|
|
#include "onnx/defs/schema.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
using SupportType = OpSchema::SupportType;
|
|
|
|
static std::vector<std::string> control_flow_types_ir10() {
|
|
auto t = OpSchema::all_tensor_types_ir10();
|
|
auto s = OpSchema::all_tensor_sequence_types_ir10();
|
|
auto o = OpSchema::all_optional_types_ir10();
|
|
t.insert(t.end(), s.begin(), s.end());
|
|
t.insert(t.end(), o.begin(), o.end());
|
|
return t;
|
|
}
|
|
|
|
ONNX_OPERATOR_SET_SCHEMA(
|
|
If,
|
|
21,
|
|
OpSchema()
|
|
.SetDoc("If conditional")
|
|
.Input(0, "cond", "Condition for the if. The tensor must contain a single element.", "B")
|
|
.Output(
|
|
0,
|
|
"outputs",
|
|
"Values that are live-out to the enclosing scope. The return values in "
|
|
"the `then_branch` and `else_branch` must be of the same data type. "
|
|
"The `then_branch` and `else_branch` may produce tensors with the same "
|
|
"element type and different shapes. "
|
|
"If corresponding outputs from the then-branch and the else-branch have "
|
|
"static shapes S1 and S2, then the shape of the corresponding output "
|
|
"variable of the if-node (if present) must be compatible with both S1 "
|
|
"and S2 as it represents the union of both possible shapes."
|
|
"For example, if in a model file, the first "
|
|
"output of `then_branch` is typed float tensor with shape [2] and the "
|
|
"first output of `else_branch` is another float tensor with shape [3], "
|
|
"If's first output should have (a) no shape set, or (b) "
|
|
"a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) "
|
|
"a shape of rank 1 with a unique `dim_param`. "
|
|
"In contrast, the first output cannot have the shape [2] since [2] and "
|
|
"[3] are not compatible.",
|
|
"V",
|
|
OpSchema::Variadic,
|
|
false)
|
|
.Attr(
|
|
"then_branch",
|
|
"Graph to run if condition is true. Has N outputs: values you wish to "
|
|
"be live-out to the enclosing scope. The number of outputs must match"
|
|
" the number of outputs in the else_branch.",
|
|
AttributeProto::GRAPH)
|
|
.Attr(
|
|
"else_branch",
|
|
"Graph to run if condition is false. Has N outputs: values you wish to"
|
|
" be live-out to the enclosing scope. The number of outputs must match"
|
|
" the number of outputs in the then_branch.",
|
|
AttributeProto::GRAPH)
|
|
.TypeConstraint(
|
|
"V",
|
|
control_flow_types_ir10(),
|
|
"All Tensor, Sequence(Tensor), Optional(Tensor), and Optional(Sequence(Tensor)) types up to IRv10.")
|
|
.TypeConstraint("B", {"tensor(bool)"}, "Only bool")
|
|
.TypeAndShapeInferenceFunction(IfInferenceFunction));
|
|
|
|
static const char* Loop_ver16_doc = R"DOC(
|
|
Generic Looping construct. This loop has multiple termination conditions:
|
|
|
|
1) Trip count. Iteration count specified at runtime. Set by
|
|
specifying the input M. Optional. Set to empty string to omit.
|
|
Note that a static trip count (specified at graph construction time) can be
|
|
specified by passing in a constant node for input M.
|
|
2) Loop termination condition. This is an input to the op that determines
|
|
whether to run the first iteration and also a loop-carried dependency for
|
|
the body graph. The body graph must yield a value for the condition variable,
|
|
whether this input is provided or not.
|
|
|
|
This table summarizes the operating modes of this operator with equivalent
|
|
C-style code:
|
|
|
|
Operator inputs defined as (max_trip_count, condition_var).
|
|
|
|
* input ("", ""):
|
|
for (int i=0; ; ++i) {
|
|
cond = ... // Note this value is ignored, but is required in the body
|
|
}
|
|
|
|
* input ("", cond) // Note this is analogous to a while loop
|
|
bool cond = ...;
|
|
for (int i=0; cond; ++i) {
|
|
cond = ...;
|
|
}
|
|
|
|
* input ("", 1) // Note this is analogous to a do-while loop
|
|
bool cond = true
|
|
for (int i=0; cond; ++i) {
|
|
cond = ...;
|
|
}
|
|
|
|
* input (trip_count, "") // Note this is analogous to a for loop
|
|
int trip_count = ...
|
|
for (int i=0; i < trip_count; ++i) {
|
|
cond = ...; // ignored
|
|
}
|
|
|
|
* input (trip_count, cond)
|
|
int trip_count = ...;
|
|
bool cond = ...;
|
|
for (int i=0; i < trip_count && cond; ++i) {
|
|
cond = ...;
|
|
}
|
|
|
|
|
|
*Sample usage - cond as well as trip count*
|
|
|
|
graph predict-net {
|
|
%a = Constant[value = <Scalar Tensor [3]>]()
|
|
%b = Constant[value = <Scalar Tensor [6]>]()
|
|
%keepgoing = Constant[value = <Scalar Tensor [1]>]()
|
|
%max_trip_count = Constant[value = <Scalar Tensor [10]>]()
|
|
%keepgoing_out, %b_out, %user_defined_vals = Loop[body = <graph body-net>](%max_trip_count, %keepgoing, %b)
|
|
return
|
|
}
|
|
|
|
graph body-net (
|
|
%i[INT32, scalar] // iteration number
|
|
%keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used
|
|
%b_in[INT32, scalar] // incoming value of loop-carried-dependency b
|
|
) {
|
|
%my_local = Add(%a, %b_in)
|
|
%b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b
|
|
%keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition
|
|
%user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated
|
|
return %keepgoing_out, %b_out, %user_defined_val
|
|
}
|
|
|
|
*Sample equivalent C code*
|
|
|
|
{
|
|
/* User-defined code (enclosing scope) */
|
|
int a = 3, b = 6;
|
|
bool keepgoing = true; // Analogous to input cond
|
|
/* End user-defined code */
|
|
|
|
/* Implicitly-defined code */
|
|
const int max_trip_count = 10; // Analogous to input M
|
|
int user_defined_vals[]; // Imagine this is resizable
|
|
/* End implicitly-defined code */
|
|
/* initialize loop-carried variables and scan-output variables */
|
|
bool keepgoing_out = keepgoing
|
|
int b_out = b
|
|
|
|
for (int i=0; i < max_trip_count && keepgoing_out; ++i) {
|
|
/* Implicitly-defined code: bind actual parameter values
|
|
to formal parameter variables of loop-body */
|
|
bool keepgoing_in = keepgoing_out;
|
|
bool b_in = b_out;
|
|
|
|
/* User-defined code (loop body) */
|
|
int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine
|
|
b_out = a - b_in;
|
|
keepgoing_out = my_local > b_out;
|
|
user_defined_val = b_in + b_in; // b_in and b_out are different variables
|
|
/* End user-defined code */
|
|
|
|
/* Implicitly defined-code */
|
|
user_defined_vals[i] = user_defined_val // accumulate scan-output values
|
|
}
|
|
// int t = my_local; // Can't do this. my_local is not accessible here.
|
|
|
|
// The values below are bound to the output variables of the loop and therefore accessible
|
|
// b_out; user_defined_vals; keepgoing_out;
|
|
}
|
|
|
|
There are several things of note in this code snippet:
|
|
|
|
1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can
|
|
be referenced in the inputs of the loop.
|
|
2) Any values computed in the loop body that needs to be used in a subsequent
|
|
iteration or after the loop are modelled using a pair of variables in the loop-body,
|
|
consisting of an input variable (eg., b_in) and an output variable (eg., b_out).
|
|
These are referred to as loop-carried dependences. The loop operation node
|
|
supplies the input value of the input variable for the first iteration, and
|
|
returns the output value of the output variable produced by the final
|
|
iteration.
|
|
3) Scan_output variables are used to implicitly concatenate values computed across
|
|
all the iterations. In the above example, the value of user_defined_val computed
|
|
over all iterations are concatenated and returned as the value of user_defined_vals
|
|
after the loop.
|
|
4) Values created in the body cannot be accessed in the enclosing scope,
|
|
except using the mechanism described above.
|
|
|
|
Note that the semantics of this op support "diagonal" or "wavefront" execution.
|
|
(See Step 3 here for an example:
|
|
https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/).
|
|
Frontends should emit multi-layer RNNs as a series of While operators (with
|
|
time being the inner looping dimension), with each successive layer consuming
|
|
the scan_outputs from the previous layer, possibly going through several
|
|
point-wise operators (e.g. dropout, residual connections, linear layer).
|
|
|
|
The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order.
|
|
)DOC";
|
|
|
|
ONNX_OPERATOR_SET_SCHEMA(
|
|
Loop,
|
|
21,
|
|
OpSchema()
|
|
.SetDoc(Loop_ver16_doc)
|
|
.Input(
|
|
0,
|
|
"M",
|
|
"A maximum trip-count for the loop specified at runtime. Optional."
|
|
" Pass empty string to skip.",
|
|
"I",
|
|
OpSchema::Optional)
|
|
.Input(
|
|
1,
|
|
"cond",
|
|
"A boolean termination condition. Optional. Pass empty string to skip.",
|
|
"B",
|
|
OpSchema::Optional)
|
|
.Input(
|
|
2,
|
|
"v_initial",
|
|
"The initial values of any loop-carried dependencies (values that "
|
|
"change across loop iterations)",
|
|
"V",
|
|
OpSchema::Variadic,
|
|
false,
|
|
0)
|
|
.Output(
|
|
0,
|
|
"v_final_and_scan_outputs",
|
|
"Final N loop carried dependency values then K scan_outputs. "
|
|
"Scan outputs must be Tensors.",
|
|
"V",
|
|
OpSchema::Variadic,
|
|
false)
|
|
.Attr(
|
|
"body",
|
|
"The graph run each iteration. It has 2+N inputs: (iteration_num, "
|
|
"condition, loop carried dependencies...). It has 1+N+K outputs: "
|
|
"(condition, loop carried dependencies..., scan_outputs...). Each "
|
|
"scan_output is created by concatenating the value of the specified "
|
|
"output value at the end of each iteration of the loop. It is an error"
|
|
" if the dimensions or data type of these scan_outputs change across loop"
|
|
" iterations.",
|
|
AttributeProto::GRAPH)
|
|
.TypeConstraint(
|
|
"V",
|
|
control_flow_types_ir10(),
|
|
"All Tensor, Sequence(Tensor), Optional(Tensor), and Optional(Sequence(Tensor)) types up to IRv10.")
|
|
.TypeConstraint("I", {"tensor(int64)"}, "tensor of int64, which should be a scalar.")
|
|
.TypeConstraint("B", {"tensor(bool)"}, "tensor of bool, which should be a scalar.")
|
|
.TypeAndShapeInferenceFunction(LoopInferenceFunction));
|
|
|
|
static const char* scan_16_doc = R"DOC(
|
|
Scan can be used to iterate over one or more scan_input tensors,
|
|
constructing zero or more scan_output tensors. It combines ideas from general recurrences,
|
|
functional programming constructs such as scan, fold, map, and zip, and is intended to enable
|
|
generalizations of RNN-like constructs for sequence-to-sequence processing.
|
|
Other tensors (referred to as state_variables here) can be used to carry a state
|
|
when iterating from one element to another (similar to hidden-state in RNNs, also referred
|
|
to as loop-carried dependences in the context of loops).
|
|
Many common usages involve a single scan_input tensor (where functionality
|
|
similar to scan, fold and map can be obtained). When more than one scan_input is used,
|
|
a behavior similar to zip is obtained.
|
|
|
|
The attribute body must be a graph, specifying the computation to be performed in
|
|
every iteration. It takes as input the current values of the state_variables and
|
|
the current iterated element of the scan_inputs. It must return the (updated) values
|
|
of the state_variables and zero or more scan_output_element tensors. The values of the
|
|
scan_output_element tensors are concatenated over all the iterations to produce the
|
|
scan_output values of the scan construct (similar to the concatenated intermediate
|
|
hidden-state values of RNN-like constructs). All the output tensors (state_variables as
|
|
well as scan_output_element tensors) are required to have the same shape in each iteration
|
|
of the loop (a restriction imposed to enable efficient memory allocation).
|
|
|
|
Note that the iterated element passed to the body subgraph does not have a sequence
|
|
axis. It will have a rank one less than the rank of the corresponding scan_input.
|
|
|
|
The scan operation returns the final values of the state_variables as well as the
|
|
scan_outputs.
|
|
|
|
The optional attribute scan_input_directions specifies the direction (forward or backward)
|
|
for each scan input. If this attribute is omitted, all sequences are scanned in the forward
|
|
direction. A bidirectional scan may be performed by specifying the same tensor input twice
|
|
in the scan_inputs, once with a forward direction, and once with a backward direction.
|
|
|
|
The scan_output of the operation is produced by concatenating the scan_output_element
|
|
values produced by the body in each iteration. The optional attribute scan_output_directions
|
|
specifies the direction in which scan_output is constructed (by appending or prepending the
|
|
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
|
|
is omitted, the scan_output_element is appended to the scan_output in each iteration.
|
|
|
|
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
|
|
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
|
|
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
|
|
Note that scanning a non-zero axis may be less efficient than scanning axis zero.
|
|
|
|
The optional attribute scan_output_axes specifies the axis along which the scan_outputs
|
|
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
|
|
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
|
|
value of 1.
|
|
|
|
Note that because of the ONNX restriction that only the last parameter of an operator can
|
|
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
|
|
Similarly, the final-states and scan-outputs are listed together as one output parameter.
|
|
The attribute num_scan_inputs indicates the number M of scan-inputs.
|
|
|
|
The behavior of
|
|
|
|
Scan <
|
|
num_scan_inputs = m,
|
|
body = loop-body,
|
|
scan_input_axes = [axis_1, ..., axis_m]
|
|
> (init_1, ..., init_n, scan_1, ..., scan_m)
|
|
|
|
is equivalent to the following pseudo-code:
|
|
|
|
// scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i
|
|
// scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j.
|
|
sequence_length = scan_1.shape[axis_1];
|
|
|
|
// initialize state-variables
|
|
st_1 = init_1; ... st_n = init_n;
|
|
// initialize scan-output variables: [] denotes an empty tensor
|
|
scan_out_1 = []; ...; scan_out_k = [];
|
|
// identify number of iterations:
|
|
|
|
// execute loop
|
|
for (int t = 0; t < sequence_length; ++t) {
|
|
// generate the scan-input elements: the notation T<axis=k>[t] indicates the sub-tensor
|
|
// of rank one less than T obtained by indexing T at position t along axis k.
|
|
si_1 = scan_1<axis=axis_1>[t];
|
|
... ;
|
|
si_m = scan_m<axis=axis_m>[t];
|
|
// execute loop-body
|
|
st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m)
|
|
// accumulate the scan-output elements
|
|
scan_out_1 = Concat<axis=0>(scan_out_1, so_1); ... ; scan_out_k = Concat<axis=0>(scan_out_k, so_k);
|
|
}
|
|
|
|
return st_1, ..., st_n, scan_out_1, ..., scan_out_k;
|
|
|
|
*Sample usage: Encoding RNN using a Scan*
|
|
|
|
The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi,
|
|
recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can
|
|
be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes
|
|
%Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these
|
|
values are computed in the outer graph, they need to be passed in as extra state_variables.
|
|
|
|
graph rnn-encoding {
|
|
%H_0 = ...
|
|
%X = ...
|
|
%Y_h, %Y = Scan[body = <graph rnn-cell-1>, num_scan_inputs=1](%H_0, %X)
|
|
return %Y, %Y_h
|
|
}
|
|
|
|
graph rnn-cell-1 (
|
|
%H_tminus1[FLOAT, tensor]
|
|
%X_t[FLOAT, tensor]
|
|
) {
|
|
%Wi = ...
|
|
%Ri = ...
|
|
%Wbi = ...
|
|
%Rbi = ...
|
|
%t1 = X_t * (Wi^T)
|
|
%t2 = H_tminus1*(Ri^T)
|
|
%t3 = Add(%t1, %t2)
|
|
%t4 = Add(%t3, %Wbi)
|
|
%t5 = Add(%t4, %Rbi)
|
|
%Ht = Tanh(%t5)
|
|
%Accumulate = Identity(%Ht)
|
|
return %Ht, %Accumulate
|
|
}
|
|
|
|
)DOC";
|
|
|
|
ONNX_OPERATOR_SET_SCHEMA(
|
|
Scan,
|
|
21,
|
|
OpSchema()
|
|
.SetDoc(scan_16_doc)
|
|
.Input(
|
|
0,
|
|
"initial_state_and_scan_inputs",
|
|
"Initial values of the loop's N state variables followed by M scan_inputs",
|
|
"V",
|
|
OpSchema::Variadic,
|
|
false)
|
|
.Output(
|
|
0,
|
|
"final_state_and_scan_outputs",
|
|
"Final values of the loop's N state variables followed by K scan_outputs",
|
|
"V",
|
|
OpSchema::Variadic,
|
|
false)
|
|
.Attr(
|
|
"body",
|
|
"The graph run each iteration. It has N+M inputs: "
|
|
"(loop state variables..., scan_input_elts...). It has N+K outputs: "
|
|
"(loop state variables..., scan_output_elts...). Each "
|
|
"scan_output is created by concatenating the value of the specified "
|
|
"scan_output_elt value at the end of each iteration of the loop. It is an error"
|
|
" if the dimensions of these values change across loop iterations.",
|
|
AttributeProto::GRAPH,
|
|
true)
|
|
.Attr("num_scan_inputs", "An attribute specifying the number of scan_inputs M. ", AttributeProto::INT, true)
|
|
.Attr(
|
|
"scan_input_directions",
|
|
"An optional list of M flags. The i-th element of the list specifies the direction "
|
|
"to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 "
|
|
"indicates reverse direction. "
|
|
"If omitted, all scan_input tensors will be scanned in the forward direction.",
|
|
AttributeProto::INTS,
|
|
false)
|
|
.Attr(
|
|
"scan_output_directions",
|
|
"An optional list of K flags, one for each scan_output. The i-th element of the list "
|
|
"specifies whether the i-th scan_output should be constructed by appending or "
|
|
"prepending a new value in each iteration: 0 indicates appending and 1 "
|
|
"indicates prepending. "
|
|
"If omitted, all scan_output tensors will be produced by appending a value "
|
|
"in each iteration.",
|
|
AttributeProto::INTS,
|
|
false)
|
|
.Attr(
|
|
"scan_input_axes",
|
|
"An optional list of M flags. The i-th element of the list specifies the axis "
|
|
"to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will "
|
|
"be used as the scan axis for every scan_input. Negative value for an axis means "
|
|
"counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input).",
|
|
AttributeProto::INTS,
|
|
false)
|
|
.Attr(
|
|
"scan_output_axes",
|
|
"An optional list of K flags. The i-th element of the list specifies the axis "
|
|
"for the i-th scan_output. The scan outputs are accumulated along the specified "
|
|
"axis. If omitted, 0 will be used as the scan axis for every scan_output. "
|
|
"Negative value for an axis means counting dimensions from the back. Accepted "
|
|
"range is [-r, r-1].",
|
|
AttributeProto::INTS,
|
|
false)
|
|
.TypeConstraint("V", OpSchema::all_tensor_types_ir10(), "All Tensor types up to IRv10.")
|
|
.TypeAndShapeInferenceFunction(ScanInferenceFunction)); // Shares same shape inference as opset 11
|
|
|
|
} // namespace ONNX_NAMESPACE
|