I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,41 @@
"""Utility to lazily import modules."""
# mypy: allow-untyped-defs
from __future__ import annotations
import importlib
from typing import Any, TYPE_CHECKING
class _LazyModule:
"""Lazily import a module."""
def __init__(self, module_name: str) -> None:
self._name = module_name
self._module: Any = None
def __repr__(self) -> str:
return f"<lazy module '{self._name}'>"
def __getattr__(self, attr):
if self._module is None:
self._module = importlib.import_module(".", self._name)
return getattr(self._module, attr)
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
# NOTE: Add additional used imports here.
if TYPE_CHECKING:
import onnx
import onnxscript
import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
onnxscript_ir = onnxscript.ir
else:
onnx = _LazyModule("onnx")
onnxscript = _LazyModule("onnxscript")
onnxscript_ir = _LazyModule("onnxscript.ir")
onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5")

View File

@ -0,0 +1,22 @@
from ._diagnostic import (
create_export_diagnostic_context,
diagnose,
engine,
export_context,
ExportDiagnosticEngine,
TorchScriptOnnxExportDiagnostic,
)
from ._rules import rules
from .infra import levels
__all__ = [
"TorchScriptOnnxExportDiagnostic",
"ExportDiagnosticEngine",
"rules",
"levels",
"engine",
"export_context",
"create_export_diagnostic_context",
"diagnose",
]

View File

@ -0,0 +1,211 @@
# mypy: allow-untyped-defs
"""Diagnostic components for TorchScript based ONNX export, i.e. `torch.onnx.export`."""
from __future__ import annotations
import contextlib
import gzip
from typing import TYPE_CHECKING
import torch
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
from torch.utils import cpp_backtrace
if TYPE_CHECKING:
from collections.abc import Generator
def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> infra.Stack:
"""Returns the current C++ call stack.
This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack.
The returned C++ call stack is a concatenated string of the C++ call stack frames.
Each frame is separated by a newline character, in the same format of
r"frame #[0-9]+: (?P<frame_info>.*)". More info at `c10/util/Backtrace.cpp`.
"""
frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n")
frame_messages = []
for frame in frames:
segments = frame.split(":", 1)
if len(segments) == 2:
frame_messages.append(segments[1].strip())
else:
frame_messages.append("<unknown frame>")
return infra.Stack(
frames=[
infra.StackFrame(location=infra.Location(message=message))
for message in frame_messages
]
)
class TorchScriptOnnxExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
infra.Diagnostic, and adds additional methods to add more information to the
diagnostic.
"""
python_call_stack: infra.Stack | None = None
cpp_call_stack: infra.Stack | None = None
def __init__(
self,
*args,
frames_to_skip: int = 1,
cpp_stack: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.python_call_stack = self.record_python_call_stack(
frames_to_skip=frames_to_skip
)
if cpp_stack:
self.cpp_call_stack = self.record_cpp_call_stack(
frames_to_skip=frames_to_skip
)
def record_cpp_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current C++ call stack in the diagnostic."""
stack = _cpp_call_stack(frames_to_skip=frames_to_skip)
stack.message = "C++ call stack"
self.with_stack(stack)
return stack
class ExportDiagnosticEngine:
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using `DiagnosticContext` directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
All `diagnose` calls inside exporter should be made in the context of that export.
However, since diagnostic context is currently being accessed via a global variable,
there is no guarantee that the context is properly initialized. Therefore, we need
to provide a default background context to fallback to, otherwise any invocation of
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
This can be removed once the pipeline for context to flow through the exporter is
established.
"""
contexts: list[infra.DiagnosticContext]
_background_context: infra.DiagnosticContext
def __init__(self) -> None:
self.contexts = []
self._background_context = infra.DiagnosticContext(
name="torch.onnx",
version=torch.__version__,
)
@property
def background_context(self) -> infra.DiagnosticContext:
return self._background_context
def create_diagnostic_context(
self,
name: str,
version: str,
options: infra.DiagnosticOptions | None = None,
) -> infra.DiagnosticContext:
"""Creates a new diagnostic context.
Args:
name: The subject name for the diagnostic context.
version: The subject version for the diagnostic context.
options: The options for the diagnostic context.
Returns:
A new diagnostic context.
"""
if options is None:
options = infra.DiagnosticOptions()
context: infra.DiagnosticContext[infra.Diagnostic] = infra.DiagnosticContext(
name, version, options
)
self.contexts.append(context)
return context
def clear(self):
"""Clears all diagnostic contexts."""
self.contexts.clear()
self._background_context.diagnostics.clear()
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())
def sarif_log(self):
log = sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[context.sarif() for context in self.contexts],
)
log.runs.append(self._background_context.sarif())
return log
engine = ExportDiagnosticEngine()
_context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context() -> (
Generator[infra.DiagnosticContext, None, None]
):
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by
export internals via global variable. See `ExportDiagnosticEngine` for more details.
"""
global _context
assert (
_context == engine.background_context
), "Export context is already set. Nested export is not supported."
_context = engine.create_diagnostic_context(
"torch.onnx.export",
torch.__version__,
)
try:
yield _context
finally:
_context = engine.background_context
def diagnose(
rule: infra.Rule,
level: infra.Level,
message: str | None = None,
frames_to_skip: int = 2,
**kwargs,
) -> TorchScriptOnnxExportDiagnostic:
"""Creates a diagnostic and record it in the global diagnostic context.
This is a wrapper around `context.log` that uses the global diagnostic
context.
"""
diagnostic = TorchScriptOnnxExportDiagnostic(
rule, level, message, frames_to_skip=frames_to_skip, **kwargs
)
export_context().log(diagnostic)
return diagnostic
def export_context() -> infra.DiagnosticContext:
global _context
return _context

View File

