I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,553 @@
# mypy: allow-untyped-defs
from __future__ import annotations
__all__ = [
# Modules
"symbolic_helper",
"utils",
"errors",
# All opsets
"symbolic_caffe2",
"symbolic_opset7",
"symbolic_opset8",
"symbolic_opset9",
"symbolic_opset10",
"symbolic_opset11",
"symbolic_opset12",
"symbolic_opset13",
"symbolic_opset14",
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
"symbolic_opset18",
"symbolic_opset19",
"symbolic_opset20",
# Enums
"ExportTypes",
"OperatorExportTypes",
"TrainingMode",
"TensorProtoDataType",
"JitScalarType",
# Public functions
"export",
"export_to_pretty_string",
"is_in_onnx_export",
"select_model_mode_for_export",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
"disable_log",
"enable_log",
# Base error
"OnnxExporterError",
# Dynamo Exporter
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXRuntimeOptions",
"OnnxRegistry",
"dynamo_export",
"enable_fake_mode",
# DORT / torch.compile
"is_onnxrt_backend_supported",
]
from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._exporter_states import ExportTypes
from ._internal.onnxruntime import (
is_onnxrt_backend_supported,
OrtBackend as _OrtBackend,
OrtBackendOptions as _OrtBackendOptions,
OrtExecutionProvider as _OrtExecutionProvider,
)
from ._type_utils import JitScalarType
from .errors import OnnxExporterError
from .utils import (
_optimize_graph,
_run_symbolic_function,
_run_symbolic_method,
export_to_pretty_string,
is_in_onnx_export,
register_custom_op_symbolic,
select_model_mode_for_export,
unregister_custom_op_symbolic,
)
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_opset7,
symbolic_opset8,
symbolic_opset9,
symbolic_opset10,
symbolic_opset11,
symbolic_opset12,
symbolic_opset13,
symbolic_opset14,
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
symbolic_opset19,
symbolic_opset20,
utils,
)
from ._internal._exporter_legacy import ( # usort: skip. needs to be last to avoid circular import
DiagnosticOptions,
ExportOptions,
ONNXProgram,
ONNXRuntimeOptions,
OnnxRegistry,
enable_fake_mode,
)
if TYPE_CHECKING:
import os
# Set namespace for exposed private names
DiagnosticOptions.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
OnnxRegistry.__module__ = "torch.onnx"
_OrtBackend.__module__ = "torch.onnx"
_OrtBackendOptions.__module__ = "torch.onnx"
_OrtExecutionProvider.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
is_onnxrt_backend_supported.__module__ = "torch.onnx"
producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION
def export(
model: torch.nn.Module
| torch.export.ExportedProgram
| torch.jit.ScriptModule
| torch.jit.ScriptFunction,
args: tuple[Any, ...] = (),
f: str | os.PathLike | None = None,
*,
kwargs: dict[str, Any] | None = None,
export_params: bool = True,
verbose: bool | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
| Mapping[str, Sequence[int]]
| None = None,
keep_initializers_as_inputs: bool = False,
dynamo: bool = False,
# Dynamo only options
external_data: bool = True,
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None,
report: bool = False,
verify: bool = False,
profile: bool = False,
dump_exported_program: bool = False,
artifacts_dir: str | os.PathLike = ".",
fallback: bool = False,
# Deprecated options
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
do_constant_folding: bool = True,
custom_opsets: Mapping[str, int] | None = None,
export_modules_as_functions: bool | Collection[type[torch.nn.Module]] = False,
autograd_inlining: bool = True,
**_: Any, # ignored options
) -> Any | None:
r"""Exports a model into ONNX format.
Args:
model: The model to be exported.
args: Example positional inputs. Any non-Tensor arguments will be hard-coded into the
exported model; any Tensor arguments will become inputs of the exported model,
in the order they occur in the tuple.
f: Path to the output ONNX model file. E.g. "model.onnx".
kwargs: Optional example keyword inputs.
export_params: If false, parameters (weights) will not be exported.
verbose: Whether to enable verbose logging.
input_names: names to assign to the input nodes of the graph, in order.
output_names: names to assign to the output nodes of the graph, in order.
opset_version: The version of the
`default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
to target. Must be >= 7.
dynamic_axes:
By default the exported model will have the shapes of all input and output tensors
set to exactly match those given in ``args``. To specify axes of tensors as
dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
* KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
``output_names``.
* VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
list, each element is an axis index.
For example::
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"],
)
Produces::
input {
name: "x"
...
shape {
dim {
dim_value: 2 # axis 0
}
dim {
dim_value: 2 # axis 1
...
output {
name: "sum"
...
shape {
dim {
dim_value: 2 # axis 0
...
While::
torch.onnx.export(
SumModule(),
(torch.ones(2, 2),),
"onnx.pb",
input_names=["x"],
output_names=["sum"],
dynamic_axes={
# dict value: manually named axes
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
},
)
Produces::
input {
name: "x"
...
shape {
dim {
dim_param: "my_custom_axis_name" # axis 0
}
dim {
dim_value: 2 # axis 1
...
output {
name: "sum"
...
shape {
dim {
dim_param: "sum_dynamic_axes_1" # axis 0
...
keep_initializers_as_inputs: If True, all the
initializers (typically corresponding to model weights) in the
exported graph will also be added as inputs to the graph. If False,
then initializers are not added as inputs to the graph, and only
the user inputs are added as inputs.
Set this to True if you intend to supply model weights at runtime.
Set it to False if the weights are static to allow for better optimizations
(e.g. constant folding) by backends/runtimes.
dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript.
external_data: Whether to save the model weights as an external data file.
This is required for models with large weights that exceed the ONNX file size limit (2GB).
When False, the weights are saved in the ONNX file with the model architecture.
dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to
:func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True.
Only one parameter `dynamic_axes` or `dynamic_shapes` should be set
at the same time.
report: Whether to generate a markdown report for the export process.
verify: Whether to verify the exported model using ONNX Runtime.
profile: Whether to profile the export process.
dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file.
This is useful for debugging the exporter.
artifacts_dir: The directory to save the debugging artifacts like the report and the serialized
exported program.
fallback: Whether to fallback to the TorchScript exporter if the dynamo exporter fails.
training: Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type: Deprecated option. Only ONNX is supported.
do_constant_folding: Deprecated option. The exported graph is always optimized.
custom_opsets: Deprecated.
A dictionary:
* KEY (str): opset domain name
* VALUE (int): opset version
If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
the opset version is set to 1. Only custom opset domain name and version should be
indicated through this argument.
export_modules_as_functions: Deprecated option.
Flag to enable
exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
particular types of modules to export as local functions in ONNX.
This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
``opset_version`` < 15 implies IR version < 8, which means no local function support.
Module variables will be exported as function attributes. There are two categories of function
attributes.
1. Annotated attributes: class variables that have type annotations via
`PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
will be exported as attributes.
Annotated attributes are not used inside the subgraph of ONNX local function because
they are not created by PyTorch JIT tracing, but they may be used by consumers
to determine whether or not to replace the function with a particular fused kernel.
2. Inferred attributes: variables that are used by operators inside the module. Attribute names
will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
* ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
* ``True``: export all ``nn.Module`` forward calls as local function nodes.
* Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
only if the type of the ``nn.Module`` is found in the set.
autograd_inlining: Deprecated.
Flag used to control whether to inline autograd functions.
Refer to https://github.com/pytorch/pytorch/pull/74765 for more details.
"""
if dynamo is True or isinstance(model, torch.export.ExportedProgram):
from torch.onnx._internal import exporter
if isinstance(args, torch.Tensor):
args = (args,)
return exporter.export_compat(
model,
args,
f,
kwargs=kwargs,
export_params=export_params,
verbose=verbose,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
external_data=external_data,
dynamic_shapes=dynamic_shapes,
report=report,
verify=verify,
profile=profile,
dump_exported_program=dump_exported_program,
artifacts_dir=artifacts_dir,
fallback=fallback,
)
else:
from torch.onnx.utils import export
if dynamic_shapes:
raise ValueError(
"The exporter only supports dynamic shapes "
"through parameter dynamic_axes when dynamo=False."
)
export(
model,
args,
f, # type: ignore[arg-type]
kwargs=kwargs,
export_params=export_params,
verbose=verbose is True,
input_names=input_names,
output_names=output_names,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
training=training,
operator_export_type=operator_export_type,
do_constant_folding=do_constant_folding,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
autograd_inlining=autograd_inlining,
)
return None
def dynamo_export(
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
/,
*model_args,
export_options: ExportOptions | None = None,
**model_kwargs,
) -> ONNXProgram | Any:
"""Export a torch.nn.Module to an ONNX graph.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.
model_kwargs: Keyword inputs to ``model``.
export_options: Options to influence the export to ONNX.
Returns:
An in-memory representation of the exported ONNX model.
**Example 1 - Simplest export**
::
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, bias=None):
out = self.linear(x)
out = out + bias
return out
model = MyModel()
kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
"my_simple_model.onnx"
)
**Example 2 - Exporting with dynamic shapes**
::
# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
model, *args, **kwargs, export_options=export_options
)
onnx_program.save("my_dynamic_model.onnx")
"""
# NOTE: The new exporter is experimental and is not enabled by default.
import warnings
from torch.onnx import _flags
from torch.onnx._internal import exporter
from torch.utils import _pytree
if isinstance(model, torch.export.ExportedProgram):
return exporter.export_compat(
model, # type: ignore[arg-type]
model_args,
f=None,
kwargs=model_kwargs,
opset_version=18,
external_data=True,
export_params=True,
fallback=True,
)
elif _flags.USE_EXPERIMENTAL_LOGIC:
if export_options is not None:
warnings.warn(
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
"For a more comprehensive set of export options, including advanced features, please consider using "
"`torch.onnx.export(..., dynamo=True)`. ",
category=FutureWarning,
)
if export_options is not None and export_options.dynamic_shapes:
# Make all shapes dynamic
def _to_dynamic_shapes_mapper():
arg_order = 0
def _to_dynamic_shape(x):
nonlocal arg_order
if isinstance(x, torch.Tensor):
rank = len(x.shape)
dynamic_shape = {}
for i in range(rank):
dynamic_shape[i] = torch.export.Dim(
f"arg_{arg_order}_dim_{i}"
)
arg_order += 1
return dynamic_shape
else:
return None
return _to_dynamic_shape
# model_args could be nested
dynamic_shapes = _pytree.tree_map(
_to_dynamic_shapes_mapper(),
model_args,
)
else:
dynamic_shapes = None
return exporter.export_compat(
model, # type: ignore[arg-type]
model_args,
f=None,
kwargs=model_kwargs,
dynamic_shapes=dynamic_shapes,
opset_version=18,
external_data=True,
export_params=True,
fallback=True,
)
else:
from torch.onnx._internal._exporter_legacy import dynamo_export
return dynamo_export(
model, *model_args, export_options=export_options, **model_kwargs
)
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
# Returns True iff ONNX logging is turned on.
is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
def enable_log() -> None:
r"""Enables ONNX logging."""
_C._jit_set_onnx_log_enabled(True)
def disable_log() -> None:
r"""Disables ONNX logging."""
_C._jit_set_onnx_log_enabled(False)
"""Sets output stream for ONNX logging.
Args:
stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
as ``stream_name``.
"""
set_log_stream = _C._jit_set_onnx_log_output_stream
"""A simple logging facility for ONNX exporter.
Args:
args: Arguments are converted to string, concatenated together with a newline
character appended to the end, and flushed to output stream.
"""
log = _C._jit_onnx_log

