1116 lines
42 KiB
C++
1116 lines
42 KiB
C++
// Copyright (c) ONNX Project Contributors
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
#include "onnx/shape_inference/implementation.h"
|
|
|
|
#include <algorithm>
|
|
#include <fstream>
|
|
#include <list>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "onnx/checker.h"
|
|
#include "onnx/common/common.h"
|
|
#include "onnx/common/file_utils.h"
|
|
#include "onnx/defs/data_type_utils.h"
|
|
#include "onnx/proto_utils.h"
|
|
#include "onnx/shape_inference/attribute_binder.h"
|
|
#include "onnx/string_utils.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
namespace shape_inference {
|
|
namespace {
|
|
|
|
std::string GetValueCaseString(const TypeProto& type) {
|
|
switch (type.value_case()) {
|
|
case TypeProto::ValueCase::kTensorType:
|
|
return "tensor_type";
|
|
case TypeProto::ValueCase::kSequenceType:
|
|
return "sequence_type";
|
|
case TypeProto::ValueCase::kMapType:
|
|
return "map_type";
|
|
case TypeProto::ValueCase::kOptionalType:
|
|
return "optional_type";
|
|
#ifdef ONNX_ML
|
|
case TypeProto::ValueCase::kOpaqueType:
|
|
return "opaque_type";
|
|
#endif
|
|
case TypeProto::ValueCase::kSparseTensorType:
|
|
return "sparse_tensor_type";
|
|
case TypeProto::ValueCase::VALUE_NOT_SET:
|
|
return "NOT_SET";
|
|
}
|
|
return ONNX_NAMESPACE::to_string(type.value_case());
|
|
}
|
|
|
|
std::string GetElemTypeString(const TypeProto_Tensor& type) {
|
|
#ifndef ONNX_USE_LITE_PROTO
|
|
std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type()));
|
|
if (!type_str.empty()) {
|
|
return type_str;
|
|
}
|
|
#endif
|
|
return ONNX_NAMESPACE::to_string(type.elem_type());
|
|
}
|
|
|
|
std::string GetElemTypeString(const TypeProto_SparseTensor& type) {
|
|
#ifndef ONNX_USE_LITE_PROTO
|
|
std::string type_str = TensorProto::DataType_Name(static_cast<TensorProto_DataType>(type.elem_type()));
|
|
if (!type_str.empty()) {
|
|
return type_str;
|
|
}
|
|
#endif
|
|
return ONNX_NAMESPACE::to_string(type.elem_type());
|
|
}
|
|
|
|
inline bool IsOnnxDomainOp(const NodeProto& node, const std::string& op_type) {
|
|
return (IsOnnxDomain(node.domain()) && (node.op_type() == op_type));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
template <class T>
|
|
void CheckTensorShapesAndTypes(const T& inferred_type, const T& existing_type) {
|
|
if (inferred_type.elem_type() != TensorProto::UNDEFINED && existing_type.elem_type() != TensorProto::UNDEFINED &&
|
|
existing_type.elem_type() != inferred_type.elem_type()) {
|
|
std::stringstream ss;
|
|
ss << "Inferred elem type differs from existing elem type: (" << GetElemTypeString(inferred_type) << ") vs ("
|
|
<< GetElemTypeString(existing_type) << ")";
|
|
fail_type_inference(ss.str());
|
|
}
|
|
|
|
if (!inferred_type.has_shape() || !existing_type.has_shape()) {
|
|
return;
|
|
}
|
|
|
|
if (inferred_type.shape().dim_size() != existing_type.shape().dim_size()) {
|
|
std::stringstream ss;
|
|
ss << "Inferred shape and existing shape differ in rank: (" << inferred_type.shape().dim_size() << ") vs ("
|
|
<< existing_type.shape().dim_size() << ")";
|
|
fail_shape_inference(ss.str());
|
|
}
|
|
|
|
for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
|
|
const auto& inferred_dim = inferred_type.shape().dim(i);
|
|
const auto& existing_dim = existing_type.shape().dim(i);
|
|
if (inferred_dim.has_dim_value() && existing_dim.has_dim_value() &&
|
|
inferred_dim.dim_value() != existing_dim.dim_value()) {
|
|
std::stringstream ss;
|
|
ss << "Inferred shape and existing shape differ in dimension " << i << ": (" << inferred_dim.dim_value()
|
|
<< ") vs (" << existing_dim.dim_value() << ")";
|
|
fail_shape_inference(ss.str());
|
|
}
|
|
}
|
|
}
|
|
|
|
void checkShapesAndTypes(const TypeProto& inferred_type, const TypeProto& existing_type) {
|
|
const auto inferred_value_case = inferred_type.value_case();
|
|
const auto existing_value_case = existing_type.value_case();
|
|
if (inferred_value_case == TypeProto::ValueCase::VALUE_NOT_SET ||
|
|
existing_value_case == TypeProto::ValueCase::VALUE_NOT_SET) {
|
|
// nothing to check; will assign inferredType to undefined existingType
|
|
return;
|
|
}
|
|
if (inferred_value_case != existing_value_case) {
|
|
fail_type_inference(
|
|
"type case mismatch. existing=",
|
|
GetValueCaseString(existing_type),
|
|
" inferred=",
|
|
GetValueCaseString(inferred_type));
|
|
}
|
|
|
|
if (inferred_value_case == TypeProto::kTensorType && existing_value_case == TypeProto::kTensorType) {
|
|
CheckTensorShapesAndTypes(inferred_type.tensor_type(), existing_type.tensor_type());
|
|
} else if (
|
|
inferred_value_case == TypeProto::kSparseTensorType && existing_value_case == TypeProto::kSparseTensorType) {
|
|
CheckTensorShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type.sparse_tensor_type());
|
|
} else if (inferred_value_case == TypeProto::kSequenceType && existing_value_case == TypeProto::kSequenceType) {
|
|
checkShapesAndTypes(inferred_type.sequence_type().elem_type(), existing_type.sequence_type().elem_type());
|
|
} else if (inferred_value_case == TypeProto::kOptionalType && existing_value_case == TypeProto::kOptionalType) {
|
|
checkShapesAndTypes(inferred_type.optional_type().elem_type(), existing_type.optional_type().elem_type());
|
|
} else if (
|
|
inferred_value_case == TypeProto::TypeProto::kMapType && existing_value_case == TypeProto::TypeProto::kMapType) {
|
|
if (inferred_type.map_type().key_type() != existing_type.map_type().key_type()) {
|
|
fail_type_inference(
|
|
"key type mismatch from MapProto. existing=",
|
|
Utils::DataTypeUtils::ToDataTypeString(existing_type.map_type().key_type()),
|
|
" inferred=",
|
|
Utils::DataTypeUtils::ToDataTypeString(inferred_type.map_type().key_type()));
|
|
}
|
|
checkShapesAndTypes(inferred_type.map_type().value_type(), existing_type.map_type().value_type());
|
|
} else {
|
|
fail_type_inference("type case unsupported. existing=", existing_value_case, " inferred=", inferred_value_case);
|
|
}
|
|
}
|
|
|
|
void mergeShapesAndTypes(const TypeProto_Tensor& inferred_type, TypeProto_Tensor* existing_type) {
|
|
if (existing_type->elem_type() == TensorProto::UNDEFINED) {
|
|
existing_type->set_elem_type(inferred_type.elem_type());
|
|
}
|
|
|
|
if (!inferred_type.has_shape()) {
|
|
return;
|
|
}
|
|
|
|
if (!existing_type->has_shape()) {
|
|
*existing_type->mutable_shape() = inferred_type.shape();
|
|
return;
|
|
}
|
|
|
|
for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
|
|
const auto& inferred_dim = inferred_type.shape().dim(i);
|
|
auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i);
|
|
if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) {
|
|
*existing_dim = inferred_dim;
|
|
}
|
|
}
|
|
}
|
|
|
|
void mergeShapesAndTypes(const TypeProto_SparseTensor& inferred_type, TypeProto_SparseTensor* existing_type) {
|
|
if (existing_type->elem_type() == TensorProto::UNDEFINED) {
|
|
existing_type->set_elem_type(inferred_type.elem_type());
|
|
}
|
|
|
|
if (!inferred_type.has_shape()) {
|
|
return;
|
|
}
|
|
|
|
if (!existing_type->has_shape()) {
|
|
*existing_type->mutable_shape() = inferred_type.shape();
|
|
return;
|
|
}
|
|
|
|
for (int i = 0; i < inferred_type.shape().dim_size(); ++i) {
|
|
const auto& inferred_dim = inferred_type.shape().dim(i);
|
|
auto* existing_dim = existing_type->mutable_shape()->mutable_dim(i);
|
|
if ((!existing_dim->has_dim_value() && !existing_dim->has_dim_param()) || inferred_dim.has_dim_value()) {
|
|
*existing_dim = inferred_dim;
|
|
}
|
|
}
|
|
}
|
|
|
|
void mergeShapesAndTypes(const TypeProto& inferred_type, TypeProto* existing_type) {
|
|
// Check before merge
|
|
checkShapesAndTypes(inferred_type, *existing_type);
|
|
const auto inferred_val_case = inferred_type.value_case();
|
|
if (inferred_val_case == TypeProto::kTensorType) {
|
|
mergeShapesAndTypes(inferred_type.tensor_type(), existing_type->mutable_tensor_type());
|
|
} else if (inferred_val_case == TypeProto::kSparseTensorType) {
|
|
mergeShapesAndTypes(inferred_type.sparse_tensor_type(), existing_type->mutable_sparse_tensor_type());
|
|
} else if (inferred_val_case == TypeProto::kSequenceType) {
|
|
mergeShapesAndTypes(
|
|
inferred_type.sequence_type().elem_type(), existing_type->mutable_sequence_type()->mutable_elem_type());
|
|
} else if (inferred_val_case == TypeProto::kOptionalType) {
|
|
mergeShapesAndTypes(
|
|
inferred_type.optional_type().elem_type(), existing_type->mutable_optional_type()->mutable_elem_type());
|
|
} else if (inferred_val_case == TypeProto::kMapType) {
|
|
if (existing_type->map_type().key_type() == TensorProto::UNDEFINED) {
|
|
existing_type->mutable_map_type()->set_key_type(inferred_type.map_type().key_type());
|
|
}
|
|
mergeShapesAndTypes(inferred_type.map_type().value_type(), existing_type->mutable_map_type()->mutable_value_type());
|
|
}
|
|
}
|
|
|
|
// TypeProto_Tensor or TypeProto_SparseTensor
|
|
template <typename TensorTypeProto>
|
|
void GenerateSymbolicShape(TensorTypeProto* inferred_type, SymbolTable& symbol_table) {
|
|
if (!inferred_type->has_shape()) {
|
|
return;
|
|
}
|
|
for (int i = 0; i < inferred_type->shape().dim_size(); ++i) {
|
|
// set a symbol if it doesn't have dim_value and dim_param
|
|
auto* dim = inferred_type->mutable_shape()->mutable_dim(i);
|
|
if (!dim->has_dim_value() && !dim->has_dim_param()) {
|
|
dim->set_dim_param(symbol_table.createNew());
|
|
}
|
|
}
|
|
}
|
|
|
|
void MaterializeSymbolicShape(TypeProto* inferred_type, SymbolTable& symbol_table) {
|
|
const auto inferred_val_case = inferred_type->value_case();
|
|
if (inferred_val_case == TypeProto::ValueCase::VALUE_NOT_SET) {
|
|
return;
|
|
}
|
|
|
|
if (inferred_val_case == TypeProto::kTensorType) {
|
|
GenerateSymbolicShape(inferred_type->mutable_tensor_type(), symbol_table);
|
|
} else if (inferred_val_case == TypeProto::kSparseTensorType) {
|
|
GenerateSymbolicShape(inferred_type->mutable_sparse_tensor_type(), symbol_table);
|
|
} else if (inferred_val_case == TypeProto::kSequenceType) {
|
|
MaterializeSymbolicShape(inferred_type->mutable_sequence_type()->mutable_elem_type(), symbol_table);
|
|
} else if (inferred_val_case == TypeProto::kOptionalType) {
|
|
MaterializeSymbolicShape(inferred_type->mutable_optional_type()->mutable_elem_type(), symbol_table);
|
|
} else if (inferred_val_case == TypeProto::kMapType) {
|
|
MaterializeSymbolicShape(inferred_type->mutable_map_type()->mutable_value_type(), symbol_table);
|
|
} else {
|
|
fail_shape_inference("type case unsupported for symbolic shape inference. inferred=", inferred_val_case);
|
|
}
|
|
}
|
|
|
|
std::string GetFunctionIdentifier(const FunctionProto& function) {
|
|
// Note: Models with IR version < 10 do not have the overload attribute.
|
|
// However, that will be mapped to an empty identifier.
|
|
std::string overload = function.overload();
|
|
if (overload.empty()) {
|
|
return function.domain() + ":" + function.name();
|
|
}
|
|
return function.domain() + ":" + function.name() + ":" + overload;
|
|
}
|
|
|
|
std::string GetFunctionIdentifier(const NodeProto& node) {
|
|
// Note: Models with IR version < 10 do not have the overload attribute.
|
|
// However, that will be mapped to an empty identifier.
|
|
std::string overload = node.overload();
|
|
if (overload.empty()) {
|
|
return node.domain() + ":" + node.op_type();
|
|
}
|
|
return node.domain() + ":" + node.op_type() + ":" + overload;
|
|
}
|
|
|
|
// InferredTypes: abstracts the differences between FunctionProto and GraphProto
|
|
// for inference. For GraphProto, inferred types are stored in the GraphProto
|
|
// but FunctionProto does not have a place to store inferred types. So, we
|
|
// use a temporary vector (for the duration of inference) to store these.
|
|
class InferredTypes {
|
|
public:
|
|
explicit InferredTypes(GraphProto* graph = nullptr) : graph_ptr(graph) {}
|
|
|
|
TypeProto* Add(const std::string& var_name, const TypeProto& type) {
|
|
if (graph_ptr != nullptr) {
|
|
auto* p = graph_ptr->add_value_info();
|
|
p->set_name(var_name);
|
|
*p->mutable_type() = type;
|
|
return p->mutable_type();
|
|
} else {
|
|
auto* p = new TypeProto(type);
|
|
types.push_back(p);
|
|
return p;
|
|
}
|
|
}
|
|
|
|
~InferredTypes() {
|
|
for (auto* p : types) {
|
|
delete p;
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::vector<TypeProto*> types;
|
|
GraphProto* graph_ptr;
|
|
ONNX_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferredTypes);
|
|
};
|
|
|
|
// Initialize a DataValueMap for a called function from the DataValueMap of the caller
|
|
void BindValuesOnCall(
|
|
const DataValueMap& caller_map,
|
|
const NodeProto& caller,
|
|
DataValueMap& callee_map,
|
|
const FunctionProto& callee) {
|
|
auto num_inputs = (std::min)(caller.input_size(), callee.input_size());
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
const std::string& actual = caller.input(i);
|
|
const std::string& formal = callee.input(i);
|
|
if (!actual.empty()) {
|
|
auto it = caller_map.find(actual);
|
|
if (it != caller_map.end()) {
|
|
callee_map[formal] = it->second;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Update a DataValueMap for a calling function from the DataValueMap of the callee
|
|
void BindValuesOnReturn(
|
|
const DataValueMap& callee_map,
|
|
const FunctionProto& callee,
|
|
DataValueMap& caller_map,
|
|
const NodeProto& caller) {
|
|
auto num_outputs = (std::min)(caller.output_size(), callee.output_size());
|
|
for (int i = 0; i < num_outputs; ++i) {
|
|
const std::string& actual = caller.output(i);
|
|
const std::string& formal = callee.output(i);
|
|
if (!actual.empty()) {
|
|
auto it = callee_map.find(formal);
|
|
if (it != callee_map.end()) {
|
|
caller_map[actual] = it->second;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
class ShapeInferenceImplBase {
|
|
public:
|
|
void UpdateType(const std::string& name, TypeProto* inferred_type) {
|
|
if (inferred_type->value_case() == TypeProto::ValueCase::VALUE_NOT_SET) {
|
|
return;
|
|
}
|
|
|
|
if (symbol_table) {
|
|
MaterializeSymbolicShape(inferred_type, *symbol_table);
|
|
}
|
|
|
|
// Find any pre-existing type and shape info. If there is such,
|
|
// then check for compatibility with the inferred
|
|
// information. Otherwise, initialize it in an empty state.
|
|
auto iter = value_types_by_name.find(name);
|
|
if (iter != value_types_by_name.end()) {
|
|
mergeShapesAndTypes(*inferred_type, iter->second);
|
|
} else {
|
|
value_types_by_name[name] = inferred_types.Add(name, *inferred_type);
|
|
// For undefined output type, update both value_info and output for now
|
|
// Update existing output with undefined type: assign inferred type to it
|
|
iter = undefined_value_types_by_name.find(name);
|
|
if (iter != undefined_value_types_by_name.end()) {
|
|
*iter->second = *inferred_type;
|
|
}
|
|
}
|
|
}
|
|
|
|
void UpdateType(ValueInfoProto& valueInfo) {
|
|
if (valueInfo.has_type()) {
|
|
value_types_by_name[valueInfo.name()] = valueInfo.mutable_type();
|
|
} else {
|
|
undefined_value_types_by_name[valueInfo.name()] = valueInfo.mutable_type();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void AddTemporaryConstant(const std::string& name, const T& vector) {
|
|
input_data_by_name_holder[name] = ToTensor(vector);
|
|
input_data_by_name[name] = &input_data_by_name_holder[name];
|
|
}
|
|
|
|
void ProcessConstant(const NodeProto& n) {
|
|
if (IsOnnxDomainOp(n, "Constant") && n.output().size() == 1) {
|
|
const std::string& output_name = n.output(0);
|
|
for (const auto& attr : n.attribute()) {
|
|
if (attr.name() == "value") {
|
|
if (attr.type() == AttributeProto::TENSOR && attr.has_t()) {
|
|
if (reuse_constant_tensors) {
|
|
input_data_by_name[output_name] = &attr.t();
|
|
} else {
|
|
input_data_by_name_holder[output_name] = attr.t();
|
|
input_data_by_name[output_name] = &input_data_by_name_holder[output_name];
|
|
}
|
|
} else if (attr.type() == AttributeProto::SPARSE_TENSOR && attr.has_sparse_tensor()) {
|
|
if (reuse_constant_tensors) {
|
|
input_sparse_data_by_name[output_name] = &attr.sparse_tensor();
|
|
}
|
|
}
|
|
} else {
|
|
switch (attr.type()) {
|
|
case AttributeProto::INTS: {
|
|
std::vector<int64_t> ints{attr.ints().begin(), attr.ints().end()};
|
|
AddTemporaryConstant(output_name, ints);
|
|
break;
|
|
}
|
|
case AttributeProto::INT: {
|
|
std::vector<int64_t> ints({attr.i()});
|
|
AddTemporaryConstant(output_name, ints);
|
|
break;
|
|
}
|
|
case AttributeProto::FLOATS: {
|
|
std::vector<float> floats{attr.floats().begin(), attr.floats().end()};
|
|
AddTemporaryConstant(output_name, floats);
|
|
break;
|
|
}
|
|
case AttributeProto::FLOAT: {
|
|
std::vector<float> floats({attr.f()});
|
|
AddTemporaryConstant(output_name, floats);
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void ProcessCall(const NodeProto& caller, const FunctionProto& callee, InferenceContext& ctx) {
|
|
DataValueMap callee_value_map;
|
|
if (generated_shape_data_by_name != nullptr) {
|
|
BindValuesOnCall(*generated_shape_data_by_name, caller, callee_value_map, callee);
|
|
}
|
|
InferShapeForFunctionNode(
|
|
callee, schema_registry, ctx, options, model_local_functions_map, symbol_table, &callee_value_map);
|
|
if (generated_shape_data_by_name != nullptr) {
|
|
BindValuesOnReturn(callee_value_map, callee, *generated_shape_data_by_name, caller);
|
|
}
|
|
}
|
|
|
|
void Process(NodeProto& n) {
|
|
// Resolve domain for node
|
|
auto dit = opset_imports.find(n.domain());
|
|
if (dit == opset_imports.end()) {
|
|
// Both "" (ONNX_DOMAIN) and "ai.onnx" (AI_ONNX_DOMAIN) refer to the default ONNX domain
|
|
if (n.domain() == ONNX_DOMAIN) {
|
|
dit = opset_imports.find(AI_ONNX_DOMAIN);
|
|
}
|
|
if (dit == opset_imports.end()) {
|
|
fail_type_inference(
|
|
"Cannot infer type and shape for node name ",
|
|
n.name(),
|
|
". No opset import for domain ",
|
|
n.domain(),
|
|
" optype ",
|
|
n.op_type());
|
|
}
|
|
}
|
|
auto domain_version = dit->second;
|
|
const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
|
|
InferenceContextImpl ctx(
|
|
n,
|
|
value_types_by_name,
|
|
input_data_by_name,
|
|
input_sparse_data_by_name,
|
|
options,
|
|
generated_shape_data_by_name,
|
|
&graph_inference_context);
|
|
|
|
ONNX_TRY {
|
|
if (schema) {
|
|
if (schema->has_type_and_shape_inference_function()) {
|
|
schema->GetTypeAndShapeInferenceFunction()(ctx);
|
|
} else if (schema->HasFunction()) {
|
|
ProcessCall(n, *(schema->GetFunction()), ctx);
|
|
} // else: rely on schema->CheckInputOutputType() down below.
|
|
// check type-constraints specified via type variables
|
|
if (options.check_type) {
|
|
schema->CheckInputOutputType(ctx);
|
|
}
|
|
} else if (model_local_functions_map.size() > 0) {
|
|
auto iter = model_local_functions_map.find(GetFunctionIdentifier(n));
|
|
if (iter != model_local_functions_map.end()) {
|
|
ProcessCall(n, *(iter->second), ctx);
|
|
} else {
|
|
has_unsupported_op = true;
|
|
return;
|
|
}
|
|
} else {
|
|
has_unsupported_op = true;
|
|
return;
|
|
}
|
|
for (int i = 0; i < n.output_size(); ++i) {
|
|
// skip type and shape propagation for missing optional outputs.
|
|
if (!n.output(i).empty())
|
|
UpdateType(n.output(i), ctx.getOutputType(i));
|
|
}
|
|
// Constant values are tracked to improve inference/checking for subsequent nodes.
|
|
ProcessConstant(n);
|
|
// If data-propagation is enabled, partial-evaluation (aka data-propagation) is performed
|
|
// to improve inference/checking for subsequent nodes.
|
|
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
|
|
if (generated_shape_data_by_name == nullptr) {
|
|
fail_shape_inference(
|
|
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
|
|
}
|
|
DataPropagationContextImpl data_propagation_ctx(
|
|
n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
|
|
schema->GetDataPropagationFunction()(data_propagation_ctx);
|
|
}
|
|
}
|
|
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
|
|
ONNX_HANDLE_EXCEPTION([&]() {
|
|
// Note: The following special handling is to accommodate custom-ops. Ideally, custom-ops
|
|
// should be registered with a schema in the schema registry, allowing inference to handle
|
|
// them. As things stand, this special handling is somewhat fragile and is not fully
|
|
// general either. Eg., a custom-op suppresses error-messages for subsequent nodes, but
|
|
// this does not work across graphs. If special handling is required, a user-option may
|
|
// be a better way to do it. The fragility comes from the fact that the types of the
|
|
// returned-values of the custom-op are unknown, and subsequent node-level inference
|
|
// may fail because of this.
|
|
if (!has_unsupported_op) {
|
|
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
|
|
}
|
|
});
|
|
}
|
|
ONNX_CATCH(const std::runtime_error& err) {
|
|
// TODO: Fix this. Unclear if this should be remapped to a shape inference error.
|
|
// Need to rationalize the different types of exceptions that can be thrown.
|
|
// See: https://github.com/onnx/onnx/pull/5519
|
|
ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
|
|
}
|
|
}
|
|
|
|
// TypeProto_Tensor or TypeProto_SparseTensor
|
|
template <typename T>
|
|
void ProcessInitializer(
|
|
const std::string& name,
|
|
const T& tensorValue,
|
|
TypeProto& initializer_type,
|
|
std::unordered_map<std::string, const T*>& map) {
|
|
map[name] = &tensorValue;
|
|
auto iter = value_types_by_name.find(name);
|
|
// If it already exists in input, check input and initializer is sync
|
|
// use shape info from input (input has priority over initializer)
|
|
if (iter != value_types_by_name.end()) {
|
|
checkShapesAndTypes(initializer_type, *iter->second);
|
|
// CheckTensorShapesAndTypes(*initializer_tensor_type, *iter->second->mutable_tensor_type());
|
|
}
|
|
// Support IR>=4: some tensors can only exist in initializer and not in input
|
|
// So shape_inference should make use of initializer shapes
|
|
// Store initializer shape info in value_info as well
|
|
else if (ir_version >= 4) {
|
|
initializer_type_list.push_back(std::move(initializer_type));
|
|
value_types_by_name[name] = &initializer_type_list.back();
|
|
}
|
|
}
|
|
|
|
void Process(GraphProto& graph) {
|
|
if (symbol_table) {
|
|
TraverseGraphsToAddExistingSymbols(graph, *symbol_table);
|
|
}
|
|
for (auto& vi : *graph.mutable_value_info()) {
|
|
UpdateType(vi);
|
|
}
|
|
for (auto& vi : *graph.mutable_input()) {
|
|
UpdateType(vi);
|
|
}
|
|
for (auto& vi : *graph.mutable_output()) {
|
|
UpdateType(vi);
|
|
}
|
|
for (const auto& tp : graph.initializer()) {
|
|
TypeProto initializer_type;
|
|
TypeProto_Tensor* initializer_tensor_type = initializer_type.mutable_tensor_type();
|
|
initializer_tensor_type->set_elem_type(tp.data_type());
|
|
// set the shape according to the initializer shape info
|
|
auto* shape = initializer_tensor_type->mutable_shape();
|
|
for (int i = 0; i < tp.dims_size(); ++i) {
|
|
shape->add_dim()->set_dim_value(tp.dims(i));
|
|
}
|
|
ProcessInitializer(tp.name(), tp, initializer_type, input_data_by_name);
|
|
}
|
|
for (const auto& tp : graph.sparse_initializer()) {
|
|
TypeProto initializer_type;
|
|
auto* initializer_sparse_tensor_type = initializer_type.mutable_sparse_tensor_type();
|
|
initializer_sparse_tensor_type->set_elem_type(tp.values().data_type());
|
|
// set the shape according to the initializer shape info
|
|
auto* shape = initializer_sparse_tensor_type->mutable_shape();
|
|
for (int i = 0; i < tp.dims_size(); ++i) {
|
|
shape->add_dim()->set_dim_value(tp.dims(i));
|
|
}
|
|
ProcessInitializer(tp.values().name(), tp, initializer_type, input_sparse_data_by_name);
|
|
}
|
|
for (auto& n : *graph.mutable_node()) {
|
|
Process(n);
|
|
}
|
|
}
|
|
|
|
void Process(const NodeProto& n, internal::AttributeBinder& attribute_binder) {
|
|
NodeProto copy_n(n);
|
|
attribute_binder.VisitNode(©_n);
|
|
Process(copy_n);
|
|
}
|
|
|
|
void Process(const FunctionProto& func_proto, InferenceContext& ctx) {
|
|
// Ensure Constant node tensor-attributes are copied
|
|
bool old_reuse_constant_tensors = reuse_constant_tensors;
|
|
reuse_constant_tensors = false;
|
|
|
|
// Get a temporary tensor-shape map
|
|
const int num_actual_inputs = static_cast<int>(ctx.getNumInputs());
|
|
const auto num_func_inputs = func_proto.input_size();
|
|
std::vector<TypeProto> types_cache(num_func_inputs);
|
|
for (int i = 0; i < num_func_inputs; ++i) {
|
|
auto& parameter_name = func_proto.input().Get(i);
|
|
auto* type_ptr = (i < num_actual_inputs) ? ctx.getInputType(i) : nullptr;
|
|
// nullptr is valid, and indicates a missing optional input
|
|
if (type_ptr != nullptr) {
|
|
// Use a temporary copy of original type.
|
|
// TODO: investigate whether we can eliminate use of temporary copy
|
|
types_cache[i] = *type_ptr;
|
|
value_types_by_name[parameter_name] = &types_cache[i];
|
|
} else
|
|
value_types_by_name[parameter_name] = nullptr;
|
|
}
|
|
|
|
// Create a temporary initializer value map
|
|
for (int i = 0; i < num_actual_inputs && i < num_func_inputs; ++i) {
|
|
const TypeProto* type = ctx.getInputType(i);
|
|
if (type != nullptr) {
|
|
if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) {
|
|
input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i);
|
|
} else if (type->value_case() == TypeProto::kSparseTensorType && ctx.getInputSparseData(i) != nullptr) {
|
|
input_sparse_data_by_name[func_proto.input().Get(i)] = ctx.getInputSparseData(i);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::unordered_map<std::string, const AttributeProto*> attr_map;
|
|
for (auto& attr : func_proto.attribute()) {
|
|
if (ctx.getAttribute(attr) != nullptr) {
|
|
attr_map[attr] = ctx.getAttribute(attr);
|
|
}
|
|
}
|
|
|
|
for (auto& default_value : func_proto.attribute_proto()) {
|
|
const std::string& name = default_value.name();
|
|
const AttributeProto* value = ctx.getAttribute(name);
|
|
attr_map[name] = (value != nullptr) ? value : &default_value;
|
|
}
|
|
|
|
internal::AttributeBinder attribute_binder(attr_map);
|
|
for (auto& n : func_proto.node()) {
|
|
Process(n, attribute_binder);
|
|
}
|
|
|
|
for (int i = 0; i < func_proto.output_size(); ++i) {
|
|
const std::string& output_name = func_proto.output().Get(i);
|
|
// Skip if no type inferred for the tensor
|
|
auto iter = value_types_by_name.find(output_name);
|
|
if (iter != value_types_by_name.cend()) {
|
|
// Copy the type info to ctx
|
|
// to pass back to main graph
|
|
auto type_proto = ctx.getOutputType(i);
|
|
type_proto->CopyFrom(*(iter->second));
|
|
}
|
|
}
|
|
|
|
reuse_constant_tensors = old_reuse_constant_tensors;
|
|
}
|
|
|
|
public:
|
|
ShapeInferenceImplBase(
|
|
GraphProto* graph, // nullptr for FunctionProto inference
|
|
const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name_in,
|
|
const std::unordered_map<std::string, int>& opset_imports_in,
|
|
const ShapeInferenceOptions& options_in,
|
|
SymbolTable* symbol_table_in,
|
|
const ModelLocalFunctionsMap& model_local_functions_map_in,
|
|
const ISchemaRegistry* schema_registry_in = OpSchemaRegistry::Instance(),
|
|
DataValueMap* generated_shape_data_by_name_in = nullptr,
|
|
const int ir_version_in = IR_VERSION // default the latest one
|
|
)
|
|
: inferred_types(graph),
|
|
value_types_by_name(outer_scope_value_types_by_name_in),
|
|
opset_imports(opset_imports_in),
|
|
options(options_in),
|
|
symbol_table(symbol_table_in),
|
|
model_local_functions_map(model_local_functions_map_in),
|
|
schema_registry(schema_registry_in),
|
|
generated_shape_data_by_name(generated_shape_data_by_name_in),
|
|
ir_version(ir_version_in),
|
|
graph_inference_context{
|
|
value_types_by_name,
|
|
opset_imports,
|
|
symbol_table,
|
|
model_local_functions_map,
|
|
schema_registry,
|
|
generated_shape_data_by_name,
|
|
ir_version} {
|
|
if (options.enable_data_propagation && generated_shape_data_by_name == nullptr) {
|
|
fail_shape_inference(
|
|
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
|
|
}
|
|
}
|
|
|
|
void FinalizeShapeInference() {
|
|
auto& errors = getErrors();
|
|
// Throw shape inference error if any. Error mode right now only supports 0 and 1.
|
|
// When set to 0, any node level shape inference errors are not thrown. This is to support backward compatiblity
|
|
// with 1.7 and earlier releases. When set to 1 it will throw all exceptions.
|
|
// TODO: Add a more granular way for exception handling.
|
|
if (!errors.empty() && options.error_mode > 0) {
|
|
std::string full_errors = "Inference error(s): ";
|
|
for (const std::string& error : inference_errors) {
|
|
full_errors += error + "\n";
|
|
}
|
|
fail_shape_inference(full_errors);
|
|
}
|
|
}
|
|
|
|
const std::vector<std::string>& getErrors() const {
|
|
return inference_errors;
|
|
}
|
|
|
|
private:
|
|
InferredTypes inferred_types;
|
|
std::unordered_map<std::string, TypeProto*> value_types_by_name;
|
|
const std::unordered_map<std::string, int>& opset_imports;
|
|
|
|
const ShapeInferenceOptions& options;
|
|
SymbolTable* symbol_table;
|
|
const ModelLocalFunctionsMap& model_local_functions_map;
|
|
const ISchemaRegistry* schema_registry;
|
|
DataValueMap* generated_shape_data_by_name;
|
|
int ir_version;
|
|
GraphInferenceContext graph_inference_context;
|
|
|
|
std::unordered_map<std::string, TypeProto*> undefined_value_types_by_name;
|
|
std::unordered_map<std::string, const TensorProto*> input_data_by_name;
|
|
std::unordered_map<std::string, TensorProto> input_data_by_name_holder;
|
|
std::unordered_map<std::string, const SparseTensorProto*> input_sparse_data_by_name;
|
|
|
|
bool has_unsupported_op = false;
|
|
|
|
std::vector<std::string> inference_errors;
|
|
|
|
std::list<TypeProto> initializer_type_list;
|
|
|
|
// reuse_constant_tensors: controls whether we need to copy tensors occurring as attributes
|
|
// in Constant nodes. We avoid it for inference for graphs, but must make a copy for functions.
|
|
bool reuse_constant_tensors = true;
|
|
};
|
|
|
|
static void InferShapesImpl(
|
|
GraphProto* g,
|
|
const std::unordered_map<std::string, TypeProto*>& outer_scope_value_types_by_name,
|
|
const std::unordered_map<std::string, int>& opset_imports,
|
|
const ShapeInferenceOptions& options,
|
|
SymbolTable* symbol_table,
|
|
const ModelLocalFunctionsMap& model_local_functions_map,
|
|
const ISchemaRegistry* schema_registry = OpSchemaRegistry::Instance(),
|
|
DataValueMap* generated_shape_data_by_name = nullptr,
|
|
const int ir_version = IR_VERSION // default the latest one
|
|
) {
|
|
DataValueMap empty;
|
|
if (generated_shape_data_by_name == nullptr) {
|
|
generated_shape_data_by_name = ∅
|
|
}
|
|
ShapeInferenceImplBase base(
|
|
g,
|
|
outer_scope_value_types_by_name,
|
|
opset_imports,
|
|
options,
|
|
symbol_table,
|
|
model_local_functions_map,
|
|
schema_registry,
|
|
generated_shape_data_by_name,
|
|
ir_version);
|
|
base.Process(*g);
|
|
base.FinalizeShapeInference();
|
|
}
|
|
|
|
// Either ModelProto or FunctionProto
|
|
template <class T>
|
|
std::unordered_map<std::string, int> GetOpsetImportsFromProto(const T& proto) {
|
|
std::unordered_map<std::string, int> opset_imports;
|
|
for (const auto& opset_import : proto.opset_import()) {
|
|
opset_imports[opset_import.domain()] = static_cast<int>(opset_import.version());
|
|
}
|
|
return opset_imports;
|
|
}
|
|
|
|
void InferShapes(
|
|
GraphProto* g,
|
|
const std::unordered_map<std::string, int>& opset_imports,
|
|
const ISchemaRegistry* schema_registry,
|
|
const ShapeInferenceOptions& options,
|
|
const std::unordered_map<std::string, const FunctionProto*>& model_local_functions) {
|
|
SymbolTableImpl symbol_table;
|
|
InferShapesImpl(
|
|
g,
|
|
std::unordered_map<std::string, TypeProto*>(0),
|
|
opset_imports,
|
|
options,
|
|
&symbol_table,
|
|
model_local_functions,
|
|
schema_registry);
|
|
}
|
|
|
|
void InferShapes(
|
|
ModelProto& m,
|
|
const ISchemaRegistry* schema_registry,
|
|
const ShapeInferenceOptions& options,
|
|
DataValueMap* generated_shape_data_by_name) {
|
|
auto opset_imports = GetOpsetImportsFromProto(m);
|
|
SymbolTableImpl symbol_table;
|
|
ModelLocalFunctionsMap model_local_functions_by_id;
|
|
for (const auto& function_proto : m.functions()) {
|
|
model_local_functions_by_id.insert({GetFunctionIdentifier(function_proto), &function_proto});
|
|
}
|
|
InferShapesImpl(
|
|
m.mutable_graph(),
|
|
std::unordered_map<std::string, TypeProto*>(0),
|
|
opset_imports,
|
|
options,
|
|
&symbol_table,
|
|
model_local_functions_by_id,
|
|
schema_registry,
|
|
generated_shape_data_by_name,
|
|
m.ir_version());
|
|
}
|
|
|
|
void InferShapes(
|
|
const std::string& model_path,
|
|
const std::string& save_path,
|
|
const ISchemaRegistry* schema_registry,
|
|
const ShapeInferenceOptions& options,
|
|
DataValueMap* generated_shape_data_by_name) {
|
|
ModelProto model;
|
|
LoadProtoFromPath(model_path, model);
|
|
InferShapes(model, schema_registry, options, generated_shape_data_by_name);
|
|
// Save the inferred model to the original model path
|
|
// Use SerializeToString instead of SerializeToOstream due to LITE_PROTO
|
|
std::fstream output(save_path, std::ios::out | std::ios::trunc | std::ios::binary);
|
|
std::string model_string;
|
|
ONNX_TRY {
|
|
model.SerializeToString(&model_string);
|
|
output << model_string;
|
|
}
|
|
ONNX_CATCH(...) {
|
|
fail_check("Unable to save inferred model to the target path:", save_path);
|
|
}
|
|
}
|
|
|
|
// Infer shape for functions
|
|
void InferShapeForFunctionNode(
|
|
const FunctionProto& func_proto,
|
|
const std::unordered_map<std::string, int>& func_opset_imports,
|
|
const ISchemaRegistry* schema_registry,
|
|
InferenceContext& ctx,
|
|
const ShapeInferenceOptions& options,
|
|
const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map,
|
|
SymbolTable* symbol_table,
|
|
DataValueMap* generated_shape_data_by_name) {
|
|
ShapeInferenceImplBase base(
|
|
nullptr, // no graph
|
|
{}, // outer_scope_value_types_by_name
|
|
func_opset_imports,
|
|
options,
|
|
symbol_table,
|
|
model_local_functions_map,
|
|
schema_registry,
|
|
generated_shape_data_by_name);
|
|
base.Process(func_proto, ctx);
|
|
base.FinalizeShapeInference();
|
|
}
|
|
|
|
void InferShapeForFunctionNode(
|
|
const FunctionProto& function_proto,
|
|
const ISchemaRegistry* schema_registry,
|
|
InferenceContext& ctx,
|
|
const ShapeInferenceOptions& options,
|
|
const std::unordered_map<std::string, const FunctionProto*>& model_local_functions_map,
|
|
SymbolTable* symbol_table,
|
|
DataValueMap* generated_shape_data_by_name) {
|
|
auto opset_imports = GetOpsetImportsFromProto(function_proto);
|
|
InferShapeForFunctionNode(
|
|
function_proto,
|
|
opset_imports,
|
|
schema_registry,
|
|
ctx,
|
|
options,
|
|
model_local_functions_map,
|
|
symbol_table,
|
|
generated_shape_data_by_name);
|
|
}
|
|
|
|
struct FunctionInferenceContext : public InferenceContext {
|
|
FunctionInferenceContext(
|
|
const FunctionProto& func_proto,
|
|
const std::vector<TypeProto>& input_types,
|
|
const std::vector<AttributeProto>& attributes,
|
|
const ShapeInferenceOptions& options)
|
|
: input_types_(input_types), options_(options), func_proto_(&func_proto) {
|
|
for (const auto& attr : attributes) {
|
|
attributesByName_[attr.name()] = &attr;
|
|
}
|
|
auto num_outputs = func_proto.output_size();
|
|
for (int i = 0; i < num_outputs; i++) {
|
|
output_types_.push_back(TypeProto());
|
|
}
|
|
}
|
|
|
|
const AttributeProto* getAttribute(const std::string& name) const override {
|
|
auto iter = attributesByName_.find(name);
|
|
if (iter == attributesByName_.end()) {
|
|
return nullptr;
|
|
} else {
|
|
return iter->second;
|
|
}
|
|
}
|
|
size_t getNumInputs() const override {
|
|
return input_types_.size();
|
|
}
|
|
|
|
size_t getNumOutputs() const override {
|
|
return output_types_.size();
|
|
}
|
|
|
|
const TypeProto* getInputType(size_t index) const override {
|
|
// We should return nullptr for missing optional parameters.
|
|
// An uninitialized TypeProto() is used for missing optional parameters, and
|
|
// is mapped to a nullptr here.
|
|
if (index >= input_types_.size())
|
|
return nullptr;
|
|
if (input_types_[index].value_case() == TypeProto::ValueCase::VALUE_NOT_SET)
|
|
return nullptr;
|
|
return &input_types_[index];
|
|
}
|
|
|
|
TypeProto* getOutputType(size_t index) override {
|
|
return (index < output_types_.size()) ? &output_types_[index] : nullptr;
|
|
}
|
|
|
|
GraphInferencer* getGraphAttributeInferencer(const std::string& attribute_name) override {
|
|
ONNX_UNUSED_PARAMETER(attribute_name); // This method is unused for function-type-inference.
|
|
return nullptr;
|
|
}
|
|
|
|
const TensorProto* getInputData(size_t index) const override {
|
|
ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of statically known input values.
|
|
return nullptr;
|
|
}
|
|
|
|
const SparseTensorProto* getInputSparseData(size_t index) const override {
|
|
ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of statically known input values.
|
|
return nullptr;
|
|
}
|
|
|
|
const TensorShapeProto* getSymbolicInput(size_t index) const override {
|
|
ONNX_UNUSED_PARAMETER(index); // This inference doesn't take advantage of data-propagation.
|
|
return nullptr;
|
|
}
|
|
|
|
std::vector<TypeProto> popOutputTypes() {
|
|
return std::move(output_types_);
|
|
}
|
|
|
|
std::string getDisplayName() const override {
|
|
if (func_proto_ == nullptr)
|
|
return "";
|
|
if (func_proto_->domain().empty()) {
|
|
if (func_proto_->name().empty())
|
|
return "";
|
|
return MakeString("function ", func_proto_->name());
|
|
}
|
|
if (func_proto_->name().empty())
|
|
return MakeString("function [", func_proto_->domain(), "]");
|
|
return MakeString("function ", func_proto_->name(), "[", func_proto_->domain(), "]");
|
|
}
|
|
|
|
private:
|
|
const std::vector<TypeProto>& input_types_;
|
|
std::vector<TypeProto> output_types_;
|
|
std::unordered_map<std::string, const AttributeProto*> attributesByName_;
|
|
ShapeInferenceOptions options_;
|
|
const FunctionProto* func_proto_;
|
|
};
|
|
|
|
std::vector<TypeProto> InferFunctionOutputTypes(
|
|
const FunctionProto& function_proto,
|
|
const std::vector<TypeProto>& input_types,
|
|
const std::vector<AttributeProto>& attributes) {
|
|
// TODO: if it is desirable for infer_function_output_types to provide check_type, strict_mode, data_prop,
|
|
// we can add them to the Python API. For now we just assume the default options.
|
|
ShapeInferenceOptions options{true, 1, false};
|
|
FunctionInferenceContext ctx(function_proto, input_types, attributes, options);
|
|
auto opset_imports = GetOpsetImportsFromProto(function_proto);
|
|
ShapeInferenceImplBase base(
|
|
nullptr, // no graph
|
|
{}, // outer_scope_value_types_by_name
|
|
opset_imports,
|
|
options,
|
|
/*symbol_table*/ nullptr,
|
|
/*model_local_functions_map*/ {},
|
|
/*schema_registry*/ OpSchemaRegistry::Instance(),
|
|
/*generated_shape_data_by_name*/ nullptr);
|
|
base.Process(function_proto, ctx);
|
|
base.FinalizeShapeInference();
|
|
return ctx.popOutputTypes();
|
|
}
|
|
|
|
std::vector<const TypeProto*> GraphInferencerImpl::doInferencing(
|
|
const std::vector<const TypeProto*>& input_types,
|
|
const std::vector<const TensorProto*>& input_data) {
|
|
SymbolTable* symbol_table = context_->symbol_table;
|
|
int num_inputs = int(input_types.size());
|
|
std::unordered_set<std::string> initializer_name_set;
|
|
for (const auto& tp : g_->initializer()) {
|
|
initializer_name_set.insert(tp.name());
|
|
}
|
|
|
|
if (context_->ir_version >= 4) {
|
|
if (g_->input_size() != num_inputs) {
|
|
fail_shape_inference("Graph has ", g_->input_size(), " inputs but ", num_inputs, " were provided");
|
|
}
|
|
for (int i = 0; i < g_->input_size(); ++i) {
|
|
if (initializer_name_set.count(g_->input(i).name()) > 0) {
|
|
fail_shape_inference(
|
|
"Cannot use the same name as both a subgraph initializer and subgraph input: ", g_->input(i).name());
|
|
}
|
|
}
|
|
} else {
|
|
// IR < 4 requires all initializers to be optional inputs
|
|
// So the number of graph input can be larger than the number of node input
|
|
if (num_inputs > g_->input_size()) {
|
|
fail_shape_inference(
|
|
"Graph has ",
|
|
g_->input_size(),
|
|
" inputs but ",
|
|
num_inputs,
|
|
" were provided.",
|
|
"The number of graph input cannot be smaller than the number of node input");
|
|
} else if (num_inputs < g_->input_size()) {
|
|
for (int i = 0; i < g_->input_size(); ++i) {
|
|
if (i < num_inputs && initializer_name_set.count(g_->input(i).name()) > 0) {
|
|
fail_shape_inference("Graph initializer names must appear after the actual inputs: ", g_->input(i).name());
|
|
} else if (i >= num_inputs && initializer_name_set.count(g_->input(i).name()) == 0) {
|
|
// Further check whether the additional input is in initializers
|
|
fail_shape_inference("Cannot find missing input: ", g_->input(i).name(), "in initializers. ");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int i = 0, end = num_inputs; i < end; ++i) {
|
|
const TypeProto* inferred_input = input_types[i];
|
|
|
|
if (!inferred_input)
|
|
continue;
|
|
|
|
TypeProto* graph_input = g_->mutable_input(i)->mutable_type();
|
|
// Even if graphInput doesn't have defined type, it will assign inferredType to it
|
|
mergeShapesAndTypes(*inferred_input, graph_input);
|
|
|
|
if (symbol_table) {
|
|
MaterializeSymbolicShape(graph_input, *symbol_table);
|
|
}
|
|
}
|
|
|
|
// future: pass inputData into InferShapes either directly, or indirectly by
|
|
// updating initializers that match subgraph inputs.
|
|
(void)input_data;
|
|
InferShapesImpl(
|
|
g_,
|
|
*context_->outer_scope_value_types_by_name, // never null
|
|
context_->opset_imports,
|
|
options_,
|
|
symbol_table,
|
|
context_->model_local_functions,
|
|
context_->schema_registry,
|
|
context_->generated_shape_data_by_name);
|
|
|
|
std::vector<const TypeProto*> graph_output_types;
|
|
graph_output_types.reserve(g_->output().size());
|
|
for (const ValueInfoProto& output : g_->output()) {
|
|
graph_output_types.push_back(&output.type());
|
|
}
|
|
|
|
return graph_output_types;
|
|
}
|
|
|
|
std::string GetErrorWithNodeInfo(const NodeProto& n, const std::runtime_error& err) {
|
|
std::string op_name = n.has_name() ? (", node name: " + n.name()) : "";
|
|
return "(op_type:" + n.op_type() + op_name + "): " + err.what();
|
|
}
|
|
|
|
void TraverseGraphsToAddExistingSymbols(const GraphProto& g, SymbolTable& symbol_table) {
|
|
symbol_table.addFromGraph(g);
|
|
for (const auto& n : g.node()) {
|
|
for (auto& attr : n.attribute()) {
|
|
if (attr.has_g()) {
|
|
TraverseGraphsToAddExistingSymbols(attr.g(), symbol_table);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace shape_inference
|
|
} // namespace ONNX_NAMESPACE
|