@ -0,0 +1,636 @@
# mypy: allow-untyped-defs
"""
GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py.
See tools/onnx/gen_diagnostics.py for more information.
Diagnostic rules for PyTorch ONNX export.
"""
import dataclasses
from typing import Tuple
# flake8: noqa
from torch.onnx._internal.diagnostics import infra
"""
GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message`
method to provide more details in the signature about the format arguments.
"""
class _NodeMissingOnnxShapeInference(infra.Rule):
"""Node is missing ONNX shape inference."""
def format_message(self, op_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
"""
return self.message_default_template.format(op_name=op_name)
def format( # type: ignore[override]
self, level: infra.Level, op_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.'
"""
return self, level, self.format_message(op_name=op_name)
class _MissingCustomSymbolicFunction(infra.Rule):
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
def format_message(self, op_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
"""
return self.message_default_template.format(op_name=op_name)
def format( # type: ignore[override]
self, level: infra.Level, op_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.'
"""
return self, level, self.format_message(op_name=op_name)
class _MissingStandardSymbolicFunction(infra.Rule):
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
def format_message( # type: ignore[override]
self, op_name, opset_version, issue_url
) -> str:
"""Returns the formatted default message of this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
"""
return self.message_default_template.format(
op_name=op_name, opset_version=opset_version, issue_url=issue_url
)
def format( # type: ignore[override]
self, level: infra.Level, op_name, opset_version, issue_url
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
"""
return (
self,
level,
self.format_message(
op_name=op_name, opset_version=opset_version, issue_url=issue_url
),
)
class _OperatorSupportedInNewerOpsetVersion(infra.Rule):
"""Operator is supported in newer opset version."""
def format_message( # type: ignore[override]
self, op_name, opset_version, supported_opset_version
) -> str:
"""Returns the formatted default message of this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
"""
return self.message_default_template.format(
op_name=op_name,
opset_version=opset_version,
supported_opset_version=supported_opset_version,
)
def format( # type: ignore[override]
self, level: infra.Level, op_name, opset_version, supported_opset_version
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
"""
return (
self,
level,
self.format_message(
op_name=op_name,
opset_version=opset_version,
supported_opset_version=supported_opset_version,
),
)
class _FxGraphToOnnx(infra.Rule):
"""Transforms graph from FX IR to ONNX IR."""
def format_message(self, graph_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
"""
return self.message_default_template.format(graph_name=graph_name)
def format( # type: ignore[override]
self, level: infra.Level, graph_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Transforming FX graph {graph_name} to ONNX graph.'
"""
return self, level, self.format_message(graph_name=graph_name)
class _FxNodeToOnnx(infra.Rule):
"""Transforms an FX node to an ONNX node."""
def format_message(self, node_repr) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Transforming FX node {node_repr} to ONNX node.'
"""
return self.message_default_template.format(node_repr=node_repr)
def format( # type: ignore[override]
self, level: infra.Level, node_repr
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Transforming FX node {node_repr} to ONNX node.'
"""
return self, level, self.format_message(node_repr=node_repr)
class _FxPass(infra.Rule):
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
def format_message(self, pass_name) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Running {pass_name} pass.'
"""
return self.message_default_template.format(pass_name=pass_name)
def format( # type: ignore[override]
self, level: infra.Level, pass_name
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Running {pass_name} pass.'
"""
return self, level, self.format_message(pass_name=pass_name)
class _NoSymbolicFunctionForCallFunction(infra.Rule):
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
def format_message(self, target) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
"""
return self.message_default_template.format(target=target)
def format( # type: ignore[override]
self, level: infra.Level, target
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'No symbolic function to convert the "call_function" node {target} to ONNX. '
"""
return self, level, self.format_message(target=target)
class _UnsupportedFxNodeAnalysis(infra.Rule):
"""Result from FX graph analysis to reveal unsupported FX nodes."""
def format_message( # type: ignore[override]
self, node_op_to_target_mapping
) -> str:
"""Returns the formatted default message of this Rule.
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
"""
return self.message_default_template.format(
node_op_to_target_mapping=node_op_to_target_mapping
)
def format( # type: ignore[override]
self, level: infra.Level, node_op_to_target_mapping
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Unsupported FX nodes: {node_op_to_target_mapping}. '
"""
return (
self,
level,
self.format_message(node_op_to_target_mapping=node_op_to_target_mapping),
)
class _OpLevelDebugging(infra.Rule):
"""Report any op level validation failure in warnings."""
def format_message(self, node, symbolic_fn) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
"""
return self.message_default_template.format(node=node, symbolic_fn=symbolic_fn)
def format( # type: ignore[override]
self, level: infra.Level, node, symbolic_fn
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation.'
"""
return self, level, self.format_message(node=node, symbolic_fn=symbolic_fn)
class _FindOpschemaMatchedSymbolicFunction(infra.Rule):
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
def format_message(self, symbolic_fn, node) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
"""
return self.message_default_template.format(symbolic_fn=symbolic_fn, node=node)
def format( # type: ignore[override]
self, level: infra.Level, symbolic_fn, node
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}.'
"""
return self, level, self.format_message(symbolic_fn=symbolic_fn, node=node)
class _FxNodeInsertTypePromotion(infra.Rule):
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
def format_message(self, target) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Performing explicit type promotion for node {target}. '
"""
return self.message_default_template.format(target=target)
def format( # type: ignore[override]
self, level: infra.Level, target
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Performing explicit type promotion for node {target}. '
"""
return self, level, self.format_message(target=target)
class _FindOperatorOverloadsInOnnxRegistry(infra.Rule):
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
def format_message(self, node) -> str: # type: ignore[override]
"""Returns the formatted default message of this Rule.
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
"""
return self.message_default_template.format(node=node)
def format( # type: ignore[override]
self, level: infra.Level, node
) -> Tuple[infra.Rule, infra.Level, str]:
"""Returns a tuple of (Rule, Level, message) for this Rule.
Message template: 'Checking if the FX node: {node} is supported in onnx registry.'
"""
return self, level, self.format_message(node=node)
@dataclasses.dataclass
class _POERules(infra.RuleCollection):
node_missing_onnx_shape_inference: _NodeMissingOnnxShapeInference = dataclasses.field(
default=_NodeMissingOnnxShapeInference.from_sarif(
**{
"id": "POE0001",
"name": "node-missing-onnx-shape-inference",
"short_description": {"text": "Node is missing ONNX shape inference."},
"full_description": {
"text": "Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec.",
"markdown": "Node is missing ONNX shape inference.\nThis usually happens when the node is not valid under standard ONNX operator spec.\n",
},
"message_strings": {
"default": {
"text": "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Node is missing ONNX shape inference."""
missing_custom_symbolic_function: _MissingCustomSymbolicFunction = dataclasses.field(
default=_MissingCustomSymbolicFunction.from_sarif(
**{
"id": "POE0002",
"name": "missing-custom-symbolic-function",
"short_description": {
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."
},
"full_description": {
"text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.",
"markdown": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.\n",
},
"message_strings": {
"default": {
"text": "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX."""
missing_standard_symbolic_function: _MissingStandardSymbolicFunction = dataclasses.field(
default=_MissingStandardSymbolicFunction.from_sarif(
**{
"id": "POE0003",
"name": "missing-standard-symbolic-function",
"short_description": {
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."
},
"full_description": {
"text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.",
"markdown": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.\n",
},
"message_strings": {
"default": {
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX."""
operator_supported_in_newer_opset_version: _OperatorSupportedInNewerOpsetVersion = dataclasses.field(
default=_OperatorSupportedInNewerOpsetVersion.from_sarif(
**{
"id": "POE0004",
"name": "operator-supported-in-newer-opset-version",
"short_description": {
"text": "Operator is supported in newer opset version."
},
"full_description": {
"text": "Operator is supported in newer opset version.",
"markdown": "Operator is supported in newer opset version.\n\nExample:\n```python\ntorch.onnx.export(model, args, ..., opset_version=9)\n```\n",
},
"message_strings": {
"default": {
"text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Operator is supported in newer opset version."""
fx_graph_to_onnx: _FxGraphToOnnx = dataclasses.field(
default=_FxGraphToOnnx.from_sarif(
**{
"id": "FXE0007",
"name": "fx-graph-to-onnx",
"short_description": {
"text": "Transforms graph from FX IR to ONNX IR."
},
"full_description": {
"text": "Transforms graph from FX IR to ONNX IR.",
"markdown": "This diagnostic tracks the transformation process from an FX Graph (in FX IR) to an ONNX Graph (in ONNX IR).\n\n## Key Representations:\n\n- **FX Graph**: The graph in FX IR produced by dynamo or symbolic tracing.\n- **ONNX Graph**: The graph in ONNX IR and [operators](https://onnx.ai/onnx/operators/).\n\n## Additional Notes:\n\n- Prior to this transformation step, the FX graph undergoes preprocessing through multiple FX passes.\n To gain insight into these transformations, refer to diagnostic `FXE0010`.\n- To enable a detailed view of the graph transformation in progress within this diagnostic, switch to the DEBUG mode.\n\n - Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n - Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\n- For specific information related to node-level FX to ONNX transformations, explore the diagnostic `FXE0008`.\n",
},
"message_strings": {
"default": {
"text": "Transforming FX graph {graph_name} to ONNX graph."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Transforms graph from FX IR to ONNX IR."""
fx_node_to_onnx: _FxNodeToOnnx = dataclasses.field(
default=_FxNodeToOnnx.from_sarif(
**{
"id": "FXE0008",
"name": "fx-node-to-onnx",
"short_description": {"text": "Transforms an FX node to an ONNX node."},
"full_description": {
"text": "Transforms an FX node to an ONNX node.",
"markdown": "This diagnostic tracks the transformation process from an FX Node to ONNX [Operators](https://onnx.ai/onnx/operators/).\n\nThe process of converting FX Node to ONNX Node involves dealing with six distinct node types:\n 1. `placeholder`: Represents a module input, maps to an ONNX graph input.\n 2. `call_module`: Symbolizes a call to a submodule, maps to an ONNX\n 3. `call_method`: Symbolizes a method call. Not yet implemented.\n 4. `call_function`: Symbolizes a function call. [Core ATen](https://pytorch.org/docs/stable/ir.html#core-aten-ir) is expected\n as the function call target. The mapping from ATen to ONNX is implemented by [ONNXScript torchlib](https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops).\n This [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) shows how to write and register a custom symbolic function for call_function FX node.\n 5. `get_attr`: Indicates an attribute access within the current module. Maps to an ONNX graph initializer.\n 6. `output`: Represents the module's output. Maps to an ONNX graph output.\n\nFor a granular understanding of how each node type is transformed, refer to the implementation details in `FxOnnxInterpreter`.\n",
},
"message_strings": {
"default": {
"text": "Transforming FX node {node_repr} to ONNX node."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Transforms an FX node to an ONNX node."""
fx_pass: _FxPass = dataclasses.field(
default=_FxPass.from_sarif(
**{
"id": "FXE0010",
"name": "fx-pass",
"short_description": {
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
},
"full_description": {
"text": "FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
"markdown": "This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n",
},
"message_strings": {"default": {"text": "Running {pass_name} pass."}},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""FX graph transformation during ONNX export before converting from FX IR to ONNX IR."""
no_symbolic_function_for_call_function: _NoSymbolicFunctionForCallFunction = dataclasses.field(
default=_NoSymbolicFunctionForCallFunction.from_sarif(
**{
"id": "FXE0011",
"name": "no-symbolic-function-for-call-function",
"short_description": {
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX.'
},
"full_description": {
"text": 'Cannot find symbolic function to convert the "call_function" FX node to ONNX. ',
"markdown": 'This error occurs when the ONNX converter is unable to find a corresponding symbolic function\nto convert a "call_function" node in the input graph to its equivalence in ONNX. The "call_function"\nnode represents a normalized function call in PyTorch, such as "torch.aten.ops.add".\n\nTo resolve this error, you can try one of the following:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html#overview) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n',
},
"message_strings": {
"default": {
"text": 'No symbolic function to convert the "call_function" node {target} to ONNX. '
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Cannot find symbolic function to convert the "call_function" FX node to ONNX."""
unsupported_fx_node_analysis: _UnsupportedFxNodeAnalysis = dataclasses.field(
default=_UnsupportedFxNodeAnalysis.from_sarif(
**{
"id": "FXE0012",
"name": "unsupported-fx-node-analysis",
"short_description": {
"text": "Result from FX graph analysis to reveal unsupported FX nodes."
},
"full_description": {
"text": "Result from FX graph analysis to reveal unsupported FX nodes.",
"markdown": "This error indicates that an FX graph contains one or more unsupported nodes. The error message\nis typically accompanied by a list of the unsupported nodes found during analysis.\n\nTo resolve this error, you can try resolving each individual unsupported node error by following\nthe suggestions by its diagnostic. Typically, options include:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) to write and\n register a custom symbolic function for the unsupported call_function FX node.\n",
},
"message_strings": {
"default": {
"text": "Unsupported FX nodes: {node_op_to_target_mapping}. "
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Result from FX graph analysis to reveal unsupported FX nodes."""
op_level_debugging: _OpLevelDebugging = dataclasses.field(
default=_OpLevelDebugging.from_sarif(
**{
"id": "FXE0013",
"name": "op-level-debugging",
"short_description": {
"text": "Report any op level validation failure in warnings."
},
"full_description": {
"text": "Report any op level validation failure in warnings.",
"markdown": "This warning message indicates that during op level debugging, certain symbolic functions\nhave failed to match the results of torch ops when using real tensors generated from fake\ntensors. It is important to note that the symbolic functions may not necessarily be\nincorrect, as the validation process is non-deterministic and should only be used as a\nreference.\n\nThere are two categories of warnings that can be triggered:\n\n1. Non-validated operators:\n If the warnings are caused by the following errors, they can be disregarded by users,\n as these errors occur due to the non-deterministic nature of the validation. However,\n it is important to be aware that the operators have not been validated.\n\n - IndexError: Unsupported input arguments of randomized dimensions/indices(INT64).\n - RuntimeError: Unsupported input arguments for torch ops are generated.\n - ValueError: Arguments/keyword arguments do not match the signature of the symbolic function.\n\n2. Potentially wrong torchlib operators:\n If the warnings are triggered by the following error, users should be aware that the symbolic functions\n may be incorrect in dispatching or implementation. In such cases, it is recommended to report\n the issue to the PyTorch-ONNX team, or create/register a custom symbolic function to replace the default one.\n\n - AssertionError: The symbolic function is potentially wrong as the results do not match the results of torch ops.\n - TypeError: The symbolic function is potentially wrong as the opschema doesn't match inputs.\n",
},
"message_strings": {
"default": {
"text": "FX node: {node} and its onnx function: {symbolic_fn} fails on op level validation."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Report any op level validation failure in warnings."""
find_opschema_matched_symbolic_function: _FindOpschemaMatchedSymbolicFunction = dataclasses.field(
default=_FindOpschemaMatchedSymbolicFunction.from_sarif(
**{
"id": "FXE0014",
"name": "find-opschema-matched-symbolic-function",
"short_description": {
"text": "Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."
},
"full_description": {
"text": "Find the OnnxFunction that matches the input dtypes by comparing them with their opschemas. A warning will be issued if the matched OnnxFunction is not an exact match.",
"markdown": "When an ATen/Custom operator is registered and needs to be dispatched to an OnnxFunction, the input/attribute\ndtypes of the ATen/Custom operator are compared with the input/attribute dtypes of the OnnxFunction opschemas\nto find a match. However, if a perfect/exact match is not found, the dispatcher will attempt to find\nthe nearest match with the highest number of input/attribute dtypes matching the OnnxFunction opschemas, while\nissuing a warning.\n\nThere are two types of level that can be triggered in this rule:\n\n1. NOTE: A perfect match is found, and no warning is issued.\n2. WARNING: The matched OnnxFunction is not a perfect/exact match.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning,\n as the definition of OnnxFunction schema is usually more stringent.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the issue to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n",
},
"message_strings": {
"default": {
"text": "The OnnxFunction: {symbolic_fn} is the nearest match of the node {node}."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Find the OnnxFunction that matches the input/attribute dtypes by comparing them with their opschemas."""
fx_node_insert_type_promotion: _FxNodeInsertTypePromotion = dataclasses.field(
default=_FxNodeInsertTypePromotion.from_sarif(
**{
"id": "FXE0015",
"name": "fx-node-insert-type-promotion",
"short_description": {
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed."
},
"full_description": {
"text": "Determine if type promotion is required for the FX node. Insert cast nodes if needed.",
"markdown": "This diagnostic monitors the node-level type promotion insertion process. In PyTorch, there is an automatic process called implicit type promotion,\nwhere the input types of an operator are promoted to a common type. The determination of the common type is based on the type promotion rule specific to each operator.\nTo learn more about PyTorch's type promotion rules, refer to the [elementwise_dtypes doc](https://github.com/pytorch/pytorch/blob/f044613f78df713fb57f70c608483c9f10ad332e/torch/_prims_common/__init__.py#L1252-L1335)\nand [torch._refs ops](https://github.com/pytorch/pytorch/blob/a475ea4542dfe961c9d097e33ab5041f61c8c17f/torch/_refs/__init__.py#L484).\n\nHowever, implicit type promotion is not supported in ONNX. Therefore, to replicate the PyTorch behavior, we need to explicitly insert cast nodes.\nThis diagnostic tracks the process of node-level type promotion insertion.\n\nThe type promotion rules used by this process can be found in `torch/onnx/_internal/fx/passes/type_promotion.py.`\nTo update or add new type promotion rules, please refer to the [Note: Update type promotion rule] section.\n",
},
"message_strings": {
"default": {
"text": "Performing explicit type promotion for node {target}. "
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Determine if type promotion is required for the FX node. Insert cast nodes if needed."""
find_operator_overloads_in_onnx_registry: _FindOperatorOverloadsInOnnxRegistry = dataclasses.field(
default=_FindOperatorOverloadsInOnnxRegistry.from_sarif(
**{
"id": "FXE0016",
"name": "find-operator-overloads-in-onnx-registry",
"short_description": {
"text": "Find the list of OnnxFunction of the PyTorch operator in onnx registry."
},
"full_description": {
"text": "This rule involves finding the list of OnnxFunction for the PyTorch operator overload in the ONNX registry. If the operator overload is not supported but its default overload is, a warning will be issued. If both the operator overload and its default overload are not supported, an error will be issued.",
"markdown": "The operator overload name serves the purpose of verifying whether a PyTorch operator is registered in the ONNX registry.\nIf it's not found, the dispatcher takes a fallback approach and tries to locate the default overload of the PyTorch\noperator in the registry. If even the default overload is absent, it signifies that the operator is officially unsupported.\n\nThere are three types of level that can be triggered in this rule:\n\n1. NOTE: The op overload is supported.\n2. WARNING: The op overload is not supported, but it's default overload is supported.\n3. ERROR: The op overload is not supported, and it's default overload is also not supported.\n\nHere are some suggestions based on the WARNING situation:\n\n1. If there are NO errors or mismatches in the results, it is safe to disregard this warning.\n2. If there are errors or mismatches in the results, it is recommended to:\n (a) Enable op_level_debugging to determine if the OnnxFunction might be incorrect.\n (b) Report the unsupported overload to the PyTorch-ONNX team.\n (c) Create/register a custom symbolic function to replace the default one.\n\nHere are some suggestions based on the ERROR situation:\n\n1. Report the unsupported operator to the PyTorch-ONNX team.\n2. Create/register a custom symbolic function to replace the default one.\n",
},
"message_strings": {
"default": {
"text": "Checking if the FX node: {node} is supported in onnx registry."
}
},
"help_uri": None,
"properties": {"deprecated": False, "tags": []},
}
),
init=False,
)
"""Find the list of OnnxFunction of the PyTorch operator in onnx registry."""
rules = _POERules()

View File

@ -0,0 +1,34 @@
from ._infra import (
DiagnosticOptions,
Graph,
Invocation,
Level,
levels,
Location,
Rule,
RuleCollection,
Stack,
StackFrame,
Tag,
ThreadFlowLocation,
)
from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic
__all__ = [
"Diagnostic",
"DiagnosticContext",
"DiagnosticOptions",
"Graph",
"Invocation",
"Level",
"levels",
"Location",
"Rule",
"RuleCollection",
"RuntimeErrorWithDiagnostic",
"Stack",
"StackFrame",
"Tag",
"ThreadFlowLocation",
]

View File

@ -0,0 +1,285 @@
# mypy: allow-untyped-defs
"""This file defines an additional layer of abstraction on top of the SARIF OM."""
from __future__ import annotations
import dataclasses
import enum
import logging
from typing import Mapping, Sequence
from torch.onnx._internal.diagnostics.infra import formatter, sarif
class Level(enum.IntEnum):
"""The level of a diagnostic.
This class is used to represent the level of a diagnostic. The levels are defined
by the SARIF specification, and are not modifiable. For alternative categories,
please use infra.Tag instead. When selecting a level, please consider the following
guidelines:
- NONE: Informational result that does not indicate the presence of a problem.
- NOTE: An opportunity for improvement was found.
- WARNING: A potential problem was found.
- ERROR: A serious problem was found.
This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
value maps to the logging levels in Python's logging module. The mapping is as
follows:
Level.NONE = logging.DEBUG = 10
Level.NOTE = logging.INFO = 20
Level.WARNING = logging.WARNING = 30
Level.ERROR = logging.ERROR = 40
"""
NONE = 10
NOTE = 20
WARNING = 30
ERROR = 40
levels = Level
class Tag(enum.Enum):
"""The tag of a diagnostic. This class can be inherited to define custom tags."""
class PatchedPropertyBag(sarif.PropertyBag):
"""Key/value pairs that provide additional information about the object.
The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6)
containing an unordered set of properties with arbitrary names." However it is not
reflected in the json file, and therefore not captured by the python representation.
This patch adds additional **kwargs to the `__init__` method to allow recording
arbitrary key/value pairs.
"""
def __init__(self, tags: list[str] | None = None, **kwargs):
super().__init__(tags=tags)
self.__dict__.update(kwargs)
@dataclasses.dataclass(frozen=True)
class Rule:
id: str
name: str
message_default_template: str
short_description: str | None = None
full_description: str | None = None
full_description_markdown: str | None = None
help_uri: str | None = None
@classmethod
def from_sarif(cls, **kwargs):
"""Returns a rule from the SARIF reporting descriptor."""
short_description = kwargs.get("short_description", {}).get("text")
full_description = kwargs.get("full_description", {}).get("text")
full_description_markdown = kwargs.get("full_description", {}).get("markdown")
help_uri = kwargs.get("help_uri")
rule = cls(
id=kwargs["id"],
name=kwargs["name"],
message_default_template=kwargs["message_strings"]["default"]["text"],
short_description=short_description,
full_description=full_description,
full_description_markdown=full_description_markdown,
help_uri=help_uri,
)
return rule
def sarif(self) -> sarif.ReportingDescriptor:
"""Returns a SARIF reporting descriptor of this Rule."""
short_description = (
sarif.MultiformatMessageString(text=self.short_description)
if self.short_description is not None
else None
)
full_description = (
sarif.MultiformatMessageString(
text=self.full_description, markdown=self.full_description_markdown
)
if self.full_description is not None
else None
)
return sarif.ReportingDescriptor(
id=self.id,
name=self.name,
short_description=short_description,
full_description=full_description,
help_uri=self.help_uri,
)
def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]:
"""Returns a tuple of (rule, level, message) for a diagnostic.
This method is used to format the message of a diagnostic. The message is
formatted using the default template of this rule, and the arguments passed in
as `*args` and `**kwargs`. The level is used to override the default level of
this rule.
"""
return (self, level, self.format_message(*args, **kwargs))
def format_message(self, *args, **kwargs) -> str:
"""Returns the formatted default message of this Rule.
This method should be overridden (with code generation) by subclasses to reflect
the exact arguments needed by the message template. This is a helper method to
create the default message for a diagnostic.
"""
return self.message_default_template.format(*args, **kwargs)
@dataclasses.dataclass
class Location:
uri: str | None = None
line: int | None = None
message: str | None = None
start_column: int | None = None
end_column: int | None = None
snippet: str | None = None
function: str | None = None
def sarif(self) -> sarif.Location:
"""Returns the SARIF representation of this location."""
return sarif.Location(
physical_location=sarif.PhysicalLocation(
artifact_location=sarif.ArtifactLocation(uri=self.uri),
region=sarif.Region(
start_line=self.line,
start_column=self.start_column,
end_column=self.end_column,
snippet=sarif.ArtifactContent(text=self.snippet),
),
),
message=sarif.Message(text=self.message)
if self.message is not None
else None,
)
@dataclasses.dataclass
class StackFrame:
location: Location
def sarif(self) -> sarif.StackFrame:
"""Returns the SARIF representation of this stack frame."""
return sarif.StackFrame(location=self.location.sarif())
@dataclasses.dataclass
class Stack:
"""Records a stack trace. The frames are in order from newest to oldest stack frame."""
frames: list[StackFrame] = dataclasses.field(default_factory=list)
message: str | None = None
def sarif(self) -> sarif.Stack:
"""Returns the SARIF representation of this stack."""
return sarif.Stack(
frames=[frame.sarif() for frame in self.frames],
message=sarif.Message(text=self.message)
if self.message is not None
else None,
)
@dataclasses.dataclass
class ThreadFlowLocation:
"""Records code location and the initial state."""
location: Location
state: Mapping[str, str]
index: int
stack: Stack | None = None
def sarif(self) -> sarif.ThreadFlowLocation:
"""Returns the SARIF representation of this thread flow location."""
return sarif.ThreadFlowLocation(
location=self.location.sarif(),
state=self.state,
stack=self.stack.sarif() if self.stack is not None else None,
)
@dataclasses.dataclass
class Graph:
"""A graph of diagnostics.
This class stores the string representation of a model graph.
The `nodes` and `edges` fields are unused in the current implementation.
"""
graph: str
name: str
description: str | None = None
def sarif(self) -> sarif.Graph:
"""Returns the SARIF representation of this graph."""
return sarif.Graph(
description=sarif.Message(text=self.graph),
properties=PatchedPropertyBag(name=self.name, description=self.description),
)
@dataclasses.dataclass
class RuleCollection:
_rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False)
def __post_init__(self) -> None:
self._rule_id_name_set = frozenset(
{
(field.default.id, field.default.name)
for field in dataclasses.fields(self)
if isinstance(field.default, Rule)
}
)
def __contains__(self, rule: Rule) -> bool:
"""Checks if the rule is in the collection."""
return (rule.id, rule.name) in self._rule_id_name_set
@classmethod
def custom_collection_from_list(
cls, new_collection_class_name: str, rules: Sequence[Rule]
) -> RuleCollection:
"""Creates a custom class inherited from RuleCollection with the list of rules."""
return dataclasses.make_dataclass(
new_collection_class_name,
[
(
formatter.kebab_case_to_snake_case(rule.name),
type(rule),
dataclasses.field(default=rule),
)
for rule in rules
],
bases=(cls,),
)()
class Invocation:
# TODO: Implement this.
# Tracks top level call arguments and diagnostic options.
def __init__(self) -> None:
raise NotImplementedError
@dataclasses.dataclass
class DiagnosticOptions:
"""Options for diagnostic context.
Attributes:
verbosity_level: Set the amount of information logged for each diagnostics,
equivalent to the 'level' in Python logging module.
warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
"""
verbosity_level: int = dataclasses.field(default=logging.INFO)
"""Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
warnings_as_errors: bool = dataclasses.field(default=False)
"""If True, warning diagnostics are treated as error diagnostics."""

View File

@ -0,0 +1,404 @@
# mypy: allow-untyped-defs
"""A diagnostic context based on SARIF."""
from __future__ import annotations
import contextlib
import dataclasses
import gzip
import logging
from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
from typing_extensions import Self
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
# This is a workaround for mypy not supporting Self from typing_extensions.
_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
diagnostic_logger: logging.Logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Diagnostic:
rule: infra.Rule
level: infra.Level
message: str | None = None
locations: list[infra.Location] = dataclasses.field(default_factory=list)
stacks: list[infra.Stack] = dataclasses.field(default_factory=list)
graphs: list[infra.Graph] = dataclasses.field(default_factory=list)
thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field(
default_factory=list
)
additional_messages: list[str] = dataclasses.field(default_factory=list)
tags: list[infra.Tag] = dataclasses.field(default_factory=list)
source_exception: Exception | None = None
"""The exception that caused this diagnostic to be created."""
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
"""The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
log level setting with `DiagnosticOptions.verbosity_level`."""
_current_log_section_depth: int = 0
def __post_init__(self) -> None:
pass
def sarif(self) -> sarif.Result:
"""Returns the SARIF Result representation of this diagnostic."""
message = self.message or self.rule.message_default_template
if self.additional_messages:
additional_message = "\n".join(self.additional_messages)
message_markdown = (
f"{message}\n\n## Additional Message:\n\n{additional_message}"
)
else:
message_markdown = message
kind: Literal["informational", "fail"] = (
"informational" if self.level == infra.Level.NONE else "fail"
)
sarif_result = sarif.Result(
message=sarif.Message(text=message, markdown=message_markdown),
level=self.level.name.lower(), # type: ignore[arg-type]
rule_id=self.rule.id,
kind=kind,
)
sarif_result.locations = [location.sarif() for location in self.locations]
sarif_result.stacks = [stack.sarif() for stack in self.stacks]
sarif_result.graphs = [graph.sarif() for graph in self.graphs]
sarif_result.code_flows = [
sarif.CodeFlow(
thread_flows=[
sarif.ThreadFlow(
locations=[loc.sarif() for loc in self.thread_flow_locations]
)
]
)
]
sarif_result.properties = sarif.PropertyBag(
tags=[tag.value for tag in self.tags]
)
return sarif_result
def with_location(self: Self, location: infra.Location) -> Self:
"""Adds a location to the diagnostic."""
self.locations.append(location)
return self
def with_thread_flow_location(
self: Self, location: infra.ThreadFlowLocation
) -> Self:
"""Adds a thread flow location to the diagnostic."""
self.thread_flow_locations.append(location)
return self
def with_stack(self: Self, stack: infra.Stack) -> Self:
"""Adds a stack to the diagnostic."""
self.stacks.append(stack)
return self
def with_graph(self: Self, graph: infra.Graph) -> Self:
"""Adds a graph to the diagnostic."""
self.graphs.append(graph)
return self
@contextlib.contextmanager
def log_section(
self, level: int, message: str, *args, **kwargs
) -> Generator[None, None, None]:
"""
Context manager for a section of log messages, denoted by a title message and increased indentation.
Same api as `logging.Logger.log`.
This context manager logs the given title at the specified log level, increases the current
section depth for subsequent log messages, and ensures that the section depth is decreased
again when exiting the context.
Args:
level: The log level.
message: The title message to log.
*args: The arguments to the message. Use `LazyString` to defer the
expensive evaluation of the arguments until the message is actually logged.
**kwargs: The keyword arguments for `logging.Logger.log`.
Yields:
None: This context manager does not yield any value.
Example:
>>> with DiagnosticContext("DummyContext", "1.0"):
... rule = infra.Rule("RuleID", "DummyRule", "Rule message")
... diagnostic = Diagnostic(rule, infra.Level.WARNING)
... with diagnostic.log_section(logging.INFO, "My Section"):
... diagnostic.log(logging.INFO, "My Message")
... with diagnostic.log_section(logging.INFO, "My Subsection"):
... diagnostic.log(logging.INFO, "My Submessage")
... diagnostic.additional_messages
['## My Section', 'My Message', '### My Subsection', 'My Submessage']
"""
if self.logger.isEnabledFor(level):
indented_format_message = (
f"##{'#' * self._current_log_section_depth } {message}"
)
self.log(
level,
indented_format_message,
*args,
**kwargs,
)
self._current_log_section_depth += 1
try:
yield
finally:
self._current_log_section_depth -= 1
def log(self, level: int, message: str, *args, **kwargs) -> None:
"""Logs a message within the diagnostic. Same api as `logging.Logger.log`.
If logger is not enabled for the given level, the message will not be logged.
Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default,
the log level recommendations are as follows. If you've set a different default verbosity level in your
application, please adjust accordingly:
- logging.ERROR: Log any events leading to application failure.
- logging.WARNING: Log events that might result in application issues or failures, although not guaranteed.
- logging.INFO: Log general useful information, ensuring minimal performance overhead.
- logging.DEBUG: Log detailed debug information, which might affect performance when logged.
Args:
level: The log level.
message: The message to log.
*args: The arguments to the message. Use `LazyString` to defer the
expensive evaluation of the arguments until the message is actually logged.
**kwargs: The keyword arguments for `logging.Logger.log`.
"""
if self.logger.isEnabledFor(level):
formatted_message = message % args
self.logger.log(level, formatted_message, **kwargs)
self.additional_messages.append(formatted_message)
def debug(self, message: str, *args, **kwargs) -> None:
"""Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
Checkout `log` for more details.
"""
self.log(logging.DEBUG, message, *args, **kwargs)
def info(self, message: str, *args, **kwargs) -> None:
"""Logs an info message within the diagnostic. Same api as logging.Logger.info.
Checkout `log` for more details.
"""
self.log(logging.INFO, message, *args, **kwargs)
def warning(self, message: str, *args, **kwargs) -> None:
"""Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
Checkout `log` for more details.
"""
self.log(logging.WARNING, message, *args, **kwargs)
def error(self, message: str, *args, **kwargs) -> None:
"""Logs an error message within the diagnostic. Same api as logging.Logger.error.
Checkout `log` for more details.
"""
self.log(logging.ERROR, message, *args, **kwargs)
def log_source_exception(self, level: int, exception: Exception) -> None:
"""Logs a source exception within the diagnostic.
Invokes `log_section` and `log` to log the exception in markdown section format.
"""
self.source_exception = exception
with self.log_section(level, "Exception log"):
self.log(level, "%s", formatter.lazy_format_exception(exception))
def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
"""Records the current Python call stack."""
frames_to_skip += 1 # Skip this function.
stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
self.with_stack(stack)
if len(stack.frames) > 0:
self.with_location(stack.frames[0].location)
return stack
def record_python_call(
self,
fn: Callable,
state: Mapping[str, str],
message: str | None = None,
frames_to_skip: int = 0,
) -> infra.ThreadFlowLocation:
"""Records a python call as one thread flow step."""
frames_to_skip += 1 # Skip this function.
stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
location = utils.function_location(fn)
location.message = message
# Add function location to the top of the stack.
stack.frames.insert(0, infra.StackFrame(location=location))
thread_flow_location = infra.ThreadFlowLocation(
location=location,
state=state,
index=len(self.thread_flow_locations),
stack=stack,
)
self.with_thread_flow_location(thread_flow_location)
return thread_flow_location
class RuntimeErrorWithDiagnostic(RuntimeError):
"""Runtime error with enclosed diagnostic information."""
def __init__(self, diagnostic: Diagnostic):
super().__init__(diagnostic.message)
self.diagnostic = diagnostic
@dataclasses.dataclass
class DiagnosticContext(Generic[_Diagnostic]):
name: str
version: str
options: infra.DiagnosticOptions = dataclasses.field(
default_factory=infra.DiagnosticOptions
)
diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
# TODO(bowbao): Implement this.
# _invocation: infra.Invocation = dataclasses.field(init=False)
_inflight_diagnostics: list[_Diagnostic] = dataclasses.field(
init=False, default_factory=list
)
_previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
_bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic)
def __enter__(self):
self._previous_log_level = self.logger.level
self.logger.setLevel(self.options.verbosity_level)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.setLevel(self._previous_log_level)
return None
def sarif(self) -> sarif.Run:
"""Returns the SARIF Run object."""
unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
return sarif.Run(
sarif.Tool(
driver=sarif.ToolComponent(
name=self.name,
version=self.version,
rules=[rule.sarif() for rule in unique_rules],
)
),
results=[diagnostic.sarif() for diagnostic in self.diagnostics],
)
def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined]
"""Returns the SARIF Log object."""
return sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[self.sarif()],
)
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def dump(self, file_path: str, compress: bool = False) -> None:
"""Dumps the SARIF log to a file."""
if compress:
with gzip.open(file_path, "wt") as f:
f.write(self.to_json())
else:
with open(file_path, "w") as f:
f.write(self.to_json())
def log(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Args:
diagnostic: The diagnostic to add.
"""
if not isinstance(diagnostic, self._bound_diagnostic_type):
raise TypeError(
f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
)
if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING: # type: ignore[attr-defined]
diagnostic.level = infra.Level.ERROR # type: ignore[attr-defined]
self.diagnostics.append(diagnostic) # type: ignore[arg-type]
def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
"""Logs a diagnostic and raises an exception if it is an error.
Use this method for logging non inflight diagnostics where diagnostic level is not known or
lower than ERROR. If it is always expected raise, use `log` and explicit
`raise` instead. Otherwise there is no way to convey the message that it always
raises to Python intellisense and type checking tools.
This method should be used only after all the necessary information for the diagnostic
has been collected.
Args:
diagnostic: The diagnostic to add.
"""
self.log(diagnostic)
if diagnostic.level == infra.Level.ERROR:
if diagnostic.source_exception is not None:
raise diagnostic.source_exception
raise RuntimeErrorWithDiagnostic(diagnostic)
@contextlib.contextmanager
def add_inflight_diagnostic(
self, diagnostic: _Diagnostic
) -> Generator[_Diagnostic, None, None]:
"""Adds a diagnostic to the context.
Use this method to add diagnostics that are not created by the context.
Args:
diagnostic: The diagnostic to add.
"""
self._inflight_diagnostics.append(diagnostic)
try:
yield diagnostic
finally:
self._inflight_diagnostics.pop()
def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
"""Pushes a diagnostic to the inflight diagnostics stack.
Args:
diagnostic: The diagnostic to push.
Raises:
ValueError: If the rule is not supported by the tool.
"""
self._inflight_diagnostics.append(diagnostic)
def pop_inflight_diagnostic(self) -> _Diagnostic:
"""Pops the last diagnostic from the inflight diagnostics stack.
Returns:
The popped diagnostic.
"""
return self._inflight_diagnostics.pop()
def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic:
if rule is None:
# TODO(bowbao): Create builtin-rules and create diagnostic using that.
if len(self._inflight_diagnostics) <= 0:
raise AssertionError("No inflight diagnostics")
return self._inflight_diagnostics[-1]
else:
for diagnostic in reversed(self._inflight_diagnostics):
if diagnostic.rule == rule:
return diagnostic
raise AssertionError(f"No inflight diagnostic for rule {rule.name}")

View File

@ -0,0 +1,153 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
import traceback
from typing import Any, Callable, Dict, Tuple
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, utils
MessageFormatterType = Callable[..., str]
def format_message_in_text(fn: Callable, *args: Any, **kwargs: Any) -> str:
return f"{formatter.display_name(fn)}. "
def format_exception_in_markdown(exception: Exception) -> str:
msg_list = ["### Exception log", "```"]
msg_list.extend(
traceback.format_exception(type(exception), exception, exception.__traceback__)
)
msg_list.append("```")
return "\n".join(msg_list)
def format_function_signature_in_markdown(
fn: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
format_argument: Callable[[Any], str] = formatter.format_argument,
) -> str:
msg_list = [f"### Function Signature {formatter.display_name(fn)}"]
state = utils.function_state(fn, args, kwargs)
for k, v in state.items():
msg_list.append(f"- {k}: {format_argument(v)}")
return "\n".join(msg_list)
def format_return_values_in_markdown(
return_values: Any,
format_argument: Callable[[Any], str] = formatter.format_argument,
) -> str:
return f"{format_argument(return_values)}"
ModifierCallableType = Callable[
[infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None
]
def diagnose_call(
rule: infra.Rule,
*,
level: infra.Level = infra.Level.NONE,
diagnostic_type: type[infra.Diagnostic] = infra.Diagnostic,
format_argument: Callable[[Any], str] = formatter.format_argument,
diagnostic_message_formatter: MessageFormatterType = format_message_in_text,
) -> Callable:
def decorator(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
common_error_message = "diagnose_call can only be applied to callables"
if not callable(fn):
raise AssertionError(
f"{common_error_message}. Got {type(fn)} instead of callable."
)
arg0 = args[0] if len(args) > 0 else None
if isinstance(ctx := arg0, infra.DiagnosticContext):
pass
elif isinstance(
ctx := getattr(arg0, "diagnostic_context", None),
infra.DiagnosticContext,
):
pass
else:
# NOTE: At decorate time, it can't tell if a callable is function or method.
# Technically both are regarded as function at that time.
raise AssertionError(
f"{common_error_message}. For {fn}, "
f"If it is a function, a DiagnosticContext instance must be present as "
f"the first argument. "
f"If it is a method, a DiagnosticContext instance must be present as "
f"the attribute 'diagnostic_context' of the 'self' argument."
)
diag = diagnostic_type(
rule,
level,
diagnostic_message_formatter(fn, *args, **kwargs),
)
# pop the decorator frame
# TODO(bowbao): by default diagnostic doesn't have stack.
# So need to check before doing this. Make the code cleaner.
# Option: do not capture stack by default in diagnostic initialization.
stack: infra.Stack | None = None
if len(diag.stacks) > 0:
stack = diag.stacks[0]
stack.frames.pop(0)
# set function location
fn_location = utils.function_location(fn)
diag.locations.insert(0, fn_location)
# Add function location to the top of the stack.
if stack is not None:
stack.frames.insert(0, infra.StackFrame(location=fn_location))
with diag.log_section(logging.INFO, "Function Signature"):
diag.log(
logging.INFO,
"%s",
formatter.LazyString(
format_function_signature_in_markdown,
fn,
args,
kwargs,
format_argument,
),
)
return_values: Any = None
with ctx.add_inflight_diagnostic(diag) as diag:
try:
return_values = fn(*args, **kwargs)
with diag.log_section(logging.INFO, "Return values"):
diag.log(
logging.INFO,
"%s",
formatter.LazyString(
format_return_values_in_markdown,
return_values,
format_argument,
),
)
return return_values
except Exception as e:
diag.log_source_exception(logging.ERROR, e)
diag.level = infra.Level.ERROR
finally:
ctx.log_and_raise_if_error(diag)
return wrapper
return decorator
# TODO(bowbao): decorator to report only when failed.

View File

@ -0,0 +1,106 @@
from __future__ import annotations
import dataclasses
import json
import re
import traceback
from typing import Any, Callable, Union
from torch._logging import LazyString
from torch.onnx._internal.diagnostics.infra import sarif
# A list of types in the SARIF module to support pretty printing.
# This is solely for type annotation for the functions below.
_SarifClass = Union[
sarif.SarifLog,
sarif.Run,
sarif.ReportingDescriptor,
sarif.Result,
]
def lazy_format_exception(exception: Exception) -> LazyString:
return LazyString(
lambda: "\n".join(
(
"```",
*traceback.format_exception(
type(exception), exception, exception.__traceback__
),
"```",
)
),
)
def snake_case_to_camel_case(s: str) -> str:
splits = s.split("_")
if len(splits) <= 1:
return s
return "".join([splits[0], *map(str.capitalize, splits[1:])])
def camel_case_to_snake_case(s: str) -> str:
return re.sub(r"([A-Z])", r"_\1", s).lower()
def kebab_case_to_snake_case(s: str) -> str:
return s.replace("-", "_")
def _convert_key(
object: dict[str, Any] | Any, convert: Callable[[str], str]
) -> dict[str, Any] | Any:
"""Convert and update keys in a dictionary with "convert".
Any value that is a dictionary will be recursively updated.
Any value that is a list will be recursively searched.
Args:
object: The object to update.
convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
Returns:
The updated object.
"""
if not isinstance(object, dict):
return object
new_dict = {}
for k, v in object.items():
new_k = convert(k)
if isinstance(v, dict):
new_v = _convert_key(v, convert)
elif isinstance(v, list):
new_v = [_convert_key(elem, convert) for elem in v]
else:
new_v = v
if new_v is None:
# Otherwise unnecessarily bloated sarif log with "null"s.
continue
if new_v == -1:
# WAR: -1 as default value shouldn't be logged into sarif.
continue
new_dict[new_k] = new_v
return new_dict
def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str:
dict = dataclasses.asdict(attr_cls_obj)
dict = _convert_key(dict, snake_case_to_camel_case)
return json.dumps(dict, indent=indent, separators=(",", ":"))
def format_argument(obj: Any) -> str:
return f"{type(obj)}"
def display_name(fn: Callable) -> str:
if hasattr(fn, "__qualname__"):
return fn.__qualname__
elif hasattr(fn, "__name__"):
return fn.__name__
else:
return str(fn)

View File

@ -0,0 +1,101 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from torch.onnx._internal.diagnostics.infra.sarif._address import Address
from torch.onnx._internal.diagnostics.infra.sarif._artifact import Artifact
from torch.onnx._internal.diagnostics.infra.sarif._artifact_change import ArtifactChange
from torch.onnx._internal.diagnostics.infra.sarif._artifact_content import (
ArtifactContent,
)
from torch.onnx._internal.diagnostics.infra.sarif._artifact_location import (
ArtifactLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._attachment import Attachment
from torch.onnx._internal.diagnostics.infra.sarif._code_flow import CodeFlow
from torch.onnx._internal.diagnostics.infra.sarif._configuration_override import (
ConfigurationOverride,
)
from torch.onnx._internal.diagnostics.infra.sarif._conversion import Conversion
from torch.onnx._internal.diagnostics.infra.sarif._edge import Edge
from torch.onnx._internal.diagnostics.infra.sarif._edge_traversal import EdgeTraversal
from torch.onnx._internal.diagnostics.infra.sarif._exception import Exception
from torch.onnx._internal.diagnostics.infra.sarif._external_properties import (
ExternalProperties,
)
from torch.onnx._internal.diagnostics.infra.sarif._external_property_file_reference import (
ExternalPropertyFileReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._external_property_file_references import (
ExternalPropertyFileReferences,
)
from torch.onnx._internal.diagnostics.infra.sarif._fix import Fix
from torch.onnx._internal.diagnostics.infra.sarif._graph import Graph
from torch.onnx._internal.diagnostics.infra.sarif._graph_traversal import GraphTraversal
from torch.onnx._internal.diagnostics.infra.sarif._invocation import Invocation
from torch.onnx._internal.diagnostics.infra.sarif._location import Location
from torch.onnx._internal.diagnostics.infra.sarif._location_relationship import (
LocationRelationship,
)
from torch.onnx._internal.diagnostics.infra.sarif._logical_location import (
LogicalLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._message import Message
from torch.onnx._internal.diagnostics.infra.sarif._multiformat_message_string import (
MultiformatMessageString,
)
from torch.onnx._internal.diagnostics.infra.sarif._node import Node
from torch.onnx._internal.diagnostics.infra.sarif._notification import Notification
from torch.onnx._internal.diagnostics.infra.sarif._physical_location import (
PhysicalLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._property_bag import PropertyBag
from torch.onnx._internal.diagnostics.infra.sarif._rectangle import Rectangle
from torch.onnx._internal.diagnostics.infra.sarif._region import Region
from torch.onnx._internal.diagnostics.infra.sarif._replacement import Replacement
from torch.onnx._internal.diagnostics.infra.sarif._reporting_configuration import (
ReportingConfiguration,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor import (
ReportingDescriptor,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor_reference import (
ReportingDescriptorReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._reporting_descriptor_relationship import (
ReportingDescriptorRelationship,
)
from torch.onnx._internal.diagnostics.infra.sarif._result import Result
from torch.onnx._internal.diagnostics.infra.sarif._result_provenance import (
ResultProvenance,
)
from torch.onnx._internal.diagnostics.infra.sarif._run import Run
from torch.onnx._internal.diagnostics.infra.sarif._run_automation_details import (
RunAutomationDetails,
)
from torch.onnx._internal.diagnostics.infra.sarif._sarif_log import SarifLog
from torch.onnx._internal.diagnostics.infra.sarif._special_locations import (
SpecialLocations,
)
from torch.onnx._internal.diagnostics.infra.sarif._stack import Stack
from torch.onnx._internal.diagnostics.infra.sarif._stack_frame import StackFrame
from torch.onnx._internal.diagnostics.infra.sarif._suppression import Suppression
from torch.onnx._internal.diagnostics.infra.sarif._thread_flow import ThreadFlow
from torch.onnx._internal.diagnostics.infra.sarif._thread_flow_location import (
ThreadFlowLocation,
)
from torch.onnx._internal.diagnostics.infra.sarif._tool import Tool
from torch.onnx._internal.diagnostics.infra.sarif._tool_component import ToolComponent
from torch.onnx._internal.diagnostics.infra.sarif._tool_component_reference import (
ToolComponentReference,
)
from torch.onnx._internal.diagnostics.infra.sarif._translation_metadata import (
TranslationMetadata,
)
from torch.onnx._internal.diagnostics.infra.sarif._version_control_details import (
VersionControlDetails,
)
from torch.onnx._internal.diagnostics.infra.sarif._web_request import WebRequest
from torch.onnx._internal.diagnostics.infra.sarif._web_response import WebResponse
# flake8: noqa

View File

@ -0,0 +1,48 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import _property_bag
@dataclasses.dataclass
class Address(object):
"""A physical or virtual address, or a range of addresses, in an 'addressable region' (memory or a binary file)."""
absolute_address: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "absoluteAddress"}
)
fully_qualified_name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "fullyQualifiedName"}
)
index: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "index"}
)
kind: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "kind"}
)
length: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "length"}
)
name: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "name"}
)
offset_from_parent: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "offsetFromParent"}
)
parent_index: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "parentIndex"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
relative_address: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "relativeAddress"}
)
# flake8: noqa

View File

@ -0,0 +1,88 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Any, List, Literal, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_content,
_artifact_location,
_message,
_property_bag,
)
@dataclasses.dataclass
class Artifact(object):
"""A single artifact. In some cases, this artifact might be nested within another artifact."""
contents: Optional[_artifact_content.ArtifactContent] = dataclasses.field(
default=None, metadata={"schema_property_name": "contents"}
)
description: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "description"}
)
encoding: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "encoding"}
)
hashes: Any = dataclasses.field(
default=None, metadata={"schema_property_name": "hashes"}
)
last_modified_time_utc: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "lastModifiedTimeUtc"}
)
length: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "length"}
)
location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "location"}
)
mime_type: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "mimeType"}
)
offset: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "offset"}
)
parent_index: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "parentIndex"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
roles: Optional[
List[
Literal[
"analysisTarget",
"attachment",
"responseFile",
"resultFile",
"standardStream",
"tracedFile",
"unmodified",
"modified",
"added",
"deleted",
"renamed",
"uncontrolled",
"driver",
"extension",
"translation",
"taxonomy",
"policy",
"referencedOnCommandLine",
"memoryContents",
"directory",
"userSpecifiedConfiguration",
"toolSpecifiedConfiguration",
"debugOutputFile",
]
]
] = dataclasses.field(default=None, metadata={"schema_property_name": "roles"})
source_language: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "sourceLanguage"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_location,
_property_bag,
_replacement,
)
@dataclasses.dataclass
class ArtifactChange(object):
"""A change to a single artifact."""
artifact_location: _artifact_location.ArtifactLocation = dataclasses.field(
metadata={"schema_property_name": "artifactLocation"}
)
replacements: List[_replacement.Replacement] = dataclasses.field(
metadata={"schema_property_name": "replacements"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,33 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_multiformat_message_string,
_property_bag,
)
@dataclasses.dataclass
class ArtifactContent(object):
"""Represents the contents of an artifact."""
binary: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "binary"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
rendered: Optional[_multiformat_message_string.MultiformatMessageString] = (
dataclasses.field(default=None, metadata={"schema_property_name": "rendered"})
)
text: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "text"}
)
# flake8: noqa

View File

@ -0,0 +1,33 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import _message, _property_bag
@dataclasses.dataclass
class ArtifactLocation(object):
"""Specifies the location of an artifact."""
description: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "description"}
)
index: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "index"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
uri: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "uri"}
)
uri_base_id: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "uriBaseId"}
)
# flake8: noqa

View File

@ -0,0 +1,39 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_location,
_message,
_property_bag,
_rectangle,
_region,
)
@dataclasses.dataclass
class Attachment(object):
"""An artifact relevant to a result."""
artifact_location: _artifact_location.ArtifactLocation = dataclasses.field(
metadata={"schema_property_name": "artifactLocation"}
)
description: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "description"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
rectangles: Optional[List[_rectangle.Rectangle]] = dataclasses.field(
default=None, metadata={"schema_property_name": "rectangles"}
)
regions: Optional[List[_region.Region]] = dataclasses.field(
default=None, metadata={"schema_property_name": "regions"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_message,
_property_bag,
_thread_flow,
)
@dataclasses.dataclass
class CodeFlow(object):
"""A set of threadFlows which together describe a pattern of code execution relevant to detecting a result."""
thread_flows: List[_thread_flow.ThreadFlow] = dataclasses.field(
metadata={"schema_property_name": "threadFlows"}
)
message: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "message"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_property_bag,
_reporting_configuration,
_reporting_descriptor_reference,
)
@dataclasses.dataclass
class ConfigurationOverride(object):
"""Information about how a specific rule or notification was reconfigured at runtime."""
configuration: _reporting_configuration.ReportingConfiguration = dataclasses.field(
metadata={"schema_property_name": "configuration"}
)
descriptor: _reporting_descriptor_reference.ReportingDescriptorReference = (
dataclasses.field(metadata={"schema_property_name": "descriptor"})
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,35 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_location,
_invocation,
_property_bag,
_tool,
)
@dataclasses.dataclass
class Conversion(object):
"""Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format."""
tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"})
analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "analysisToolLogFiles"}
)
)
invocation: Optional[_invocation.Invocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocation"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import _message, _property_bag
@dataclasses.dataclass
class Edge(object):
"""Represents a directed edge in a graph."""
id: str = dataclasses.field(metadata={"schema_property_name": "id"})
source_node_id: str = dataclasses.field(
metadata={"schema_property_name": "sourceNodeId"}
)
target_node_id: str = dataclasses.field(
metadata={"schema_property_name": "targetNodeId"}
)
label: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "label"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Any, Optional
from torch.onnx._internal.diagnostics.infra.sarif import _message, _property_bag
@dataclasses.dataclass
class EdgeTraversal(object):
"""Represents the traversal of a single edge during a graph traversal."""
edge_id: str = dataclasses.field(metadata={"schema_property_name": "edgeId"})
final_state: Any = dataclasses.field(
default=None, metadata={"schema_property_name": "finalState"}
)
message: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "message"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
step_over_edge_count: Optional[int] = dataclasses.field(
default=None, metadata={"schema_property_name": "stepOverEdgeCount"}
)
# flake8: noqa

View File

@ -0,0 +1,37 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_exception,
_property_bag,
_stack,
)
@dataclasses.dataclass
class Exception(object):
"""Describes a runtime exception encountered during the execution of an analysis tool."""
inner_exceptions: Optional[List[_exception.Exception]] = dataclasses.field(
default=None, metadata={"schema_property_name": "innerExceptions"}
)
kind: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "kind"}
)
message: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "message"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
stack: Optional[_stack.Stack] = dataclasses.field(
default=None, metadata={"schema_property_name": "stack"}
)
# flake8: noqa

