Files
2024-10-30 22:14:35 +01:00

58 lines
1.9 KiB
Python

# Copyright (c) ONNX Project Contributors
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any
from onnx.reference.op_run import OpRun, _split_class_name
def build_registered_operators_any_domain(
module_context: dict[str, Any],
) -> dict[str, dict[int | None, OpRun]]:
reg_ops: dict[str, dict[int | None, OpRun]] = {}
for class_name, class_type in module_context.items():
if class_name.startswith("_") or class_name in {
"Any",
"Dict",
"List",
"TOptional",
"Union",
"annotations",
"cl",
"class_name",
"get_schema",
"module_context",
"textwrap",
}:
continue
if isinstance(class_type, type(build_registered_operators_any_domain)):
continue
try:
issub = issubclass(class_type, OpRun)
except TypeError as e:
raise TypeError(
f"Unexpected variable type {class_type!r} and class_name={class_name!r}."
) from e
if issub:
op_type, op_version = _split_class_name(class_name)
if op_type not in reg_ops:
reg_ops[op_type] = {}
reg_ops[op_type][op_version] = class_type
if not reg_ops:
raise RuntimeError(
"No registered operator. This error happens when no implementation "
"of type 'OpRun' was detected. It may be due to an error during installation. Please try reinstalling onnx."
)
# Set default implementation to the latest one.
for impl in reg_ops.values():
if None in impl:
# default already exists
continue
max_version = max(impl) # type: ignore[type-var]
impl[None] = impl[max_version]
return reg_ops