130 lines
3.0 KiB
C++
130 lines
3.0 KiB
C++
// Copyright (c) ONNX Project Contributors
|
|
|
|
/*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
#pragma once
|
|
#include "onnx/common/common.h"
|
|
#include "onnx/onnx_pb.h"
|
|
|
|
namespace ONNX_NAMESPACE {
|
|
namespace internal {
|
|
|
|
// Visitor: A readonly visitor class for ONNX Proto objects.
|
|
// This class is restricted to Nodes, Graphs, Attributes, and Functions.
|
|
// The VisitX methods invoke ProcessX, and if that returns true, will
|
|
// continue to visit all children of the X.
|
|
|
|
struct Visitor {
|
|
virtual void VisitGraph(const GraphProto& graph) {
|
|
if (ProcessGraph(graph))
|
|
for (auto& node : graph.node())
|
|
VisitNode(node);
|
|
}
|
|
|
|
virtual void VisitFunction(const FunctionProto& function) {
|
|
if (ProcessFunction(function))
|
|
for (auto& node : function.node())
|
|
VisitNode(node);
|
|
}
|
|
|
|
virtual void VisitNode(const NodeProto& node) {
|
|
if (ProcessNode(node)) {
|
|
for (auto& attr : node.attribute()) {
|
|
VisitAttribute(attr);
|
|
}
|
|
}
|
|
}
|
|
|
|
virtual void VisitAttribute(const AttributeProto& attr) {
|
|
if (ProcessAttribute(attr)) {
|
|
if (attr.has_g()) {
|
|
VisitGraph(attr.g());
|
|
}
|
|
for (auto& graph : attr.graphs())
|
|
VisitGraph(graph);
|
|
}
|
|
}
|
|
|
|
virtual bool ProcessGraph(const GraphProto& graph) {
|
|
ONNX_UNUSED_PARAMETER(graph);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessFunction(const FunctionProto& function) {
|
|
ONNX_UNUSED_PARAMETER(function);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessNode(const NodeProto& node) {
|
|
ONNX_UNUSED_PARAMETER(node);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessAttribute(const AttributeProto& attr) {
|
|
ONNX_UNUSED_PARAMETER(attr);
|
|
return true;
|
|
}
|
|
|
|
virtual ~Visitor() {}
|
|
};
|
|
|
|
// MutableVisitor: A version of Visitor that allows mutation of the visited objects.
|
|
struct MutableVisitor {
|
|
virtual void VisitGraph(GraphProto* graph) {
|
|
if (ProcessGraph(graph))
|
|
for (auto& node : *(graph->mutable_node()))
|
|
VisitNode(&node);
|
|
}
|
|
|
|
virtual void VisitFunction(FunctionProto* function) {
|
|
if (ProcessFunction(function))
|
|
for (auto& node : *(function->mutable_node()))
|
|
VisitNode(&node);
|
|
}
|
|
|
|
virtual void VisitNode(NodeProto* node) {
|
|
if (ProcessNode(node)) {
|
|
for (auto& attr : *(node->mutable_attribute())) {
|
|
VisitAttribute(&attr);
|
|
}
|
|
}
|
|
}
|
|
|
|
virtual void VisitAttribute(AttributeProto* attr) {
|
|
if (ProcessAttribute(attr)) {
|
|
if (attr->has_g()) {
|
|
VisitGraph(attr->mutable_g());
|
|
}
|
|
for (auto& graph : *(attr->mutable_graphs()))
|
|
VisitGraph(&graph);
|
|
}
|
|
}
|
|
|
|
virtual bool ProcessGraph(GraphProto* graph) {
|
|
ONNX_UNUSED_PARAMETER(graph);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessFunction(FunctionProto* function) {
|
|
ONNX_UNUSED_PARAMETER(function);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessNode(NodeProto* node) {
|
|
ONNX_UNUSED_PARAMETER(node);
|
|
return true;
|
|
}
|
|
|
|
virtual bool ProcessAttribute(AttributeProto* attr) {
|
|
ONNX_UNUSED_PARAMETER(attr);
|
|
return true;
|
|
}
|
|
|
|
virtual ~MutableVisitor() {}
|
|
};
|
|
|
|
} // namespace internal
|
|
} // namespace ONNX_NAMESPACE
|