49 lines
1.3 KiB
C++
49 lines
1.3 KiB
C++
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "onnx/defs/shape_inference.h"
|
|
#include "onnx/defs/tensor_proto_util.h"
|
|
#include "onnx/onnx_pb.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
namespace defs {
|
|
namespace math {
|
|
namespace utils {
|
|
|
|
template <typename T>
|
|
T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
|
|
if (t == nullptr) {
|
|
return T{};
|
|
}
|
|
|
|
auto data_type = t->data_type();
|
|
switch (data_type) {
|
|
case ONNX_NAMESPACE::TensorProto::FLOAT:
|
|
return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
|
|
case ONNX_NAMESPACE::TensorProto::DOUBLE:
|
|
return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
|
|
case ONNX_NAMESPACE::TensorProto::INT32:
|
|
return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
|
|
case ONNX_NAMESPACE::TensorProto::INT64:
|
|
return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
|
|
default:
|
|
fail_shape_inference("Unsupported input data type of ", data_type);
|
|
}
|
|
}
|
|
|
|
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
|
|
|
|
void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
|
|
|
|
const char* QLinearMatMulDoc();
|
|
|
|
int MathOpTwoIntegers(std::string op_type, int a, int b);
|
|
|
|
} // namespace utils
|
|
} // namespace math
|
|
} // namespace defs
|
|
} // namespace ONNX_NAMESPACE
|