View File

@ -0,0 +1,25 @@
"""Constant values used in ONNX."""
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 20
ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 17
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
INT64_MAX = 9223372036854775807
INT32_MAX = 2147483647
INT16_MAX = 32767
INT8_MAX = 127
UINT8_MAX = 255
INT64_MIN = -9223372036854775808
INT32_MIN = -2147483648
INT16_MIN = -32768
INT8_MIN = -128
UINT8_MIN = 0

View File

@ -0,0 +1,72 @@
"""Utility for deprecating functions."""
import functools
import textwrap
import warnings
from typing import Callable, TypeVar
from typing_extensions import ParamSpec
_T = TypeVar("_T")
_P = ParamSpec("_P")
def deprecated(
since: str, removed_in: str, instructions: str
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""Marks functions as deprecated.
It will result in a warning when the function is called and a note in the
docstring.
Args:
since: The version when the function was first deprecated.
removed_in: The version when the function will be removed.
instructions: The action users should take.
"""
def decorator(function: Callable[_P, _T]) -> Callable[_P, _T]:
@functools.wraps(function)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
warnings.warn(
f"'{function.__module__}.{function.__name__}' "
f"is deprecated in version {since} and will be "
f"removed in {removed_in}. Please {instructions}.",
category=FutureWarning,
stacklevel=2,
)
return function(*args, **kwargs)
# Add a deprecation note to the docstring.
docstring = function.__doc__ or ""
# Add a note to the docstring.
deprecation_note = textwrap.dedent(
f"""\
.. deprecated:: {since}
Deprecated and will be removed in version {removed_in}.
Please {instructions}.
"""
)
# Split docstring at first occurrence of newline
summary_and_body = docstring.split("\n\n", 1)
if len(summary_and_body) > 1:
summary, body = summary_and_body
# Dedent the body. We cannot do this with the presence of the summary because
# the body contains leading whitespaces when the summary does not.
body = textwrap.dedent(body)
new_docstring_parts = [deprecation_note, "\n\n", summary, body]
else:
summary = summary_and_body[0]
new_docstring_parts = [deprecation_note, "\n\n", summary]
wrapper.__doc__ = "".join(new_docstring_parts)
return wrapper
return decorator

View File

@ -0,0 +1,27 @@
"""Experimental classes and functions used by ONNX export."""
import dataclasses
from typing import Mapping, Optional, Sequence, Set, Type, Union
import torch
import torch._C._onnx as _C_onnx
@dataclasses.dataclass
class ExportOptions:
"""Arguments used by :func:`torch.onnx.export`."""
# TODO(justinchuby): Deprecate and remove this class.
export_params: bool = True
verbose: bool = False
training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
input_names: Optional[Sequence[str]] = None
output_names: Optional[Sequence[str]] = None
operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
opset_version: Optional[int] = None
do_constant_folding: bool = True
dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
keep_initializers_as_inputs: Optional[bool] = None
custom_opsets: Optional[Mapping[str, int]] = None
export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False

View File

@ -0,0 +1,12 @@
from __future__ import annotations
class ExportTypes:
"""Specifies how the ONNX model is stored."""
# TODO(justinchuby): Deprecate and remove this class.
PROTOBUF_FILE = "Saves model in the specified protobuf file."
ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
DIRECTORY = "Saves model in the specified folder."

View File

@ -0,0 +1,49 @@
"""Internal feature flags for torch.onnx.
NOTE: These flags are experimental only. Any flag here can be removed at any
time without notice.
"""
import logging
import os
logger = logging.getLogger(__name__)
def _load_boolean_flag(
name: str,
*,
this_will: str,
deprecated: bool = False,
default: bool = False,
) -> bool:
"""Load a boolean flag from environment variable.
Args:
name: The name of the environment variable.
this_will: A string that describes what this flag will do.
deprecated: Whether this flag is deprecated.
default: The default value if envvar not defined.
"""
undefined = os.getenv(name) is None
state = os.getenv(name) == "1"
if state:
if deprecated:
logger.error(
"Experimental flag %s is deprecated. Please remove it from your environment.",
name,
)
else:
logger.warning(
"Experimental flag %s is enabled. This will %s.", name, this_will
)
if undefined:
state = default
return state
USE_EXPERIMENTAL_LOGIC: bool = _load_boolean_flag(
"TORCH_ONNX_USE_EXPERIMENTAL_LOGIC",
this_will="use ExportedProgram and the new torch.onnx export logic",
)

View File

@ -0,0 +1,87 @@
# mypy: allow-untyped-defs
"""Globals used internally by the ONNX exporter.
Do not use this module outside of `torch.onnx` and its tests.
Be very judicious when adding any new global variables. Do not create new global
variables unless they are absolutely necessary.
"""
import torch._C._onnx as _C_onnx
# This module should only depend on _constants and nothing else in torch.onnx to keep
# dependency direction clean.
from torch.onnx import _constants
class _InternalGlobals:
"""Globals used internally by ONNX exporter.
NOTE: Be very judicious when adding any new variables. Do not create new
global variables unless they are absolutely necessary.
"""
def __init__(self) -> None:
self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
self._in_onnx_export: bool = False
# Whether the user's model is training during export
self.export_training: bool = False
self.operator_export_type: _C_onnx.OperatorExportTypes = (
_C_onnx.OperatorExportTypes.ONNX
)
self.onnx_shape_inference: bool = True
self._autograd_inlining: bool = True
@property
def training_mode(self):
"""The training mode for the exporter."""
return self._training_mode
@training_mode.setter
def training_mode(self, training_mode: _C_onnx.TrainingMode):
if not isinstance(training_mode, _C_onnx.TrainingMode):
raise TypeError(
"training_mode must be of type 'torch.onnx.TrainingMode'. This is "
"likely a bug in torch.onnx."
)
self._training_mode = training_mode
@property
def export_onnx_opset_version(self) -> int:
"""Opset version used during export."""
return self._export_onnx_opset_version
@export_onnx_opset_version.setter
def export_onnx_opset_version(self, value: int):
supported_versions = range(
_constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
)
if value not in supported_versions:
raise ValueError(f"Unsupported ONNX opset version: {value}")
self._export_onnx_opset_version = value
@property
def in_onnx_export(self) -> bool:
"""Whether it is in the middle of ONNX export."""
return self._in_onnx_export
@in_onnx_export.setter
def in_onnx_export(self, value: bool):
if type(value) is not bool:
raise TypeError("in_onnx_export must be a boolean")
self._in_onnx_export = value
@property
def autograd_inlining(self) -> bool:
"""Whether Autograd must be inlined."""
return self._autograd_inlining
@autograd_inlining.setter
def autograd_inlining(self, value: bool):
if type(value) is not bool:
raise TypeError("autograd_inlining must be a boolean")
self._autograd_inlining = value
GLOBALS = _InternalGlobals()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
"""Utility to lazily import modules."""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from typing import Any, TYPE_CHECKING
class _LazyModule:
"""Lazily import a module."""
def __init__(self, module_name: str) -> None:
self._name = module_name
self._module: Any = None
def __repr__(self) -> str:
return f"<lazy module '{self._name}'>"
def __getattr__(self, attr):
if self._module is None:
self._module = importlib.import_module(".", self._name)
return getattr(self._module, attr)
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
# NOTE: Add additional used imports here.
if TYPE_CHECKING:
import onnx
import onnxscript
import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
onnxscript_ir = onnxscript.ir
else:
onnx = _LazyModule("onnx")
onnxscript = _LazyModule("onnxscript")
onnxscript_ir = _LazyModule("onnxscript.ir")
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5")

View File

@ -0,0 +1,22 @@
from ._diagnostic import (
create_export_diagnostic_context,
diagnose,
engine,
export_context,
ExportDiagnosticEngine,
TorchScriptOnnxExportDiagnostic,
)
from ._rules import rules
from .infra import levels
__all__ = [
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
"engine",
"export_context",
"create_export_diagnostic_context",
"diagnose",
]

View File

@ -0,0 +1,211 @@
# mypy: allow-untyped-defs
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations
import contextlib
import gzip
from typing import TYPE_CHECKING
import torch
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace
if TYPE_CHECKING:
from collections.abc import Generator
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
"""Returns the current C++ call stack.
This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
The returned C++ call stack is a concatenated string of the C++ call stack frames.
Each frame is separated by a newline character, in the same format of
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
"""
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
frame_messages = []
for frame in frames:
segments = frame.split(":", 1)
if len(segments) == 2:
frame_messages.append(segments[1].strip())
else:
frame_messages.append("<unknown frame>")
return infra.Stack(
frames=[
infra.StackFrame(location=infra.Location(message=message))
for message in frame_messages
]
)
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
infra.Diagnostic, and adds additional methods to add more information to the
diagnostic.
"""
python_call_stack: infra.Stack | None = None
cpp_call_stack: infra.Stack | None = None
def __init__(
self,
*args,
frames_to_skip: int = 1,
cpp_stack: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.python_call_stack = self.record_python_call_stack(
frames_to_skip=frames_to_skip
)
if cpp_stack:
self.cpp_call_stack = self.record_cpp_call_stack(
frames_to_skip=frames_to_skip
)
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current C++ call stack in the diagnostic."""
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
stack.message = "C++ call stack"
self.with_stack(stack)
return stack
class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using `DiagnosticContext` directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
All `diagnose` calls inside exporter should be made in the context of that export.
However, since diagnostic context is currently being accessed via a global variable,
there is no guarantee that the context is properly initialized. Therefore, we need
to provide a default background context to fallback to, otherwise any invocation of
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
This can be removed once the pipeline for context to flow through the exporter is
established.
"""
contexts: list[infra.DiagnosticContext]
_background_context: infra.DiagnosticContext
def __init__(self) -> None:
self.contexts = []
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
)
@property
def background_context(self) -> infra.DiagnosticContext:
return self._background_context
def create_diagnostic_context(
self,
name: str,
version: str,
options: infra.DiagnosticOptions | None = None,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.
Args:
name: The subject name for the diagnostic context.
version: The subject version for the diagnostic context.
options: The options for the diagnostic context.
Returns:
A new diagnostic context.
"""
if options is None:
options = infra.DiagnosticOptions()
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
name, version, options
)
self.contexts.append(context)
return context
def clear(self):
"""Clears all diagnostic contexts."""
self.contexts.clear()
self._background_context.diagnostics.clear()
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())
def sarif_log(self):
log = sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[context.sarif() for context in self.contexts],
)
log.runs.append(self._background_context.sarif())
return log
engine = ExportDiagnosticEngine()
_context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context() -> (
Generator[infra.DiagnosticContext, None, None]
):
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by
export internals via global variable. See `ExportDiagnosticEngine` for more details.
"""
global _context
assert (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export",
torch.__version__,
)
try:
yield _context
finally:
_context = engine.background_context
def diagnose(
rule: infra.Rule,
level: infra.Level,
message: str | None = None,
frames_to_skip: int = 2,
**kwargs,
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
return diagnostic
def export_context() -> infra.DiagnosticContext:
global _context
return _context

View File

@ -0,0 +1,636 @@
# mypy: allow-untyped-defs
"""
GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py.
See tools/onnx/gen_diagnostics.py for more information.
Diagnostic rules for PyTorch ONNX export.
"""
import dataclasses
from typing import Tuple
# flake8: noqa
from torch.onnx._internal.diagnostics import infra
"""
GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message`
method to provide more details in the signature about the format arguments.
"""
class _NodeMissingOnnxShapeInference(infra.Rule):
"""Node is missing ONNX shape inference."""
def format_message(self, op_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
"""
return self.message_default_template.format(op_name=op_name)
def format( # type: ignore[override]
self, level: infra.Level, op_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
"""
return self, level, self.format_message(op_name=op_name)
class _MissingCustomSymbolicFunction(infra.Rule):
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
def format_message(self, op_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
"""
return self.message_default_template.format(op_name=op_name)
def format( # type: ignore[override]
self, level: infra.Level, op_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
"""
return self, level, self.format_message(op_name=op_name)
class _MissingStandardSymbolicFunction(infra.Rule):
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
def format_message( # type: ignore[override]
self, op_name, opset_version, issue_url
) -> str:
"""Returns the formatted default message of this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
"""
return self.message_default_template.format(
op_name=op_name, opset_version=opset_version, issue_url=issue_url
)
def format( # type: ignore[override]
self, level: infra.Level, op_name, opset_version, issue_url
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
"""
return (
self,
level,
self.format_message(
op_name=op_name, opset_version=opset_version, issue_url=issue_url
),
)
class _OperatorSupportedInNewerOpsetVersion(infra.Rule):
"""Operator is supported in newer opset version."""
def format_message( # type: ignore[override]
self, op_name, opset_version, supported_opset_version
) -> str:
"""Returns the formatted default message of this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
"""
return self.message_default_template.format(
op_name=op_name,
opset_version=opset_version,
supported_opset_version=supported_opset_version,
)
def format( # type: ignore[override]
self, level: infra.Level, op_name, opset_version, supported_opset_version
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
"""
return (
self,
level,
self.format_message(
op_name=op_name,
opset_version=opset_version,
supported_opset_version=supported_opset_version,
),
)
class _FxGraphToOnnx(infra.Rule):
"""Transforms graph from FX IR to ONNX IR."""
def format_message(self, graph_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
"""
return self.message_default_template.format(graph_name=graph_name)
def format( # type: ignore[override]
self, level: infra.Level, graph_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
"""
return self, level, self.format_message(graph_name=graph_name)
class _FxNodeToOnnx(infra.Rule):
"""Transforms an FX node to an ONNX node."""
def format_message(self, node_repr) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Transforming FX node {node_repr} to ONNX node.'
"""
return self.message_default_template.format(node_repr=node_repr)
def format( # type: ignore[override]
self, level: infra.Level, node_repr
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Transforming FX node {node_repr} to ONNX node.'
"""
return self, level, self.format_message(node_repr=node_repr)
class _FxPass(infra.Rule):
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
def format_message(self, pass_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Running {pass_name} pass.'
"""
return self.message_default_template.format(pass_name=pass_name)
def format( # type: ignore[override]
self, level: infra.Level, pass_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Running {pass_name} pass.'
"""
return self, level, self.format_message(pass_name=pass_name)
class _NoSymbolicFunctionForCallFunction(infra.Rule):
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
def format_message(self, target) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
"""
return self.message_default_template.format(target=target)
def format( # type: ignore[override]
self, level: infra.Level, target
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
"""
return self, level, self.format_message(target=target)
class _UnsupportedFxNodeAnalysis(infra.Rule):
"""Result from FX graph analysis to reveal unsupported FX nodes."""
def format_message( # type: ignore[override]
self, node_op_to_target_mapping
) -> str:
"""Returns the formatted default message of this Rule.
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
"""
return self.message_default_template.format(
node_op_to_target_mapping=node_op_to_target_mapping
)
def format( # type: ignore[override]
self, level: infra.Level, node_op_to_target_mapping
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
"""
return (
self,
level,
self.format_message(node_op_to_target_mapping=node_op_to_target_mapping),
)
class _OpLevelDebugging(infra.Rule):
"""Report any op level validation failure in warnings."""
def format_message(self, node, symbolic_fn) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
"""
return self.message_default_template.format(node=node, symbolic_fn=symbolic_fn)
def format( # type: ignore[override]
self, level: infra.Level, node, symbolic_fn
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
"""
return self, level, self.format_message(node=node, symbolic_fn=symbolic_fn)
class _FindOpschemaMatchedSymbolicFunction(infra.Rule):
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
def format_message(self, symbolic_fn, node) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
"""
return self.message_default_template.format(symbolic_fn=symbolic_fn, node=node)
def format( # type: ignore[override]
self, level: infra.Level, symbolic_fn, node
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
"""
return self, level, self.format_message(symbolic_fn=symbolic_fn, node=node)
class _FxNodeInsertTypePromotion(infra.Rule):
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
def format_message(self, target) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Performing explicit type promotion for node {target}. '
"""
return self.message_default_template.format(target=target)
def format( # type: ignore[override]
self, level: infra.Level, target
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Performing explicit type promotion for node {target}. '
"""
return self, level, self.format_message(target=target)
class _FindOperatorOverloadsInOnnxRegistry(infra.Rule):
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
def format_message(self, node) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
"""
return self.message_default_template.format(node=node)
def format( # type: ignore[override]
self, level: infra.Level, node
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
"""
return self, level, self.format_message(node=node)
@dataclasses.dataclass
class _POERules(infra.RuleCollection):
node_missing_onnx_shape_inference: _NodeMissingOnnxShapeInference = dataclasses.field(
default=_NodeMissingOnnxShapeInference.from_sarif(
**{
"id": "POE0001",
"name": "node-missing-onnx-shape-inference",
"short_description": {"text": "Node is missing ONNX shape inference."},
"full_description": {
"text": "Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec.",
"markdown": "Node is missing ONNX shape inference.\nThis usually happens when the node is not valid under standard ONNX operator spec.\n",
},
"message_strings": {
"default": {
"text": "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Node is missing ONNX shape inference."""
missing_custom_symbolic_function: _MissingCustomSymbolicFunction = dataclasses.field(
default=_MissingCustomSymbolicFunction.from_sarif(
**{
"id": "POE0002",
"name": "missing-custom-symbolic-function",
"short_description": {
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."
},
"full_description": {
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.",
"markdown": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.\n",
},
"message_strings": {
"default": {
"text": "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
missing_standard_symbolic_function: _MissingStandardSymbolicFunction = dataclasses.field(
default=_MissingStandardSymbolicFunction.from_sarif(
**{
"id": "POE0003",
"name": "missing-standard-symbolic-function",
"short_description": {
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."
},
"full_description": {
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.",
"markdown": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.\n",
},
"message_strings": {
"default": {
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
operator_supported_in_newer_opset_version: _OperatorSupportedInNewerOpsetVersion = dataclasses.field(
default=_OperatorSupportedInNewerOpsetVersion.from_sarif(
**{
"id": "POE0004",
"name": "operator-supported-in-newer-opset-version",
"short_description": {
"text": "Operator is supported in newer opset version."
},
"full_description": {
"text": "Operator is supported in newer opset version.",
"markdown": "Operator is supported in newer opset version.\n\nExample:\n```python\ntorch.onnx.export(model, args, ..., opset_version=9)\n```\n",
},
"message_strings": {
"default": {
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Operator is supported in newer opset version."""
fx_graph_to_onnx: _FxGraphToOnnx = dataclasses.field(
default=_FxGraphToOnnx.from_sarif(
**{
"id": "FXE0007",
"name": "fx-graph-to-onnx",
"short_description": {
"text": "Transforms graph from FX IR to ONNX IR."
},
"full_description": {
"text": "Transforms graph from FX IR to ONNX IR.",
"markdown": "This diagnostic tracks the transformation process from an FX Graph (in FX IR) to an ONNX Graph (in ONNX IR).\n\n## Key Representations:\n\n- **FX Graph**: The graph in FX IR produced by dynamo or symbolic tracing.\n- **ONNX Graph**: The graph in ONNX IR and [operators](https://onnx.ai/onnx/operators/).\n\n## Additional Notes:\n\n- Prior to this transformation step, the FX graph undergoes preprocessing through multiple FX passes.\n To gain insight into these transformations, refer to diagnostic `FXE0010`.\n- To enable a detailed view of the graph transformation in progress within this diagnostic, switch to the DEBUG mode.\n\n - Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n - Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\n- For specific information related to node-level FX to ONNX transformations, explore the diagnostic `FXE0008`.\n",
},
"message_strings": {
"default": {
"text": "Transforming FX graph {graph_name} to ONNX graph."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Transforms graph from FX IR to ONNX IR."""
fx_node_to_onnx: _FxNodeToOnnx = dataclasses.field(
default=_FxNodeToOnnx.from_sarif(
**{
"id": "FXE0008",
"name": "fx-node-to-onnx",
"short_description": {"text": "Transforms an FX node to an ONNX node."},
"full_description": {
"text": "Transforms an FX node to an ONNX node.",
"markdown": "This diagnostic tracks the transformation process from an FX Node to ONNX [Operators](https://onnx.ai/onnx/operators/).\n\nThe process of converting FX Node to ONNX Node involves dealing with six distinct node types:\n 1. `placeholder`: Represents a module input, maps to an ONNX graph input.\n 2. `call_module`: Symbolizes a call to a submodule, maps to an ONNX\n 3. `call_method`: Symbolizes a method call. Not yet implemented.\n 4. `call_function`: Symbolizes a function call. [Core ATen](https://pytorch.org/docs/stable/ir.html#core-aten-ir) is expected\n as the function call target. The mapping from ATen to ONNX is implemented by [ONNXScript torchlib](https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops).\n This [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) shows how to write and register a custom symbolic function for call_function FX node.\n 5. `get_attr`: Indicates an attribute access within the current module. Maps to an ONNX graph initializer.\n 6. `output`: Represents the module's output. Maps to an ONNX graph output.\n\nFor a granular understanding of how each node type is transformed, refer to the implementation details in `FxOnnxInterpreter`.\n",
},
"message_strings": {
"default": {
"text": "Transforming FX node {node_repr} to ONNX node."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Transforms an FX node to an ONNX node."""
fx_pass: _FxPass = dataclasses.field(
default=_FxPass.from_sarif(
**{
"id": "FXE0010",
"name": "fx-pass",
"short_description": {
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
},
"full_description": {
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
"markdown": "This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n",
},
"message_strings": {"default": {"text": "Running {pass_name} pass."}},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
no_symbolic_function_for_call_function: _NoSymbolicFunctionForCallFunction = dataclasses.field(
default=_NoSymbolicFunctionForCallFunction.from_sarif(
**{
"id": "FXE0011",
"name": "no-symbolic-function-for-call-function",
"short_description": {
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX.'
},
"full_description": {
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX. ',
"markdown": 'This error occurs when the ONNX converter is unable to find a corresponding symbolic function\nto convert a "call_function" node in the input graph to its equivalence in ONNX. The "call_function"\nnode represents a normalized function call in PyTorch, such as "torch.aten.ops.add".\n\nTo resolve this error, you can try one of the following:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html#overview) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n',
},
"message_strings": {
"default": {
"text": 'No symbolic function to convert the "call_function" node {target} to ONNX. '
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
unsupported_fx_node_analysis: _UnsupportedFxNodeAnalysis = dataclasses.field(
default=_UnsupportedFxNodeAnalysis.from_sarif(
**{
"id": "FXE0012",
"name": "unsupported-fx-node-analysis",
"short_description": {
"text": "Result from FX graph analysis to reveal unsupported FX nodes."
},
"full_description": {
"text": "Result from FX graph analysis to reveal unsupported FX nodes.",
"markdown": "This error indicates that an FX graph contains one or more unsupported nodes. The error message\nis typically accompanied by a list of the unsupported nodes found during analysis.\n\nTo resolve this error, you can try resolving each individual unsupported node error by following\nthe suggestions by its diagnostic. Typically, options include:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n",
},
"message_strings": {
"default": {
"text": "Unsupported FX nodes: {node_op_to_target_mapping}. "
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Result from FX graph analysis to reveal unsupported FX nodes."""
op_level_debugging: _OpLevelDebugging = dataclasses.field(
default=_OpLevelDebugging.from_sarif(
**{
"id": "FXE0013",
"name": "op-level-debugging",
"short_description": {
"text": "Report any op level validation failure in warnings."
},
"full_description": {
"text": "Report any op level validation failure in warnings.",
"markdown": "This warning message indicates that during op level debugging, certain symbolic functions\nhave failed to match the results of torch ops when using real tensors generated from fake\ntensors. It is important to note that the symbolic functions may not necessarily be\nincorrect, as the validation process is non-deterministic and should only be used as a\nreference.\n\nThere are two categories of warnings that can be triggered:\n\n1. Non-validated operators:\n If the warnings are caused by the following errors, they can be disregarded by users,\n as these errors occur due to the non-deterministic nature of the validation. However,\n it is important to be aware that the operators have not been validated.\n\n - IndexError: Unsupported input arguments of randomized dimensions/indices(INT64).\n - RuntimeError: Unsupported input arguments for torch ops are generated.\n - ValueError: Arguments/keyword arguments do not match the signature of the symbolic function.\n\n2. Potentially wrong torchlib operators:\n If the warnings are triggered by the following error, users should be aware that the symbolic functions\n may be incorrect in dispatching or implementation. In such cases, it is recommended to report\n the issue to the PyTorch-ONNX team, or create/register a custom symbolic function to replace the default one.\n\n - AssertionError: The symbolic function is potentially wrong as the results do not match the results of torch ops.\n - TypeError: The symbolic function is potentially wrong as the opschema doesn't match inputs.\n",
},
"message_strings": {
"default": {
"text": "FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Report any op level validation failure in warnings."""
find_opschema_matched_symbolic_function: _FindOpschemaMatchedSymbolicFunction = dataclasses.field(
default=_FindOpschemaMatchedSymbolicFunction.from_sarif(
**{
"id": "FXE0014",
"name": "find-opschema-matched-symbolic-function",
"short_description": {
"text": "Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."
},
"full_description": {
"text": "Find the OnnxFunction that matches the input dtypes by comparing them with their opschemas. A warning will be issued if the matched OnnxFunction is not an exact match.",
"markdown": "When an ATen/Custom operator is registered and needs to be dispatched to an OnnxFunction, the input/attribute\ndtypes of the ATen/Custom operator are compared with the input/attribute dtypes of the OnnxFunction opschemas\nto find a match. However, if a perfect/exact match is not found, the dispatcher will attempt to find\nthe nearest match with the highest number of input/attribute dtypes matching the OnnxFunction opschemas, while\nissuing a warning.\n\nThere are two types of level that can be triggered in this rule:\n\n1. NOTE: A perfect match is found, and no warning is issued.\n2. WARNING: The matched OnnxFunction is not a perfect/exact match.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning,\n as the definition of OnnxFunction schema is usually more stringent.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the issue to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n",
},
"message_strings": {
"default": {
"text": "The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
fx_node_insert_type_promotion: _FxNodeInsertTypePromotion = dataclasses.field(
default=_FxNodeInsertTypePromotion.from_sarif(
**{
"id": "FXE0015",
"name": "fx-node-insert-type-promotion",
"short_description": {
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed."
},
"full_description": {
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed.",
"markdown": "This diagnostic monitors the node-level type promotion insertion process. In PyTorch, there is an automatic process called implicit type promotion,\nwhere the input types of an operator are promoted to a common type. The determination of the common type is based on the type promotion rule specific to each operator.\nTo learn more about PyTorch's type promotion rules, refer to the [elementwise_dtypes doc](https://github.com/pytorch/pytorch/blob/f044613f78df713fb57f70c608483c9f10ad332e/torch/_prims_common/__init__.py#L1252-L1335)\nand [torch._refs ops](https://github.com/pytorch/pytorch/blob/a475ea4542dfe961c9d097e33ab5041f61c8c17f/torch/_refs/__init__.py#L484).\n\nHowever, implicit type promotion is not supported in ONNX. Therefore, to replicate the PyTorch behavior, we need to explicitly insert cast nodes.\nThis diagnostic tracks the process of node-level type promotion insertion.\n\nThe type promotion rules used by this process can be found in `torch/onnx/_internal/fx/passes/type_promotion.py.`\nTo update or add new type promotion rules, please refer to the [Note: Update type promotion rule] section.\n",
},
"message_strings": {
"default": {
"text": "Performing explicit type promotion for node {target}. "
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
find_operator_overloads_in_onnx_registry: _FindOperatorOverloadsInOnnxRegistry = dataclasses.field(
default=_FindOperatorOverloadsInOnnxRegistry.from_sarif(
**{
"id": "FXE0016",
"name": "find-operator-overloads-in-onnx-registry",
"short_description": {
"text": "Find the list of OnnxFunction of the PyTorch operator in onnx registry."
},
"full_description": {
"text": "This rule involves finding the list of OnnxFunction for the PyTorch operator overload in the ONNX registry. If the operator overload is not supported but its default overload is, a warning will be issued. If both the operator overload and its default overload are not supported, an error will be issued.",
"markdown": "The operator overload name serves the purpose of verifying whether a PyTorch operator is registered in the ONNX registry.\nIf it's not found, the dispatcher takes a fallback approach and tries to locate the default overload of the PyTorch\noperator in the registry. If even the default overload is absent, it signifies that the operator is officially unsupported.\n\nThere are three types of level that can be triggered in this rule:\n\n1. NOTE: The op overload is supported.\n2. WARNING: The op overload is not supported, but it's default overload is supported.\n3. ERROR: The op overload is not supported, and it's default overload is also not supported.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the unsupported overload to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n\nHere are some suggestions based on the ERROR situation:\n\n1. Report the unsupported operator to the PyTorch-ONNX team.\n2. Create/register a custom symbolic function to replace the default one.\n",
},
"message_strings": {
"default": {
"text": "Checking if the FX node: {node} is supported in onnx registry."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
rules = _POERules()

View File

@ -0,0 +1,34 @@
from ._infra import (
DiagnosticOptions,
Graph,
Invocation,
Level,
levels,
Location,
Rule,
RuleCollection,
Stack,
StackFrame,
Tag,
ThreadFlowLocation,
)
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic
__all__ = [
"Diagnostic",
"DiagnosticContext",
"DiagnosticOptions",
"Graph",
"Invocation",
"Level",
"levels",
"Location",
"Rule",
"RuleCollection",
"RuntimeErrorWithDiagnostic",
"Stack",
"StackFrame",
"Tag",
"ThreadFlowLocation",
]

View File

@ -0,0 +1,285 @@
# mypy: allow-untyped-defs
"""This file defines an additional layer of abstraction on top of the SARIF OM."""
from __future__ import annotations
import dataclasses
import enum
import logging
from typing import Mapping, Sequence
from torch.onnx._internal.diagnostics.infra import formatter, sarif
class Level(enum.IntEnum):
"""The level of a diagnostic.
This class is used to represent the level of a diagnostic. The levels are defined
by the SARIF specification, and are not modifiable. For alternative categories,
please use infra.Tag instead. When selecting a level, please consider the following
guidelines:
- NONE: Informational result that does not indicate the presence of a problem.
- NOTE: An opportunity for improvement was found.
- WARNING: A potential problem was found.
- ERROR: A serious problem was found.
This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
value maps to the logging levels in Python's logging module. The mapping is as
follows:
Level.NONE = logging.DEBUG = 10
Level.NOTE = logging.INFO = 20
Level.WARNING = logging.WARNING = 30
Level.ERROR = logging.ERROR = 40
"""
NONE = 10
NOTE = 20
WARNING = 30
ERROR = 40
levels = Level
class Tag(enum.Enum):
"""The tag of a diagnostic. This class can be inherited to define custom tags."""
class PatchedPropertyBag(sarif.PropertyBag):
"""Key/value pairs that provide additional information about the object.
The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6)
containing an unordered set of properties with arbitrary names." However it is not
reflected in the json file, and therefore not captured by the python representation.
This patch adds additional **kwargs to the `__init__` method to allow recording
arbitrary key/value pairs.
"""
def __init__(self, tags: list[str] | None = None, **kwargs):
super().__init__(tags=tags)
self.__dict__.update(kwargs)
@dataclasses.dataclass(frozen=True)
class Rule:
id: str
name: str
message_default_template: str
short_description: str | None = None
full_description: str | None = None
full_description_markdown: str | None = None
help_uri: str | None = None
@classmethod
def from_sarif(cls, **kwargs):
"""Returns a rule from the SARIF reporting descriptor."""
short_description = kwargs.get("short_description", {}).get("text")
full_description = kwargs.get("full_description", {}).get("text")
full_description_markdown = kwargs.get("full_description", {}).get("markdown")
help_uri = kwargs.get("help_uri")
rule = cls(
id=kwargs["id"],
name=kwargs["name"],
message_default_template=kwargs["message_strings"]["default"]["text"],
short_description=short_description,
full_description=full_description,
full_description_markdown=full_description_markdown,
help_uri=help_uri,
)
return rule
def sarif(self) -> sarif.ReportingDescriptor:
"""Returns a SARIF reporting descriptor of this Rule."""
short_description = (
sarif.MultiformatMessageString(text=self.short_description)
if self.short_description is not None
else None
)
full_description = (
sarif.MultiformatMessageString(
text=self.full_description, markdown=self.full_description_markdown
)
if self.full_description is not None
else None
)
return sarif.ReportingDescriptor(
id=self.id,
name=self.name,
short_description=short_description,
full_description=full_description,
help_uri=self.help_uri,
)
def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]:
"""Returns a tuple of (rule, level, message) for a diagnostic.
This method is used to format the message of a diagnostic. The message is
formatted using the default template of this rule, and the arguments passed in
as `*args` and `**kwargs`. The level is used to override the default level of
this rule.
"""
return (self, level, self.format_message(*args, **kwargs))
def format_message(self, *args, **kwargs) -> str:
"""Returns the formatted default message of this Rule.
This method should be overridden (with code generation) by subclasses to reflect
the exact arguments needed by the message template. This is a helper method to
create the default message for a diagnostic.
"""
return self.message_default_template.format(*args, **kwargs)
@dataclasses.dataclass
class Location:
uri: str | None = None
line: int | None = None
message: str | None = None
start_column: int | None = None
end_column: int | None = None
snippet: str | None = None
function: str | None = None
def sarif(self) -> sarif.Location:
"""Returns the SARIF representation of this location."""
return sarif.Location(
physical_location=sarif.PhysicalLocation(
artifact_location=sarif.ArtifactLocation(uri=self.uri),
region=sarif.Region(
start_line=self.line,
start_column=self.start_column,
end_column=self.end_column,
snippet=sarif.ArtifactContent(text=self.snippet),
),
),
message=sarif.Message(text=self.message)
if self.message is not None
else None,
)
@dataclasses.dataclass
class StackFrame:
location: Location
def sarif(self) -> sarif.StackFrame:
"""Returns the SARIF representation of this stack frame."""
return sarif.StackFrame(location=self.location.sarif())
@dataclasses.dataclass
class Stack:
"""Records a stack trace. The frames are in order from newest to oldest stack frame."""
frames: list[StackFrame] = dataclasses.field(default_factory=list)
message: str | None = None
def sarif(self) -> sarif.Stack:
"""Returns the SARIF representation of this stack."""
return sarif.Stack(
frames=[frame.sarif() for frame in self.frames],
message=sarif.Message(text=self.message)
if self.message is not None
else None,
)
@dataclasses.dataclass
class ThreadFlowLocation:
"""Records code location and the initial state."""
location: Location
state: Mapping[str, str]
index: int
stack: Stack | None = None
def sarif(self) -> sarif.ThreadFlowLocation:
"""Returns the SARIF representation of this thread flow location."""
return sarif.ThreadFlowLocation(
location=self.location.sarif(),
state=self.state,
stack=self.stack.sarif() if self.stack is not None else None,
)
@dataclasses.dataclass
class Graph:
"""A graph of diagnostics.
This class stores the string representation of a model graph.
The `nodes` and `edges` fields are unused in the current implementation.
"""
graph: str
name: str
description: str | None = None
def sarif(self) -> sarif.Graph:
"""Returns the SARIF representation of this graph."""
return sarif.Graph(
description=sarif.Message(text=self.graph),
properties=PatchedPropertyBag(name=self.name, description=self.description),
)
@dataclasses.dataclass
class RuleCollection:
_rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False)
def __post_init__(self) -> None:
self._rule_id_name_set = frozenset(
{
(field.default.id, field.default.name)
for field in dataclasses.fields(self)
if isinstance(field.default, Rule)
}
)
def __contains__(self, rule: Rule) -> bool:
"""Checks if the rule is in the collection."""
return (rule.id, rule.name) in self._rule_id_name_set
@classmethod
def custom_collection_from_list(
cls, new_collection_class_name: str, rules: Sequence[Rule]
) -> RuleCollection:
"""Creates a custom class inherited from RuleCollection with the list of rules."""
return dataclasses.make_dataclass(
new_collection_class_name,
[
(
formatter.kebab_case_to_snake_case(rule.name),
type(rule),
dataclasses.field(default=rule),
)
for rule in rules
],
bases=(cls,),
)()
class Invocation:
# TODO: Implement this.
# Tracks top level call arguments and diagnostic options.
def __init__(self) -> None:
raise NotImplementedError
@dataclasses.dataclass
class DiagnosticOptions:
"""Options for diagnostic context.
Attributes:
verbosity_level: Set the amount of information logged for each diagnostics,
equivalent to the 'level' in Python logging module.
warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
"""
verbosity_level: int = dataclasses.field(default=logging.INFO)
"""Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warning diagnostics are treated as error diagnostics."""

View File

@ -0,0 +1,404 @@
# mypy: allow-untyped-defs
"""A diagnostic context based on SARIF."""
from __future__ import annotations
import contextlib
import dataclasses
import gzip
import logging
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
from typing_extensions import Self
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
diagnostic_logger: logging.Logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Diagnostic:
rule: infra.Rule
level: infra.Level
message: str | None = None
locations: list[infra.Location] = dataclasses.field(default_factory=list)
stacks: list[infra.Stack] = dataclasses.field(default_factory=list)
graphs: list[infra.Graph] = dataclasses.field(default_factory=list)
thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field(
default_factory=list
)
additional_messages: list[str] = dataclasses.field(default_factory=list)
tags: list[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Exception | None = None
"""The exception that caused this diagnostic to be created."""
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
"""The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
log level setting with `DiagnosticOptions.verbosity_level`."""
_current_log_section_depth: int = 0
def __post_init__(self) -> None:
pass
def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
message = self.message or self.rule.message_default_template
if self.additional_messages:
additional_message = "\n".join(self.additional_messages)
message_markdown = (
f"{message}\n\n## Additional Message:\n\n{additional_message}"
)
else:
message_markdown = message
kind: Literal["informational", "fail"] = (
"informational" if self.level == infra.Level.NONE else "fail"
)
sarif_result = sarif.Result(
message=sarif.Message(text=message, markdown=message_markdown),
level=self.level.name.lower(), # type: ignore[arg-type]
rule_id=self.rule.id,
kind=kind,
)
sarif_result.locations = [location.sarif() for location in self.locations]
sarif_result.stacks = [stack.sarif() for stack in self.stacks]
sarif_result.graphs = [graph.sarif() for graph in self.graphs]
sarif_result.code_flows = [
sarif.CodeFlow(
thread_flows=[
sarif.ThreadFlow(
locations=[loc.sarif() for loc in self.thread_flow_locations]
)
]
)
]
sarif_result.properties = sarif.PropertyBag(
tags=[tag.value for tag in self.tags]
)
return sarif_result
def with_location(self: Self, location: infra.Location) -> Self:
"""Adds a location to the diagnostic."""
self.locations.append(location)
return self
def with_thread_flow_location(
self: Self, location: infra.ThreadFlowLocation
) -> Self:
"""Adds a thread flow location to the diagnostic."""
self.thread_flow_locations.append(location)
return self
def with_stack(self: Self, stack: infra.Stack) -> Self:
"""Adds a stack to the diagnostic."""
self.stacks.append(stack)
return self
def with_graph(self: Self, graph: infra.Graph) -> Self:
"""Adds a graph to the diagnostic."""
self.graphs.append(graph)
return self
@contextlib.contextmanager
def log_section(
self, level: int, message: str, *args, **kwargs
) -> Generator[None, None, None]:
"""
Context manager for a section of log messages, denoted by a title message and increased indentation.
Same api as `logging.Logger.log`.
This context manager logs the given title at the specified log level, increases the current
section depth for subsequent log messages, and ensures that the section depth is decreased
again when exiting the context.
Args:
level: The log level.
message: The title message to log.
*args: The arguments to the message. Use `LazyString` to defer the
expensive evaluation of the arguments until the message is actually logged.
**kwargs: The keyword arguments for `logging.Logger.log`.
Yields:
None: This context manager does not yield any value.
Example:
>>> with DiagnosticContext("DummyContext", "1.0"):
... rule = infra.Rule("RuleID", "DummyRule", "Rule message")
... diagnostic = Diagnostic(rule, infra.Level.WARNING)
... with diagnostic.log_section(logging.INFO, "My Section"):
... diagnostic.log(logging.INFO, "My Message")
... with diagnostic.log_section(logging.INFO, "My Subsection"):
... diagnostic.log(logging.INFO, "My Submessage")
... diagnostic.additional_messages
['## My Section', 'My Message', '### My Subsection', 'My Submessage']
"""
if self.logger.isEnabledFor(level):
indented_format_message = (
f"##{'#' * self._current_log_section_depth } {message}"
)
self.log(
level,
indented_format_message,
*args,
**kwargs,
)
self._current_log_section_depth += 1
try:
yield
finally:
self._current_log_section_depth -= 1
def log(self, level: int, message: str, *args, **kwargs) -> None:
"""Logs a message within the diagnostic. Same api as `logging.Logger.log`.
If logger is not enabled for the given level, the message will not be logged.
Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default,
the log level recommendations are as follows. If you've set a different default verbosity level in your
application, please adjust accordingly:
- logging.ERROR: Log any events leading to application failure.
- logging.WARNING: Log events that might result in application issues or failures, although not guaranteed.
- logging.INFO: Log general useful information, ensuring minimal performance overhead.
- logging.DEBUG: Log detailed debug information, which might affect performance when logged.
Args:
level: The log level.
message: The message to log.
*args: The arguments to the message. Use `LazyString` to defer the
expensive evaluation of the arguments until the message is actually logged.
**kwargs: The keyword arguments for `logging.Logger.log`.
"""
if self.logger.isEnabledFor(level):
formatted_message = message % args
self.logger.log(level, formatted_message, **kwargs)
self.additional_messages.append(formatted_message)
def debug(self, message: str, *args, **kwargs) -> None:
"""Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
Checkout `log` for more details.
"""
self.log(logging.DEBUG, message, *args, **kwargs)
def info(self, message: str, *args, **kwargs) -> None:
"""Logs an info message within the diagnostic. Same api as logging.Logger.info.
Checkout `log` for more details.
"""
self.log(logging.INFO, message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs) -> None:
"""Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
Checkout `log` for more details.
"""
self.log(logging.WARNING, message, *args, **kwargs)
def error(self, message: str, *args, **kwargs) -> None:
"""Logs an error message within the diagnostic. Same api as logging.Logger.error.
Checkout `log` for more details.
"""
self.log(logging.ERROR, message, *args, **kwargs)
def log_source_exception(self, level: int, exception: Exception) -> None:
"""Logs a source exception within the diagnostic.
Invokes `log_section` and `log` to log the exception in markdown section format.
"""
self.source_exception = exception
with self.log_section(level, "Exception log"):
self.log(level, "%s", formatter.lazy_format_exception(exception))
def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current Python call stack."""
frames_to_skip += 1 # Skip this function.
stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
self.with_stack(stack)
if len(stack.frames) > 0:
self.with_location(stack.frames[0].location)
return stack
def record_python_call(
self,
fn: Callable,
state: Mapping[str, str],
message: str | None = None,
frames_to_skip: int = 0,
) -> infra.ThreadFlowLocation:
"""Records a python call as one thread flow step."""
frames_to_skip += 1 # Skip this function.
stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
location = utils.function_location(fn)
location.message = message
# Add function location to the top of the stack.
stack.frames.insert(0, infra.StackFrame(location=location))
thread_flow_location = infra.ThreadFlowLocation(
location=location,
state=state,
index=len(self.thread_flow_locations),
stack=stack,
)
self.with_thread_flow_location(thread_flow_location)
return thread_flow_location
class RuntimeErrorWithDiagnostic(RuntimeError):
"""Runtime error with enclosed diagnostic information."""
def __init__(self, diagnostic: Diagnostic):
super().__init__(diagnostic.message)
self.diagnostic = diagnostic
@dataclasses.dataclass
class DiagnosticContext(Generic[_Diagnostic]):
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
# TODO(bowbao): Implement this.
# _invocation: infra.Invocation = dataclasses.field(init=False)
_inflight_diagnostics: list[_Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic)
def __enter__(self):
self._previous_log_level = self.logger.level
self.logger.setLevel(self.options.verbosity_level)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.setLevel(self._previous_log_level)
return None
def sarif(self) -> sarif.Run:
"""Returns the SARIF Run object."""
unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
return sarif.Run(
sarif.Tool(
driver=sarif.ToolComponent(
name=self.name,
version=self.version,
rules=[rule.sarif() for rule in unique_rules],
)
),
results=[diagnostic.sarif() for diagnostic in self.diagnostics],
)
def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined]
"""Returns the SARIF Log object."""
return sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[self.sarif()],
)
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())
def log(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Args:
diagnostic: The diagnostic to add.
"""
if not isinstance(diagnostic, self._bound_diagnostic_type):
raise TypeError(
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING: # type: ignore[attr-defined]
diagnostic.level = infra.Level.ERROR # type: ignore[attr-defined]
self.diagnostics.append(diagnostic) # type: ignore[arg-type]
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.
Use this method for logging non inflight diagnostics where diagnostic level is not known or
lower than ERROR. If it is always expected raise, use `log` and explicit
`raise` instead. Otherwise there is no way to convey the message that it always
raises to Python intellisense and type checking tools.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Args:
diagnostic: The diagnostic to add.
"""
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
if diagnostic.source_exception is not None:
raise diagnostic.source_exception
raise RuntimeErrorWithDiagnostic(diagnostic)
@contextlib.contextmanager
def add_inflight_diagnostic(
self, diagnostic: _Diagnostic
) -> Generator[_Diagnostic, None, None]:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Args:
diagnostic: The diagnostic to add.
"""
self._inflight_diagnostics.append(diagnostic)
try:
yield diagnostic
finally:
self._inflight_diagnostics.pop()
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Args:
diagnostic: The diagnostic to push.
Raises:
ValueError: If the rule is not supported by the tool.
"""
self._inflight_diagnostics.append(diagnostic)
def pop_inflight_diagnostic(self) -> _Diagnostic:
"""Pops the last diagnostic from the inflight diagnostics stack.
Returns:
The popped diagnostic.
"""
return self._inflight_diagnostics.pop()
def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
raise AssertionError("No inflight diagnostics")
return self._inflight_diagnostics[-1]
else:
for diagnostic in reversed(self._inflight_diagnostics):
if diagnostic.rule == rule:
return diagnostic
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")

View File

@ -0,0 +1,153 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
import traceback
from typing import Any, Callable, Dict, Tuple
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, utils
MessageFormatterType = Callable[..., str]
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
return f"{formatter.display_name(fn)}. "
def format_exception_in_markdown(exception: Exception) -> str:
msg_list = ["### Exception log", "```"]
msg_list.extend(
traceback.format_exception(type(exception), exception, exception.__traceback__)
)
msg_list.append("```")
return "\n".join(msg_list)
def format_function_signature_in_markdown(
fn: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
format_argument: Callable[[Any], str] = formatter.format_argument,
) -> str:
msg_list = [f"### Function Signature {formatter.display_name(fn)}"]
state = utils.function_state(fn, args, kwargs)
for k, v in state.items():
msg_list.append(f"- {k}: {format_argument(v)}")
return "\n".join(msg_list)
def format_return_values_in_markdown(
return_values: Any,
format_argument: Callable[[Any], str] = formatter.format_argument,
) -> str:
return f"{format_argument(return_values)}"
ModifierCallableType = Callable[
[infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None
]
def diagnose_call(
rule: infra.Rule,
*,
level: infra.Level = infra.Level.NONE,
diagnostic_type: type[infra.Diagnostic] = infra.Diagnostic,
format_argument: Callable[[Any], str] = formatter.format_argument,
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
) -> Callable:
def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
common_error_message = "diagnose_call can only be applied to callables"
if not callable(fn):
raise AssertionError(
f"{common_error_message}. Got {type(fn)} instead of callable."
)
arg0 = args[0] if len(args) > 0 else None
if isinstance(ctx := arg0, infra.DiagnosticContext):
pass
elif isinstance(
ctx := getattr(arg0, "diagnostic_context", None),
infra.DiagnosticContext,
):
pass
else:
# NOTE: At decorate time, it can't tell if a callable is function or method.
# Technically both are regarded as function at that time.
raise AssertionError(
f"{common_error_message}. For {fn}, "
f"If it is a function, a DiagnosticContext instance must be present as "
f"the first argument. "
f"If it is a method, a DiagnosticContext instance must be present as "
f"the attribute 'diagnostic_context' of the 'self' argument."
)
diag = diagnostic_type(
rule,
level,
diagnostic_message_formatter(fn, *args, **kwargs),
)
# pop the decorator frame
# TODO(bowbao): by default diagnostic doesn't have stack.
# So need to check before doing this. Make the code cleaner.
# Option: do not capture stack by default in diagnostic initialization.
stack: infra.Stack | None = None
if len(diag.stacks) > 0:
stack = diag.stacks[0]
stack.frames.pop(0)
# set function location
fn_location = utils.function_location(fn)
diag.locations.insert(0, fn_location)
# Add function location to the top of the stack.
if stack is not None:
stack.frames.insert(0, infra.StackFrame(location=fn_location))
with diag.log_section(logging.INFO, "Function Signature"):
diag.log(
logging.INFO,
"%s",
formatter.LazyString(
format_function_signature_in_markdown,
fn,
args,
kwargs,
format_argument,
),
)
return_values: Any = None
with ctx.add_inflight_diagnostic(diag) as diag:
try:
return_values = fn(*args, **kwargs)
with diag.log_section(logging.INFO, "Return values"):
diag.log(
logging.INFO,
"%s",
formatter.LazyString(
format_return_values_in_markdown,
return_values,
format_argument,
),
)
return return_values
except Exception as e:
diag.log_source_exception(logging.ERROR, e)
diag.level = infra.Level.ERROR
finally:
ctx.log_and_raise_if_error(diag)
return wrapper
return decorator
# TODO(bowbao): decorator to report only when failed.

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import dataclasses
import json
import re
import traceback
from typing import Any, Callable, Union
from torch._logging import LazyString
from torch.onnx._internal.diagnostics.infra import sarif
# A list of types in the SARIF module to support pretty printing.
# This is solely for type annotation for the functions below.
_SarifClass = Union[
sarif.SarifLog,
sarif.Run,
sarif.ReportingDescriptor,
sarif.Result,
]
def lazy_format_exception(exception: Exception) -> LazyString:
return LazyString(
lambda: "\n".join(
(
"```",
*traceback.format_exception(
type(exception), exception, exception.__traceback__
),
"```",
)
),
)
def snake_case_to_camel_case(s: str) -> str:
splits = s.split("_")
if len(splits) <= 1:
return s
return "".join([splits[0], *map(str.capitalize, splits[1:])])
def camel_case_to_snake_case(s: str) -> str:
return re.sub(r"([A-Z])", r"_\1", s).lower()
def kebab_case_to_snake_case(s: str) -> str:
return s.replace("-", "_")
def _convert_key(
object: dict[str, Any] | Any, convert: Callable[[str], str]
) -> dict[str, Any] | Any:
"""Convert and update keys in a dictionary with "convert".
Any value that is a dictionary will be recursively updated.
Any value that is a list will be recursively searched.
Args:
object: The object to update.
convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
Returns:
The updated object.
"""
if not isinstance(object, dict):
return object
new_dict = {}
for k, v in object.items():
new_k = convert(k)
if isinstance(v, dict):
new_v = _convert_key(v, convert)
elif isinstance(v, list):
new_v = [_convert_key(elem, convert) for elem in v]
else:
new_v = v
if new_v is None:
# Otherwise unnecessarily bloated sarif log with "null"s.
continue
if new_v == -1:
# WAR: -1 as default value shouldn't be logged into sarif.
continue
new_dict[new_k] = new_v
return new_dict
def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str:
dict = dataclasses.asdict(attr_cls_obj)
dict = _convert_key(dict, snake_case_to_camel_case)
return json.dumps(dict, indent=indent, separators=(",", ":"))
def format_argument(obj: Any) -> str:
return f"{type(obj)}"
def display_name(fn: Callable) -> str:
if hasattr(fn, "__qualname__"):
return fn.__qualname__
elif hasattr(fn, "__name__"):
return fn.__name__
else:
return str(fn)

View File

@ -0,0 +1,101 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from torch.onnx._internal.diagnostics.infra.sarif._address import Address
from torch.onnx._internal.diagnostics.infra.sarif._artifact import Artifact
from torch.onnx._internal.diagnostics.infra.sarif._artifact_change import ArtifactChange
from torch.onnx._internal.diagnostics.infra.sarif._artifact_content import (
ArtifactContent,
)
from torch.onnx._internal.diagnostics.infra.sarif._artifact_location import (
ArtifactLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._attachment import Attachment
from torch.onnx._internal.diagnostics.infra.sarif._code_flow import CodeFlow
from torch.onnx._internal.diagnostics.infra.sarif._configuration_override import (
ConfigurationOverride,
)
from torch.onnx._internal.diagnostics.infra.sarif._conversion import Conversion
from torch.onnx._internal.diagnostics.infra.sarif._edge import Edge
from torch.onnx._internal.diagnostics.infra.sarif._edge_traversal import EdgeTraversal
from torch.onnx._internal.diagnostics.infra.sarif._exception import Exception
from torch.onnx._internal.diagnostics.infra.sarif._external_properties import (
ExternalProperties,
)
from torch.onnx._internal.diagnostics.infra.sarif._external_property_file_reference import (
ExternalPropertyFileReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._external_property_file_references import (
ExternalPropertyFileReferences,
)
from torch.onnx._internal.diagnostics.infra.sarif._fix import Fix
from torch.onnx._internal.diagnostics.infra.sarif._graph import Graph
from torch.onnx._internal.diagnostics.infra.sarif._graph_traversal import GraphTraversal
from torch.onnx._internal.diagnostics.infra.sarif._invocation import Invocation
from torch.onnx._internal.diagnostics.infra.sarif._location import Location
from torch.onnx._internal.diagnostics.infra.sarif._location_relationship import (
LocationRelationship,
)
from torch.onnx._internal.diagnostics.infra.sarif._logical_location import (
LogicalLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._message import Message
from torch.onnx._internal.diagnostics.infra.sarif._multiformat_message_string import (
MultiformatMessageString,
)
from torch.onnx._internal.diagnostics.infra.sarif._node import Node
from torch.onnx._internal.diagnostics.infra.sarif._notification import Notification
from torch.onnx._internal.diagnostics.infra.sarif._physical_location import (
PhysicalLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._property_bag import PropertyBag
from torch.onnx._internal.diagnostics.infra.sarif._rectangle import Rectangle
from torch.onnx._internal.diagnostics.infra.sarif._region import Region
from torch.onnx._internal.diagnostics.infra.sarif._replacement import Replacement
from torch.onnx._internal.diagnostics.infra.sarif._reporting_configuration import (
ReportingConfiguration,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor import (
ReportingDescriptor,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor_reference import (
ReportingDescriptorReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor_relationship import (
ReportingDescriptorRelationship,
)
from torch.onnx._internal.diagnostics.infra.sarif._result import Result
from torch.onnx._internal.diagnostics.infra.sarif._result_provenance import (
ResultProvenance,
)
from torch.onnx._internal.diagnostics.infra.sarif._run import Run
from torch.onnx._internal.diagnostics.infra.sarif._run_automation_details import (
RunAutomationDetails,
)
from torch.onnx._internal.diagnostics.infra.sarif._sarif_log import SarifLog
from torch.onnx._internal.diagnostics.infra.sarif._special_locations import (
SpecialLocations,
)
from torch.onnx._internal.diagnostics.infra.sarif._stack import Stack
from torch.onnx._internal.diagnostics.infra.sarif._stack_frame import StackFrame
from torch.onnx._internal.diagnostics.infra.sarif._suppression import Suppression
from torch.onnx._internal.diagnostics.infra.sarif._thread_flow import ThreadFlow
from torch.onnx._internal.diagnostics.infra.sarif._thread_flow_location import (
ThreadFlowLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._tool import Tool
from torch.onnx._internal.diagnostics.infra.sarif._tool_component import ToolComponent
from torch.onnx._internal.diagnostics.infra.sarif._tool_component_reference import (
ToolComponentReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._translation_metadata import (
TranslationMetadata,
)
from torch.onnx._internal.diagnostics.infra.sarif._version_control_details import (
VersionControlDetails,
)
from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest
from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse
# flake8: noqa

Some files were not shown because too many files have changed in this diff Show More