360 lines
13 KiB
C++
360 lines
13 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#include "onnx/defs/controlflow/utils.h"
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
|
|
void ClearShape(TypeProto& input_type) {
|
|
if (input_type.has_tensor_type()) {
|
|
input_type.mutable_tensor_type()->clear_shape();
|
|
} else if (input_type.has_sequence_type()) {
|
|
auto& seq_type = *input_type.mutable_sequence_type();
|
|
if (seq_type.has_elem_type()) {
|
|
ClearShape(*(seq_type.mutable_elem_type()));
|
|
}
|
|
} else if (input_type.has_optional_type()) {
|
|
auto& opt_type = *input_type.mutable_optional_type();
|
|
if (opt_type.has_elem_type()) {
|
|
ClearShape(*(opt_type.mutable_elem_type()));
|
|
}
|
|
}
|
|
}
|
|
|
|
void IfInferenceFunction(InferenceContext& ctx) {
|
|
// there are no inputs so we just need to run the subgraph inferencing for
|
|
// then/else subgraphs and apply those to the outputs.
|
|
std::vector<const TypeProto*> subgraph_input_types; // none
|
|
std::vector<const TensorProto*> input_data; // none
|
|
|
|
std::vector<const TypeProto*> then_output_types;
|
|
std::vector<const TypeProto*> else_output_types;
|
|
|
|
// Run inferencing on the subgraph
|
|
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("then_branch");
|
|
if (graphInferencer) {
|
|
then_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
|
|
}
|
|
|
|
graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
|
|
if (graphInferencer) {
|
|
else_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
|
|
}
|
|
|
|
auto num_outputs = ctx.getNumOutputs();
|
|
auto num_then_outputs = then_output_types.size();
|
|
auto num_else_outputs = else_output_types.size();
|
|
|
|
// the output types for then and else should be the same
|
|
if (num_then_outputs != num_else_outputs) {
|
|
fail_type_inference(
|
|
"then_branch and else_branch produce different number of outputs. ",
|
|
num_then_outputs,
|
|
" != ",
|
|
num_else_outputs);
|
|
}
|
|
|
|
if (num_then_outputs != num_outputs) {
|
|
fail_type_inference("If node has ", num_outputs, " but subgraphs produce ", num_then_outputs);
|
|
}
|
|
|
|
for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
|
|
auto then_output = then_output_types[i];
|
|
auto else_output = else_output_types[i];
|
|
|
|
auto* if_output = ctx.getOutputType(i);
|
|
*if_output = *then_output;
|
|
|
|
UnionTypeInfo(*else_output, *if_output);
|
|
}
|
|
}
|
|
|
|
void LoopInferenceFunction(InferenceContext& ctx) {
|
|
auto num_inputs = ctx.getNumInputs();
|
|
assert(num_inputs >= 2);
|
|
auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
|
|
|
|
std::vector<const TypeProto*> subgraph_input_types;
|
|
subgraph_input_types.reserve(num_inputs);
|
|
|
|
std::vector<TypeProto> temporary_type_protos;
|
|
temporary_type_protos.reserve(num_inputs - 2);
|
|
|
|
// create TypeProto to validate iteration number type is the same as the
|
|
// optional 'M' input for max iterations.
|
|
TypeProto iter_num_type;
|
|
iter_num_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
|
|
subgraph_input_types.push_back(&iter_num_type);
|
|
|
|
// 'cond'
|
|
subgraph_input_types.push_back(ctx.getInputType(1));
|
|
|
|
// loop state value types get propagated to outputs, but shape may change
|
|
// across iterations so don't propagate it to the outputs and don't pass it
|
|
// into the subgraph inferencing
|
|
for (size_t i = 2; i < num_inputs; ++i) {
|
|
propagateElemTypeFromInputToOutput(ctx, i, i - 2);
|
|
|
|
// copy so we can remove the shape before passing to the subgraph
|
|
// inferencing
|
|
temporary_type_protos.push_back(*ctx.getInputType(i));
|
|
auto& input_type = temporary_type_protos.back();
|
|
|
|
ClearShape(input_type);
|
|
subgraph_input_types.push_back(&input_type);
|
|
}
|
|
|
|
// Run inferencing on the subgraph
|
|
std::vector<const TypeProto*> subgraph_output_types;
|
|
|
|
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
|
|
if (graphInferencer) {
|
|
std::vector<const TensorProto*> input_data;
|
|
input_data.push_back(nullptr); // iteration number
|
|
for (size_t i = 1; i < num_inputs; ++i) {
|
|
input_data.push_back(ctx.getInputData(i));
|
|
}
|
|
|
|
subgraph_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
|
|
}
|
|
|
|
// if empty(), assume inferencing was skipped
|
|
if (!subgraph_output_types.empty()) {
|
|
auto num_outputs = ctx.getNumOutputs();
|
|
|
|
// subgraph outputs the condition value first but that is only used
|
|
// internally and not returned by Loop.
|
|
if (subgraph_output_types.size() != num_outputs + 1) {
|
|
fail_type_inference(
|
|
"Graph attribute inferencing returned type information for ",
|
|
subgraph_output_types.size(),
|
|
" outputs. Expected ",
|
|
num_outputs + 1);
|
|
}
|
|
|
|
// check loop state values match. we should already have type/shape info
|
|
for (size_t i = 0; i < num_outputs; ++i) {
|
|
auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
|
|
auto* loop_output_type = ctx.getOutputType(i);
|
|
|
|
const bool is_loop_state_var = i < num_loop_state_vars;
|
|
|
|
if (!subgraph_output_type->has_tensor_type() && !subgraph_output_type->has_sequence_type() &&
|
|
!subgraph_output_type->has_optional_type()) {
|
|
fail_type_inference(
|
|
"Loop 'body' subgraph outputs should all be tensors or sequences or optionals, but output ",
|
|
i,
|
|
" was ",
|
|
subgraph_output_type->value_case());
|
|
}
|
|
|
|
if (!is_loop_state_var && !subgraph_output_type->has_tensor_type()) {
|
|
fail_type_inference(
|
|
"Loop 'body' subgraph scan outputs should all be tensors but output ",
|
|
i,
|
|
" was ",
|
|
subgraph_output_type->value_case());
|
|
}
|
|
|
|
// if there's an existing type check it matches. otherwise propagate
|
|
propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
|
|
|
|
if (is_loop_state_var) {
|
|
// shape may change across iterations so ignore.
|
|
} else {
|
|
// propagate shape
|
|
if (subgraph_output_type->tensor_type().has_shape()) {
|
|
// per iteration output. first dimension will be number of iterations
|
|
// but we don't know that value yet
|
|
TypeProto inferred_type(*subgraph_output_type);
|
|
auto* mutable_inferred_tensor_type = inferred_type.mutable_tensor_type();
|
|
auto* mutable_inferred_shape = mutable_inferred_tensor_type->mutable_shape();
|
|
|
|
mutable_inferred_shape->clear_dim();
|
|
|
|
// add empty dimension for number of iterations
|
|
mutable_inferred_shape->add_dim();
|
|
|
|
// add dimensions from subgraph output shape
|
|
for (const auto& dim : subgraph_output_type->tensor_type().shape().dim()) {
|
|
(*mutable_inferred_shape->add_dim()) = dim;
|
|
}
|
|
|
|
mergeInShapeInfo(*mutable_inferred_tensor_type, *loop_output_type->mutable_tensor_type());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
int handle_negative_axis_validate(const std::string& attrib, int axis, int rank) {
|
|
if (!(-rank <= axis && axis < rank)) {
|
|
fail_shape_inference(attrib, " axis value ", axis, " is invalid for a tensor of rank ", rank);
|
|
}
|
|
return (axis >= 0 ? axis : axis + rank);
|
|
}
|
|
|
|
void ScanInferenceFunction(InferenceContext& ctx) {
|
|
auto num_inputs = ctx.getNumInputs();
|
|
auto num_scan_inputs = narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
|
|
auto num_loop_state_vars = num_inputs - num_scan_inputs;
|
|
auto num_outputs = ctx.getNumOutputs();
|
|
auto num_scan_outputs = num_outputs - num_loop_state_vars;
|
|
|
|
std::vector<int64_t> axes, output_axes;
|
|
if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
|
|
if (axes.size() != num_scan_inputs) {
|
|
fail_shape_inference(
|
|
"Number of scan input axes specified (",
|
|
axes.size(),
|
|
") is not equal to number of scan inputs (",
|
|
num_scan_inputs,
|
|
").");
|
|
}
|
|
} else {
|
|
axes.insert(axes.end(), num_scan_inputs, 0);
|
|
}
|
|
|
|
if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
|
|
if (output_axes.size() != num_scan_outputs) {
|
|
fail_shape_inference(
|
|
"Number of scan output axes specified (",
|
|
output_axes.size(),
|
|
") is not equal to number of scan outputs (",
|
|
num_scan_outputs,
|
|
").");
|
|
}
|
|
} else {
|
|
output_axes.insert(output_axes.end(), num_scan_outputs, 0);
|
|
}
|
|
|
|
std::vector<TypeProto> temporary_type_protos;
|
|
temporary_type_protos.reserve(num_inputs);
|
|
|
|
std::vector<const TypeProto*> subgraph_input_types;
|
|
subgraph_input_types.reserve(num_inputs);
|
|
|
|
TensorShapeProto_Dimension sequence_len_dim;
|
|
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
bool is_loop_state_var = i < num_loop_state_vars;
|
|
bool has_shape = hasInputShape(ctx, i);
|
|
const auto* input_type = ctx.getInputType(i);
|
|
|
|
// Enforce type constraint for inputs
|
|
if (!input_type || !input_type->has_tensor_type()) {
|
|
fail_type_inference("Scan input ", i, " was not a tensor.");
|
|
}
|
|
|
|
if (is_loop_state_var) {
|
|
// If it's a loop state variable we can propagate type and shape 1:1 to
|
|
// the matching Scan output.
|
|
// We can also pass through the type and shape to the subgraph but need to
|
|
// remove the batch size dimension from the shape.
|
|
propagateElemTypeFromInputToOutput(ctx, i, i);
|
|
if (has_shape)
|
|
propagateShapeFromInputToOutput(ctx, i, i);
|
|
|
|
subgraph_input_types.push_back(input_type);
|
|
} else {
|
|
// For other inputs there is no fixed relationships to the Scan outputs,
|
|
// so we don't propagate type/shape information.
|
|
// We can pass through the type and shape to the subgraph inputs but
|
|
// need to remove the sequence length dimensions from the shape.
|
|
if (has_shape) {
|
|
const auto& shape = input_type->tensor_type().shape();
|
|
|
|
// remove sequence length dimensions and add to subgraph_input_types
|
|
int axis = static_cast<int>(axes[i - num_loop_state_vars]);
|
|
axis = handle_negative_axis_validate("scan_input_axes", axis, shape.dim_size());
|
|
|
|
// update sequence_len if a value is available
|
|
|
|
const auto& dims = shape.dim();
|
|
mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);
|
|
|
|
temporary_type_protos.push_back(RemoveIthDimensionFromShape(*input_type, axis));
|
|
subgraph_input_types.push_back(&temporary_type_protos.back());
|
|
|
|
} else {
|
|
subgraph_input_types.push_back(input_type);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Run inferencing on the subgraph
|
|
std::vector<const TypeProto*> output_types;
|
|
|
|
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
|
|
if (graphInferencer) {
|
|
std::vector<const TensorProto*> input_data;
|
|
input_data.reserve(num_inputs);
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
// ctx.getInputData(i), the input to scan, does not represent the input to
|
|
// scan body. So, we pass in null, to represent an unknown value.
|
|
input_data.push_back(nullptr);
|
|
}
|
|
|
|
output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
|
|
}
|
|
|
|
// if empty(), assume inferencing was skipped
|
|
if (!output_types.empty()) {
|
|
if (output_types.size() != num_outputs) {
|
|
fail_type_inference(
|
|
"Graph attribute inferencing returned type information for ",
|
|
output_types.size(),
|
|
" outputs. Expected ",
|
|
num_outputs);
|
|
}
|
|
|
|
// propagate type/shape information for loop state variables and outputs
|
|
for (size_t i = 0; i < num_outputs; ++i) {
|
|
const bool is_loop_state_var = i < num_loop_state_vars;
|
|
auto* subgraph_output_type = output_types[i];
|
|
auto* scan_output_type = ctx.getOutputType(i);
|
|
auto* mutable_scan_output_tensor_type = scan_output_type->mutable_tensor_type();
|
|
|
|
if (!subgraph_output_type->has_tensor_type()) {
|
|
fail_type_inference("Scan 'body' subgraph outputs should all be tensors but output ", i, " was not");
|
|
}
|
|
auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();
|
|
|
|
if (is_loop_state_var) {
|
|
// merge shape; type already propagated
|
|
mergeInShapeInfo(subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
|
|
} else {
|
|
scan_output_type->mutable_tensor_type()->set_elem_type(subgraph_output_tensor_type.elem_type());
|
|
|
|
// propagate shape
|
|
if (subgraph_output_tensor_type.has_shape()) {
|
|
// infer shape of scan-output from the shape of scan-output-element
|
|
// by adding sequence-length at the correct axis position
|
|
const TensorShapeProto& subgraph_output_shape = subgraph_output_tensor_type.shape();
|
|
TensorShapeProto inferred_shape;
|
|
|
|
auto subgraph_output_rank = subgraph_output_shape.dim_size();
|
|
auto output_rank = subgraph_output_rank + 1;
|
|
int output_axis = static_cast<int>(output_axes[i - num_loop_state_vars]);
|
|
output_axis = handle_negative_axis_validate("scan_output_axes", output_axis, output_rank);
|
|
|
|
for (int j = 0; j < output_axis; ++j)
|
|
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
|
|
*(inferred_shape.add_dim()) = sequence_len_dim;
|
|
for (int j = output_axis; j < subgraph_output_rank; ++j)
|
|
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
|
|
|
|
// Merge inferred shape with existing shape information
|
|
mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace ONNX_NAMESPACE
|