View File

@ -0,0 +1,98 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Literal, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_address,
_artifact,
_conversion,
_graph,
_invocation,
_logical_location,
_property_bag,
_result,
_thread_flow_location,
_tool_component,
_web_request,
_web_response,
)
@dataclasses.dataclass
class ExternalProperties(object):
"""The top-level element of an external property file."""
addresses: Optional[List[_address.Address]] = dataclasses.field(
default=None, metadata={"schema_property_name": "addresses"}
)
artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field(
default=None, metadata={"schema_property_name": "artifacts"}
)
conversion: Optional[_conversion.Conversion] = dataclasses.field(
default=None, metadata={"schema_property_name": "conversion"}
)
driver: Optional[_tool_component.ToolComponent] = dataclasses.field(
default=None, metadata={"schema_property_name": "driver"}
)
extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "extensions"}
)
externalized_properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "externalizedProperties"}
)
graphs: Optional[List[_graph.Graph]] = dataclasses.field(
default=None, metadata={"schema_property_name": "graphs"}
)
guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "guid"}
)
invocations: Optional[List[_invocation.Invocation]] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocations"}
)
logical_locations: Optional[List[_logical_location.LogicalLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
)
)
policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "policies"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
results: Optional[List[_result.Result]] = dataclasses.field(
default=None, metadata={"schema_property_name": "results"}
)
run_guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "runGuid"}
)
schema: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "schema"}
)
taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "taxonomies"}
)
thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = (
dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
)
)
translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"}
)
version: Optional[Literal["2.1.0"]] = dataclasses.field(
default=None, metadata={"schema_property_name": "version"}
)
web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field(
default=None, metadata={"schema_property_name": "webRequests"}
)
web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field(
default=None, metadata={"schema_property_name": "webResponses"}
)
# flake8: noqa

