234 lines
9.1 KiB
Python
234 lines
9.1 KiB
Python
import contextlib
|
|
import threading
|
|
from typing import Callable, Generator, Iterable, Optional, Union
|
|
|
|
from .custom_ops import custom_op
|
|
from .infer_schema import infer_schema
|
|
|
|
|
|
def triton_op(
|
|
name: str,
|
|
fn: Optional[Callable] = None,
|
|
/,
|
|
*,
|
|
mutates_args: Union[str, Iterable[str]],
|
|
schema: Optional[str] = None,
|
|
) -> Callable:
|
|
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
|
|
|
|
Use this instead of :func:`torch.library.custom_op` when the implementation
|
|
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
|
|
custom operators as opaque (:func:`torch.compile` and
|
|
:func:`torch.export.export` will never trace into them), but ``triton_op``
|
|
makes the implementation visible to these subsystems, allowing them
|
|
to optimize the triton kernel(s).
|
|
|
|
Note that ``fn`` must only consist of calls to PyTorch-understood
|
|
operators and triton kernels. Any triton kernels called inside ``fn``
|
|
must be wrapped in a call to :func:`torch._library.capture_triton``.
|
|
|
|
Args:
|
|
name (str): A name for the custom op that looks like "{namespace}::{name}",
|
|
e.g. "mylib::my_linear". The name is used as the op's stable identifier
|
|
in PyTorch subsystems (e.g. torch.export, FX graphs).
|
|
To avoid name collisions, please use your project name as the namespace;
|
|
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
|
|
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
|
|
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
|
|
it pessimistically assumes that all inputs to the operator are being mutated.
|
|
schema (None | str): A schema string for the operator. If None
|
|
(recommended) we'll infer a schema for the operator from its type
|
|
annotations. We recommend letting us infer a schema unless you
|
|
have a specific reason not to.
|
|
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> import torch
|
|
>>> from torch._library import triton_op, capture_triton
|
|
>>>
|
|
>>> import triton
|
|
>>> from triton import language as tl
|
|
>>>
|
|
>>> @triton.jit
|
|
>>> def add_kernel(
|
|
>>> in_ptr0,
|
|
>>> in_ptr1,
|
|
>>> out_ptr,
|
|
>>> n_elements,
|
|
>>> BLOCK_SIZE: "tl.constexpr",
|
|
>>> ):
|
|
>>> pid = tl.program_id(axis=0)
|
|
>>> block_start = pid * BLOCK_SIZE
|
|
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
>>> mask = offsets < n_elements
|
|
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
>>> output = x + y
|
|
>>> tl.store(out_ptr + offsets, output, mask=mask)
|
|
>>>
|
|
>>> @triton_op("mylib::add", mutates_args={})
|
|
>>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
>>> output = torch.empty_like(x)
|
|
>>> n_elements = output.numel()
|
|
>>>
|
|
>>> def grid(meta):
|
|
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
>>>
|
|
>>> # NB: we need to wrap the triton kernel in a call to capture_triton
|
|
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
|
>>> return output
|
|
>>>
|
|
>>> @torch.compile
|
|
>>> def f(x, y):
|
|
>>> return add(x, y)
|
|
>>>
|
|
>>> x = torch.randn(3, device="cuda")
|
|
>>> y = torch.randn(3, device="cuda")
|
|
>>>
|
|
>>> z = f(x, y)
|
|
>>> assert torch.allclose(z, x + y)
|
|
|
|
"""
|
|
|
|
def dec(fn: Callable) -> Callable:
|
|
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
|
# Optimization: we're passing regular Tensors into the triton kernel, so
|
|
# no need to go through HOP dispatch
|
|
with set_capture_triton_enabled(False):
|
|
return fn(*args, **kwargs)
|
|
|
|
result = custom_op(
|
|
name,
|
|
backend_fn,
|
|
mutates_args=mutates_args,
|
|
schema=infer_schema(fn, mutates_args=mutates_args),
|
|
)
|
|
from .._subclasses.functional_tensor import FunctionalTensorMode
|
|
|
|
# We require that the user pass us a function that is make_fx traceable,
|
|
# so we can just register it as the Fake/meta kernel.
|
|
result.register_fake(fn)
|
|
|
|
# We decompose the operator when FunctionalTensorMode is active.
|
|
# The goal is to decompose the operator in AOTDispatcher.
|
|
# - With torch.compile, this means that the backend (usually Inductor)
|
|
# can see a call to the triton kernel(s) and so it can directly optimize
|
|
# them by inlining them into the lowering process.
|
|
# - With post-dispatch torch.export, this means that there will
|
|
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
|
|
# graph (that we have yet to figure out how to serialize).
|
|
def functional_decomp( # type: ignore[no-untyped-def]
|
|
mode, _, types, args, kwargs
|
|
):
|
|
with mode:
|
|
return fn(*args, **kwargs)
|
|
|
|
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
|
|
return result
|
|
|
|
if fn is None:
|
|
return dec
|
|
else:
|
|
return dec(fn)
|
|
|
|
|
|
capture_triton_enabled = threading.local()
|
|
capture_triton_enabled_default = True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
|
|
"""If triton kernels annotated with @capture_triton should dispatch via HOP
|
|
or go straight to the triton kernel execution.
|
|
|
|
We have this switch because eager-mode performance of HOP dispatch is slow
|
|
enough to matter (~1ms) and we know that capture_triton isn't necessary in
|
|
some situations (eager-mode with regular Tensors)
|
|
"""
|
|
try:
|
|
prev = is_capture_triton_enabled()
|
|
capture_triton_enabled.value = enabled
|
|
yield
|
|
finally:
|
|
capture_triton_enabled.value = prev
|
|
|
|
|
|
def is_capture_triton_enabled() -> bool:
|
|
return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
|
|
|
|
|
|
def capture_triton(triton_kernel: Callable, /) -> Callable:
|
|
"""Allows capture of a triton kernel into a graph via make_fx or
|
|
non-strict export (coming soon).
|
|
|
|
These technologies perform Dispatcher-based tracing (via
|
|
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
|
|
The ``capture_triton`` API returns a new callable that can actually
|
|
be traced into a graph.
|
|
|
|
Examples:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch
|
|
>>> import triton
|
|
>>> from triton import language as tl
|
|
>>> from torch.fx.experimental.proxy_tensor import make_fx
|
|
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
|
|
>>>
|
|
>>> @triton.jit
|
|
>>> def add_kernel(
|
|
>>> in_ptr0,
|
|
>>> in_ptr1,
|
|
>>> out_ptr,
|
|
>>> n_elements,
|
|
>>> BLOCK_SIZE: "tl.constexpr",
|
|
>>> ):
|
|
>>> pid = tl.program_id(axis=0)
|
|
>>> block_start = pid * BLOCK_SIZE
|
|
>>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
>>> mask = offsets < n_elements
|
|
>>> x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
>>> y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
>>> output = x + y
|
|
>>> tl.store(out_ptr + offsets, output, mask=mask)
|
|
>>>
|
|
>>> def add(x, y):
|
|
>>> output = torch.empty_like(x)
|
|
>>> n_elements = output.numel()
|
|
>>>
|
|
>>> def grid_fn(meta):
|
|
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
>>>
|
|
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
|
|
>>> return output
|
|
>>>
|
|
>>> x = torch.randn(3, device="cuda")
|
|
>>> y = torch.randn(3, device="cuda")
|
|
>>> gm = make_fx(add)(x, y)
|
|
>>> print(gm.code)
|
|
>>> # def forward(self, x_1, y_1):
|
|
>>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
|
|
>>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
|
|
>>> # kernel_idx = 0, constant_args_idx = 0,
|
|
>>> # grid = [(1, 1, 1)], kwargs = {
|
|
>>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
|
|
>>> # 'n_elements': 3, 'BLOCK_SIZE': 16
|
|
>>> # })
|
|
>>> # return empty_like
|
|
|
|
"""
|
|
from triton.runtime.autotuner import Autotuner
|
|
from triton.runtime.jit import JITFunction
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
|
|
|
|
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
|
|
raise RuntimeError(
|
|
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
|
|
)
|
|
if not is_capture_triton_enabled():
|
|
return triton_kernel
|
|
return TraceableTritonKernelWrapper(triton_kernel, None, None)
|