93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
# mypy: allow-untyped-defs
|
|
"""This file exports ONNX ops for opset 20.
|
|
|
|
Note [ONNX Operators that are added/updated in opset 20]
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-20-of-the-default-onnx-operator-set
|
|
New operators:
|
|
AffineGrid
|
|
ConstantOfShape
|
|
DFT
|
|
Gelu
|
|
GridSample
|
|
ImageDecoder
|
|
IsInf
|
|
IsNaN
|
|
ReduceMax
|
|
ReduceMin
|
|
RegexFullMatch
|
|
StringConcat
|
|
StringSplit
|
|
"""
|
|
|
|
import functools
|
|
|
|
import torch.nn.functional as F
|
|
from torch import _C
|
|
from torch.onnx import symbolic_helper
|
|
from torch.onnx._internal import jit_utils, registration
|
|
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
|
|
|
__all__ = ["_grid_sampler", "_affine_grid_generator", "gelu"]
|
|
|
|
|
|
def convert_grid_sample_mode(mode_s):
|
|
return (
|
|
"linear" if mode_s == "bilinear" else "cubic" if mode_s == "bicubic" else mode_s
|
|
)
|
|
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=20)
|
|
|
|
|
|
@_onnx_symbolic("aten::grid_sampler")
|
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
|
def _grid_sampler(
|
|
g: jit_utils.GraphContext,
|
|
input: _C.Value,
|
|
grid: _C.Value,
|
|
mode_enum: int,
|
|
padding_mode_enum: int,
|
|
align_corners: bool,
|
|
):
|
|
mode_s = {v: k for k, v in F.GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg, index]
|
|
# mode string changes at https://onnx.ai/onnx/operators/text_diff_GridSample_16_20.html
|
|
mode_s = convert_grid_sample_mode(mode_s)
|
|
padding_mode_s = {v: k for k, v in F.GRID_SAMPLE_PADDING_MODES.items()}[ # type: ignore[call-arg, index]
|
|
padding_mode_enum # type: ignore[index]
|
|
]
|
|
return g.op(
|
|
"GridSample",
|
|
input,
|
|
grid,
|
|
align_corners_i=int(align_corners),
|
|
mode_s=mode_s,
|
|
padding_mode_s=padding_mode_s,
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::affine_grid_generator")
|
|
@symbolic_helper.parse_args("v", "v", "b")
|
|
def _affine_grid_generator(
|
|
g: jit_utils.GraphContext,
|
|
theta: _C.Value,
|
|
size: _C.Value,
|
|
align_corners: bool,
|
|
):
|
|
return g.op(
|
|
"AffineGrid",
|
|
theta,
|
|
size,
|
|
align_corners_i=int(align_corners),
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::gelu")
|
|
@symbolic_helper.parse_args("v", "s")
|
|
def gelu(g: jit_utils.GraphContext, self: _C.Value, approximate: str = "none"):
|
|
return g.op("Gelu", self, approximate_s=approximate)
|