View File

@ -0,0 +1,33 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_location,
_property_bag,
)
@dataclasses.dataclass
class ExternalPropertyFileReference(object):
"""Contains information that enables a SARIF consumer to locate the external property file that contains the value of an externalized property associated with the run."""
guid: Optional[str] = dataclasses.field(
default=None, metadata={"schema_property_name": "guid"}
)
item_count: int = dataclasses.field(
default=-1, metadata={"schema_property_name": "itemCount"}
)
location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field(
default=None, metadata={"schema_property_name": "location"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

View File

@ -0,0 +1,86 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_external_property_file_reference,
_property_bag,
)
@dataclasses.dataclass
class ExternalPropertyFileReferences(object):
"""References to external property files that should be inlined with the content of a root log file."""
addresses: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "addresses"})
artifacts: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "artifacts"})
conversion: Optional[
_external_property_file_reference.ExternalPropertyFileReference
] = dataclasses.field(default=None, metadata={"schema_property_name": "conversion"})
driver: Optional[
_external_property_file_reference.ExternalPropertyFileReference
] = dataclasses.field(default=None, metadata={"schema_property_name": "driver"})
extensions: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "extensions"})
externalized_properties: Optional[
_external_property_file_reference.ExternalPropertyFileReference
] = dataclasses.field(
default=None, metadata={"schema_property_name": "externalizedProperties"}
)
graphs: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "graphs"})
invocations: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "invocations"}
)
logical_locations: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "logicalLocations"}
)
policies: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "policies"})
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
results: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "results"})
taxonomies: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(default=None, metadata={"schema_property_name": "taxonomies"})
thread_flow_locations: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "threadFlowLocations"}
)
translations: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "translations"}
)
web_requests: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "webRequests"}
)
web_responses: Optional[
List[_external_property_file_reference.ExternalPropertyFileReference]
] = dataclasses.field(
default=None, metadata={"schema_property_name": "webResponses"}
)
# flake8: noqa

View File

@ -0,0 +1,31 @@
# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29,
# with extension for dataclasses and type annotation.
from __future__ import annotations
import dataclasses
from typing import List, Optional
from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_change,
_message,
_property_bag,
)
@dataclasses.dataclass
class Fix(object):
"""A proposed fix for the problem represented by a result object. A fix specifies a set of artifacts to modify. For each artifact, it specifies a set of bytes to remove, and provides a set of new bytes to replace them."""
artifact_changes: List[_artifact_change.ArtifactChange] = dataclasses.field(
metadata={"schema_property_name": "artifactChanges"}
)
description: Optional[_message.Message] = dataclasses.field(
default=None, metadata={"schema_property_name": "description"}
)
properties: Optional[_property_bag.PropertyBag] = dataclasses.field(
default=None, metadata={"schema_property_name": "properties"}
)
# flake8: noqa

Some files were not shown because too many files have changed in this diff Show More