I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,20 @@
from ctypes import c_void_p
from torch import Tensor
# Defined in torch/csrc/inductor/aoti_runner/pybind.cpp
# Tensor to AtenTensorHandle
def unsafe_alloc_void_ptrs_from_tensors(tensors: list[Tensor]) -> list[c_void_p]: ...
def unsafe_alloc_void_ptr_from_tensor(tensor: Tensor) -> c_void_p: ...
# AtenTensorHandle to Tensor
def alloc_tensors_by_stealing_from_void_ptrs(
handles: list[c_void_p],
) -> list[Tensor]: ...
def alloc_tensor_by_stealing_from_void_ptr(
handle: c_void_p,
) -> Tensor: ...
class AOTIModelContainerRunnerCpu: ...
class AOTIModelContainerRunnerCuda: ...

View File

@ -0,0 +1,135 @@
# mypy: allow-untyped-defs
from enum import Enum
from typing import Any, Callable
import torch
from torch._C._profiler import (
_ProfilerEvent,
ActiveProfilerType,
ProfilerActivity,
ProfilerConfig,
)
# Defined in torch/csrc/autograd/init.cpp
class DeviceType(Enum):
CPU = ...
CUDA = ...
XPU = ...
MKLDNN = ...
OPENGL = ...
OPENCL = ...
IDEEP = ...
HIP = ...
FPGA = ...
MAIA = ...
XLA = ...
MTIA = ...
MPS = ...
HPU = ...
Meta = ...
Vulkan = ...
Metal = ...
PrivateUse1 = ...
class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
def privateuse1_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cuda_memory_usage(self) -> int: ...
def device(self) -> int: ...
def handle(self) -> int: ...
def has_cuda(self) -> bool: ...
def is_remote(self) -> bool: ...
def kind(self) -> int: ...
def name(self) -> str: ...
def node_id(self) -> int: ...
def sequence_nr(self) -> int: ...
def shapes(self) -> list[list[int]]: ...
def thread_id(self) -> int: ...
def flops(self) -> float: ...
def is_async(self) -> bool: ...
class _KinetoEvent:
def name(self) -> str: ...
def device_index(self) -> int: ...
def device_resource_id(self) -> int: ...
def start_ns(self) -> int: ...
def end_ns(self) -> int: ...
def duration_ns(self) -> int: ...
def is_async(self) -> bool: ...
def linked_correlation_id(self) -> int: ...
def shapes(self) -> list[list[int]]: ...
def dtypes(self) -> list[str]: ...
def concrete_inputs(self) -> list[Any]: ...
def kwinputs(self) -> dict[str, Any]: ...
def device_type(self) -> DeviceType: ...
def start_thread_id(self) -> int: ...
def end_thread_id(self) -> int: ...
def correlation_id(self) -> int: ...
def fwd_thread_id(self) -> int: ...
def stack(self) -> list[str]: ...
def scope(self) -> int: ...
def sequence_nr(self) -> int: ...
def flops(self) -> int: ...
def cuda_elapsed_us(self) -> int: ...
def privateuse1_elapsed_us(self) -> int: ...
def is_user_annotation(self) -> bool: ...
class _ProfilerResult:
def events(self) -> list[_KinetoEvent]: ...
def legacy_events(self) -> list[list[ProfilerEvent]]: ...
def save(self, path: str) -> None: ...
def experimental_event_tree(self) -> list[_ProfilerEvent]: ...
def trace_start_ns(self) -> int: ...
class SavedTensor: ...
def _enable_profiler(
config: ProfilerConfig,
activities: set[ProfilerActivity],
) -> None: ...
def _prepare_profiler(
config: ProfilerConfig,
activities: set[ProfilerActivity],
) -> None: ...
def _toggle_collection_dynamic(
enable: bool,
activities: set[ProfilerActivity],
) -> None: ...
def _disable_profiler() -> _ProfilerResult: ...
def _profiler_enabled() -> bool: ...
def _add_metadata_json(key: str, value: str) -> None: ...
def _kineto_step() -> None: ...
def _get_current_graph_task_keep_graph() -> bool: ...
def _get_sequence_nr() -> int: ...
def kineto_available() -> bool: ...
def _record_function_with_args_enter(name: str, *args) -> torch.Tensor: ...
def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
def _supported_activities() -> set[ProfilerActivity]: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
def _push_saved_tensors_default_hooks(
pack_hook: Callable[[torch.Tensor], Any],
unpack_hook: Callable[[Any], torch.Tensor],
) -> None: ...
def _pop_saved_tensors_default_hooks() -> None: ...
def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...
def _saved_tensors_hooks_enable() -> None: ...
def _saved_tensors_hooks_disable(message: str) -> None: ...
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
class CreationMeta(Enum):
DEFAULT = ...
IN_CUSTOM_FUNCTION = ...
MULTI_OUTPUT_NODE = ...
NO_GRAD_MODE = ...
INFERENCE_MODE = ...
def _set_creation_meta(t: torch.Tensor, creation_meta: CreationMeta) -> None: ...
def _get_creation_meta(t: torch.Tensor) -> CreationMeta: ...

View File

@ -0,0 +1,12 @@
from torch.types import _bool, _int
# Defined in torch/csrc/cpu/Module.cpp
def _is_avx2_supported() -> _bool: ...
def _is_avx512_supported() -> _bool: ...
def _is_avx512_vnni_supported() -> _bool: ...
def _is_avx512_bf16_supported() -> _bool: ...
def _is_amx_tile_supported() -> _bool: ...
def _init_amx() -> _bool: ...
def _L1d_cache_size() -> _int: ...
def _L2_cache_size() -> _int: ...

View File

@ -0,0 +1,17 @@
from enum import Enum
from torch.types import _bool, Tuple
# Defined in torch/csrc/cuda/shared/cudnn.cpp
is_cuda: _bool
def getRuntimeVersion() -> Tuple[int, int, int]: ...
def getCompileVersion() -> Tuple[int, int, int]: ...
def getVersionInt() -> int: ...
class RNNMode(int, Enum):
value: int
rnn_relu = ...
rnn_tanh = ...
lstm = ...
gru = ...

View File

@ -0,0 +1 @@
def getVersionInt() -> int: ...

View File

@ -0,0 +1,27 @@
# mypy: allow-untyped-defs
from typing import Any
import torch
# This module is defined in torch/csrc/distributed/autograd/init.cpp
class DistAutogradContext:
def _context_id(self) -> int: ...
def _recv_functions(self) -> dict[int, Any]: ...
def _send_functions(self) -> dict[int, Any]: ...
def _known_worker_ids(self) -> set[int]: ...
def _new_context() -> DistAutogradContext: ...
def _release_context(context_id: int) -> None: ...
def _get_max_id() -> int: ...
def _is_valid_context(worker_id: int) -> bool: ...
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
def _current_context() -> DistAutogradContext: ...
def _init(worker_id: int) -> None: ...
def _get_debug_info() -> dict[str, str]: ...
def backward(
context_id: int,
roots: list[torch.Tensor],
retain_graph=False,
) -> None: ...
def get_gradients(context_id: int) -> dict[torch.Tensor, torch.Tensor]: ...

View File

@ -0,0 +1,699 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload
import torch
from torch import Tensor
from torch._C import ScriptObject
from torch.futures import Future
# This module is defined in torch/csrc/distributed/c10d/init.cpp
_DEFAULT_FIRST_BUCKET_BYTES: int
_DEFAULT_NO_TIMEOUT: timedelta
_DEFAULT_PG_TIMEOUT: timedelta
_DEFAULT_PG_NCCL_TIMEOUT: timedelta
class BuiltinCommHookType(Enum):
ALLREDUCE = ...
FP16_COMPRESS = ...
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(
reducer: Reducer,
comm_hook_type: BuiltinCommHookType,
): ...
def _set_global_rank(rank: int) -> None: ...
def _hash_tensors(tensors: list[Tensor]) -> int: ...
class GradBucket:
def index(self) -> int: ...
def buffer(self) -> Tensor: ...
def gradients(self) -> list[Tensor]: ...
def is_last(self) -> bool: ...
def set_buffer(self, tensor: Tensor) -> None: ...
def parameters(self) -> list[Tensor]: ...
class Reducer:
def __init__(
self,
params: list[Tensor],
bucket_indices: list[list[int]],
per_bucket_size_limits: list[int],
process_group: ProcessGroup,
expect_sparse_gradients: list[bool] = ...,
bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
find_unused_parameters: bool = ...,
gradient_as_bucket_view: bool = ...,
param_to_name_mapping: dict[int, str] = ...,
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
) -> None: ...
def prepare_for_forward(self) -> None: ...
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
def get_backward_stats(self) -> list[int]: ...
def _install_post_backward_futures(self, futures: list[Future]) -> None: ...
def _rebuild_buckets(self) -> bool: ...
def _get_zeros_like_grad_buckets(self) -> list[GradBucket]: ...
def _push_all_rebuilt_params(self) -> None: ...
def _set_forward_pass_work_handle(
self,
work: Work,
use_static_world_size: bool,
): ...
def _get_local_used_map(self) -> Tensor: ...
def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
def _set_static_graph(self) -> None: ...
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
def set_logger(self, logger: Logger) -> None: ...
def _remove_autograd_hooks(self) -> None: ...
def _check_reducer_finalized(self) -> None: ...
def _set_sparse_metadata(self, global_unique_ids: dict[str, Tensor]) -> None: ...
def _reset_state(self) -> None: ...
def _update_process_group(self, new_process_group: ProcessGroup) -> None: ...
class DDPLoggingData:
strs_map: dict[str, str]
ints_map: dict[str, int]
class Logger:
def __init__(self, reducer: Reducer) -> None: ...
def set_construction_data_and_log(
self,
module_name: str,
device_ids: list[int],
output_device: int,
broadcast_buffers: bool,
has_sync_bn: bool,
static_graph: bool,
): ...
def set_runtime_stats_and_log(self) -> None: ...
def set_error_and_log(self, error: str) -> None: ...
def _get_ddp_logging_data(self) -> DDPLoggingData: ...
def _set_comm_hook_name(self, comm_hook: str) -> None: ...
def _set_uneven_input_join(self) -> None: ...
def _set_static_graph(self) -> None: ...
class _WorkerServer:
def __init__(self, socket_path: str) -> None: ...
def shutdown(self) -> None: ...
def get_debug_level(): ...
def set_debug_level(): ...
def set_debug_level_from_env(): ...
class DebugLevel(Enum):
OFF = ...
INFO = ...
DETAIL = ...
class ReduceOp:
def __init__(self, op: RedOpType) -> None: ...
SUM: RedOpType = ...
AVG: RedOpType = ...
PRODUCT: RedOpType = ...
MIN: RedOpType = ...
MAX: RedOpType = ...
BAND: RedOpType = ...
BOR: RedOpType = ...
BXOR: RedOpType = ...
PREMUL_SUM: RedOpType = ...
UNUSED: RedOpType = ...
class RedOpType(Enum): ...
class BroadcastOptions:
rootRank: int
rootTensor: int
timeout: timedelta
asyncOp: bool
class AllreduceOptions:
reduceOp: ReduceOp
timeout: timedelta
class AllreduceCoalescedOptions(AllreduceOptions): ...
class ReduceOptions:
reduceOp: ReduceOp
rootRank: int
rootTensor: int
timeout: timedelta
class AllgatherOptions:
timeout: timedelta
asyncOp: bool
class GatherOptions:
rootRank: int
timeout: timedelta
class ScatterOptions:
rootRank: int
timeout: timedelta
asyncOp: bool
class ReduceScatterOptions:
reduceOp: ReduceOp
timeout: timedelta
asyncOp: bool
class BarrierOptions:
device_ids: list[int]
device: torch.device
timeout: timedelta
class AllToAllOptions:
timeout: timedelta
class Store:
def set(self, key: str, value: str): ...
def get(self, key: str) -> bytes: ...
def add(self, key: str, value: int) -> int: ...
def compare_set(
self,
key: str,
expected_value: str,
desired_value: str,
) -> bytes: ...
def delete_key(self, key: str) -> bool: ...
def num_keys(self) -> int: ...
def set_timeout(self, timeout: timedelta): ...
@overload
def wait(self, keys: list[str]): ...
@overload
def wait(self, keys: list[str], timeout: timedelta): ...
class FileStore(Store):
def __init__(self, path: str, numWorkers: int = ...) -> None: ...
class HashStore(Store):
def __init__(self) -> None: ...
class TCPStore(Store):
def __init__(
self,
host_name: str,
port: int,
world_size: int | None = ...,
is_master: bool = ...,
timeout: timedelta = ...,
wait_for_workers: bool = ...,
multi_tenant: bool = ...,
master_listen_fd: int | None = ...,
use_libuv: bool | None = ...,
) -> None: ...
@property
def host(self) -> str: ...
@property
def port(self) -> int: ...
class PrefixStore(Store):
def __init__(self, prefix: str, store: Store) -> None: ...
@property
def underlying_store(self) -> Store: ...
class _ControlCollectives:
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
def all_sum(self, key: str, data: int, timeout: timedelta) -> int: ...
class _StoreCollectives(_ControlCollectives):
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...
class _DistributedBackendOptions:
def __init__(self) -> None: ...
@property
def store(self) -> Store: ...
@store.setter
def store(self, store: Store) -> None: ...
@property
def group_rank(self) -> int: ...
@group_rank.setter
def group_rank(self, rank: int) -> None: ...
@property
def group_size(self) -> int: ...
@group_size.setter
def group_size(self, size: int) -> None: ...
@property
def timeout(self) -> timedelta: ...
@timeout.setter
def timeout(self, timeout: timedelta) -> None: ...
@property
def group_id(self) -> str: ...
@group_id.setter
def group_id(self, group_id: str) -> None: ...
@property
def global_ranks_in_group(self) -> list[int]: ...
@global_ranks_in_group.setter
def global_ranks_in_group(self, ranks: list[int]) -> None: ...
class Work:
def is_completed(self) -> bool: ...
def is_success(self) -> bool: ...
def exception(self) -> Any: ...
def wait(self, timeout: timedelta = ...) -> bool: ...
def get_future(self) -> Future: ...
def source_rank(self) -> int: ...
def _source_rank(self) -> int: ...
def result(self) -> list[Tensor]: ...
def synchronize(self): ...
def boxed(self) -> ScriptObject: ...
@staticmethod
def unbox(obj: ScriptObject) -> Work: ...
class Backend:
class Options:
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
@property
def backend(self) -> str: ...
@property
def _timeout(self) -> timedelta: ...
@_timeout.setter
def _timeout(self, val: timedelta) -> None: ...
def __init__(
self,
rank: int,
size: int,
) -> None: ...
@property
def supports_splitting(self) -> bool: ...
@property
def options(self) -> Options: ...
def rank(self) -> int: ...
def size(self) -> int: ...
def eager_connect_single_device(self, device: torch.device | None) -> None: ...
def _set_sequence_number_for_group(self) -> None: ...
def _set_default_timeout(self, timeout: timedelta) -> None: ...
class ProcessGroup:
class Options:
def __init__(self, backend: str, timeout: timedelta = ...) -> None: ...
@property
def backend(self) -> str: ...
@property
def _timeout(self) -> timedelta: ...
@_timeout.setter
def _timeout(self, val: timedelta) -> None: ...
class BackendType(Enum):
UNDEFINED = ...
GLOO = ...
NCCL = ...
UCC = ...
MPI = ...
CUSTOM = ...
def __init__(
self,
store: Store,
rank: int,
size: int,
options: Options,
) -> None: ...
def rank(self) -> int: ...
def size(self) -> int: ...
@overload
def broadcast(
self,
tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def broadcast(
self,
tensor: Tensor,
root: int,
) -> Work: ...
@overload
def allreduce(
self,
tensors: list[Tensor],
opts: AllreduceOptions = ...,
) -> Work: ...
@overload
def allreduce(
self,
tensors: list[Tensor],
op=...,
) -> Work: ...
@overload
def allreduce(
self,
tensor: Tensor,
op=...,
) -> Work: ...
def allreduce_coalesced(
self,
tensors: list[Tensor],
opts=...,
) -> Work: ...
def reduce_scatter_tensor_coalesced(
self,
outputTensors: list[Tensor],
inputTensors: list[Tensor],
opts: ReduceScatterOptions | None = None,
) -> Work: ...
@overload
def reduce(
self,
tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def reduce(
self,
tensor: Tensor,
root: int,
op=...,
) -> Work: ...
@overload
def allgather(
self,
output_tensors: list[list[Tensor]],
input_tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def allgather(
self,
output_tensors: list[Tensor],
input_tensor: Tensor,
) -> Work: ...
def _allgather_base(
self,
output: Tensor,
input: Tensor,
opts=...,
) -> Work: ...
def allgather_coalesced(
self,
output_lists: list[list[Tensor]],
input_list: list[Tensor],
opts=...,
) -> Work: ...
def allgather_into_tensor_coalesced(
self,
output_lists: list[Tensor],
input_list: list[Tensor],
opts=...,
) -> Work: ...
@overload
def gather(
self,
output_tensors: list[list[Tensor]],
input_tensors: list[Tensor],
opts=...,
) -> Work: ...
@overload
def gather(
self,
output_tensors: list[Tensor],
input_tensor: Tensor,
root: int,
) -> Work: ...
@overload
def scatter(
self,
output_tensors: list[Tensor],
input_tensors: list[list[Tensor]],
opts=...,
) -> Work: ...
@overload
def scatter(
self,
output_tensor: Tensor,
input_tensors: list[Tensor],
root: int,
) -> Work: ...
@overload
def reduce_scatter(
self,
output_tensors: list[Tensor],
input_tensors: list[list[Tensor]],
opts=...,
) -> Work: ...
@overload
def reduce_scatter(
self,
output_tensors: Tensor,
input_tensor: list[Tensor],
) -> Work: ...
def _reduce_scatter_base(
self,
outputTensor: Tensor,
inputTensor: Tensor,
opts: ReduceScatterOptions | None,
) -> Work: ...
@overload
def alltoall_base(
self,
output_tensor: Tensor,
input_tensor: Tensor,
output_split_sizes: list[int],
input_split_sizes: list[int],
opts=...,
) -> Work: ...
@overload
def alltoall_base(
self,
output: Tensor,
input: Tensor,
output_split_sizes: list[int],
input_split_sizes: list[int],
) -> Work: ...
@overload
def alltoall(
self,
output_tensor: list[Tensor],
input_tensor: list[Tensor],
opts=...,
) -> Work: ...
@overload
def alltoall(
self,
output: list[Tensor],
input: list[Tensor],
) -> Work: ...
def send(
self,
tensors: list[Tensor],
dstRank: int,
tag: int,
) -> Work: ...
def recv(
self,
tensors: list[Tensor],
srcRank: int,
tag: int,
) -> Work: ...
def recv_anysource(self, tensors: list[Tensor], tag: int) -> Work: ...
def barrier(self, opts=...) -> Work: ...
def boxed(self) -> ScriptObject: ...
@staticmethod
def unbox(obj: ScriptObject) -> ProcessGroup: ...
def _start_coalescing(self, device: torch.device) -> None: ...
def _end_coalescing(self, device: torch.device) -> Work: ...
def _get_backend_name(self) -> str: ...
def _backend_id(self, backend_type: BackendType) -> int: ...
@property
def _device_types(self) -> list[torch.device]: ...
def _get_backend(self, device: torch.device) -> Backend: ...
def _register_backend(
self,
device: torch.device,
backend_type: BackendType,
backend: Backend | None,
) -> None: ...
def _set_group_name(self, name: str) -> None: ...
def _set_group_desc(self, desc: str) -> None: ...
def name(self) -> str: ...
def _has_hooks(self) -> bool: ...
def _wait_for_pending_works(self) -> None: ...
def _set_sequence_number_for_group(self) -> None: ...
@property
def bound_device_id(self) -> torch.device | None: ...
@bound_device_id.setter
def bound_device_id(self, device: torch.device | None) -> None: ...
@property
def group_name(self) -> str: ...
@property
def group_desc(self) -> str: ...
class ProcessGroupGloo(Backend):
class Device: ...
class Options(ProcessGroup.Options):
devices: list[ProcessGroupGloo.Device]
threads: int
def __init__(self): ...
def __init__(
self,
store: Store,
rank: int,
size: int,
timeout: timedelta,
) -> None: ...
@staticmethod
def create_device(hostname="", interface="") -> Device: ...
@staticmethod
def create_default_device() -> Device: ...
def _set_default_timeout(self, timeout) -> None: ...
class _ProcessGroupWrapper(Backend):
def __init__(self, pg: Backend, gloo_pg: ProcessGroupGloo) -> None: ...
wrapped_pg: Backend
class ProcessGroupNCCL(Backend):
class NCCLConfig:
blocking: int
cga_cluster_size: int
min_ctas: int
max_ctas: int
class Options(ProcessGroup.Options):
config: ProcessGroupNCCL.NCCLConfig
is_high_priority_stream: bool
split_from: ProcessGroupNCCL
split_color: int
global_ranks_in_group: list[int]
group_name: str
def __init__(self, is_high_priority_stream: bool = False): ...
def __init__(
self,
store: Store,
rank: int,
size: int,
options: Options,
) -> None: ...
def _group_start(self) -> None: ...
def _group_end(self) -> None: ...
def _set_default_timeout(self, timeout) -> None: ...
def _shutdown(self) -> None: ...
def perform_nocolor_split(self, device: torch.device) -> None: ...
def comm_split_count(self) -> int: ...
def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ...
@property
def uid(self) -> int: ...
@property
def options(self) -> Options: ... # type: ignore[override]
class ProcessGroupUCC(Backend):
def __init__(
self,
store: Store,
rank: int,
size: int,
timeout: timedelta,
) -> None: ...
class ProcessGroupMPI(Backend):
def __init__(
self,
rank: int,
size: int,
pgComm: int,
) -> None: ...
@staticmethod
def create(ranks: list[int]) -> ProcessGroupMPI: ...
def _compute_bucket_assignment_by_size(
tensors: list[Tensor],
bucket_size_limits: list[int],
expect_sparse_gradient: list[bool] = ...,
tensor_indices: list[int] = ...,
) -> tuple[list[list[int]], list[int]]: ...
def _broadcast_coalesced(
process_group: ProcessGroup,
tensors: list[Tensor],
buffer_size: int,
src: int,
): ...
def _test_python_store(store: Store): ...
def _verify_params_across_processes(
process_group: ProcessGroup,
params: list[Tensor],
logger: Logger | None,
): ...
def _make_nccl_premul_sum(factor: float | list[Tensor]) -> ReduceOp: ...
def _register_process_group(
group_name: str,
process_group: ProcessGroup,
) -> None: ...
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
def _unregister_all_process_groups() -> None: ...
def _unregister_process_group(group_name: str) -> None: ...
class _SymmetricMemory:
@staticmethod
def set_group_info(
group_name: str,
rank: int,
world_size: int,
store: Store,
) -> None: ...
@staticmethod
def empty_strided_p2p(
size: torch.types._size,
stride: torch.types._size,
dtype: torch.dtype,
device: torch.device,
group_name: str,
) -> torch.Tensor: ...
@property
def rank(self) -> int: ...
@property
def world_size(self) -> int: ...
@staticmethod
def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ...
def get_buffer(
self,
rank: int,
sizes: torch.types._size,
dtype: torch.dtype,
storage_offset: int | None = 0,
) -> torch.Tensor: ...
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
class ProcessGroupCudaP2P(Backend):
class Options:
nccl_options: Optional[ProcessGroupNCCL.Options]
buffer_size: Optional[int]
def __init__(self) -> None: ...
def __init__(
self,
store: Store,
rank: int,
size: int,
options: ProcessGroupCudaP2P.Options,
) -> None: ...
def is_p2p_available(self) -> bool: ...
def get_buffer_size(self) -> int: ...
def stream(self) -> torch.cuda.Stream: ...
def intra_node_barrier(self) -> Work: ...
def get_p2p_buffer(
self,
rank: int,
sizes: torch.Size,
dtype: torch.dtype,
storage_offset: Optional[int] = 0,
) -> torch.Tensor: ...
def _shutdown(self) -> None: ...

View File

@ -0,0 +1,188 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from datetime import timedelta
from typing import Any, Generic, overload, TypeVar
import torch
from torch._C import Future
from torch._C._autograd import ProfilerEvent
from torch._C._distributed_c10d import Store
from torch._C._profiler import ProfilerConfig
# This module is defined in torch/csrc/distributed/rpc/init.cpp
_DEFAULT_INIT_METHOD: str
_DEFAULT_NUM_WORKER_THREADS: int
_UNSET_RPC_TIMEOUT: float
_DEFAULT_RPC_TIMEOUT_SEC: float
_T = TypeVar("_T")
class RpcBackendOptions:
rpc_timeout: float
init_method: str
def __init__(
self,
rpc_timeout: float = ...,
init_method: str = ...,
) -> None: ...
class WorkerInfo:
def __init__(self, name: str, worker_id: int) -> None: ...
@property
def name(self) -> str: ...
@property
def id(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
class RpcAgent:
def join(self, shutdown: bool = False, timeout: float = 0): ...
def sync(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> list[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
def get_debug_info(self) -> dict[str, str]: ...
def get_metrics(self) -> dict[str, str]: ...
class PyRRef(Generic[_T]):
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
def is_owner(self) -> bool: ...
def confirmed_by_owner(self) -> bool: ...
def owner(self) -> WorkerInfo: ...
def owner_name(self) -> str: ...
def to_here(self, timeout: float = ...) -> _T: ...
def local_value(self) -> Any: ...
def rpc_sync(self, timeout: float = ...) -> Any: ...
def rpc_async(self, timeout: float = ...) -> Any: ...
def remote(self, timeout: float = ...) -> Any: ...
def _serialize(self) -> tuple: ...
@staticmethod
def _deserialize(tp: tuple) -> PyRRef: ...
def _get_type(self) -> type[_T]: ...
def _get_future(self) -> Future[_T]: ...
def _get_profiling_future(self) -> Future[_T]: ...
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: dict[str, dict[torch.device, torch.device]]
devices: list[torch.device]
def __init__(
self,
num_worker_threads: int,
_transports: list | None,
_channels: list | None,
rpc_timeout: float = ...,
init_method: str = ...,
device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006
devices: list[torch.device] = [], # noqa: B006
) -> None: ...
def _set_device_map(
self,
to: str,
device_map: dict[torch.device, torch.device],
): ...
class TensorPipeAgent(RpcAgent):
def __init__(
self,
store: Store,
name: str,
worker_id: int,
world_size: int | None,
opts: _TensorPipeRpcBackendOptionsBase,
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
devices: list[torch.device],
) -> None: ...
def join(self, shutdown: bool = False, timeout: float = 0): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> list[WorkerInfo]: ...
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
def _update_group_membership(
self,
worker_info: WorkerInfo,
my_devices: list[torch.device],
reverse_device_map: dict[str, dict[torch.device, torch.device]],
is_join: bool,
): ...
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
@property
def is_static_group(self) -> bool: ...
@property
def store(self) -> Store: ...
def _is_current_rpc_agent_set() -> bool: ...
def _get_current_rpc_agent() -> RpcAgent: ...
def _set_and_start_rpc_agent(agent: RpcAgent): ...
def _reset_current_rpc_agent(): ...
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
def _destroy_rref_context(ignoreRRefLeak: bool): ...
def _rref_context_get_debug_info() -> dict[str, str]: ...
def _cleanup_python_rpc_handler(): ...
def _invoke_rpc_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any,
): ...
def _invoke_rpc_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: list[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_rpc_torchscript(
dstWorkerName: str,
qualifiedNameStr: str,
argsTuple: tuple,
kwargsDict: dict,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any,
): ...
def _invoke_remote_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: list[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_torchscript(
dstWorkerName: WorkerInfo,
qualifiedNameStr: str,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
*args: Any,
**kwargs: Any,
): ...
def get_rpc_timeout() -> float: ...
def enable_gil_profiling(flag: bool): ...
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
class RemoteProfilerManager:
@staticmethod
def set_current_profiling_key(key: str): ...
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
def _set_profiler_node_id(default_node_id: int): ...
def _enable_jit_rref_pickle(): ...
def _disable_jit_rref_pickle(): ...

View File

@ -0,0 +1,32 @@
import torch
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase, TensorPipeAgent
# This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
def __init__(
self,
num_worker_threads: int,
rpc_timeout: float,
init_method: str,
messages_to_fail: list[str],
messages_to_delay: dict[str, float],
num_fail_sends: int,
) -> None: ...
num_send_recv_threads: int
messages_to_fail: list[str]
messages_to_delay: dict[str, float]
num_fail_sends: int
class FaultyTensorPipeAgent(TensorPipeAgent):
def __init__(
self,
store: Store,
name: str,
rank: int,
world_size: int,
options: FaultyTensorPipeRpcBackendOptions,
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
devices: list[torch.device],
) -> None: ...

View File

@ -0,0 +1,11 @@
from typing import AnyStr
from torch import Tensor
class UndefinedGrad:
def __init__(self) -> None: ...
def __call__(self, *inputs: Tensor) -> list[Tensor]: ...
class DelayedError:
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
def __call__(self, inputs: list[Tensor]) -> list[Tensor]: ...

View File

@ -0,0 +1,83 @@
# mypy: allow-untyped-defs
from enum import Enum
from torch import Tensor
# Defined in torch/csrc/functorch/init.cpp
def _set_dynamic_layer_keys_included(included: bool) -> None: ...
def get_unwrapped(tensor: Tensor) -> Tensor: ...
def is_batchedtensor(tensor: Tensor) -> bool: ...
def is_functionaltensor(tensor: Tensor) -> bool: ...
def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
def is_legacy_batchedtensor(tensor: Tensor) -> bool: ...
def maybe_get_bdim(tensor: Tensor) -> int: ...
def maybe_get_level(tensor: Tensor) -> int: ...
def maybe_current_level() -> int | None: ...
def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
def _unwrap_batched(tensor: Tensor, level: int) -> tuple[Tensor, int | None]: ...
def current_level() -> int: ...
def count_jvp_interpreters() -> int: ...
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
def get_single_level_autograd_function_allowed() -> bool: ...
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
def _vmap_decrement_nesting() -> int: ...
def _grad_increment_nesting() -> int: ...
def _grad_decrement_nesting() -> int: ...
def _jvp_increment_nesting() -> int: ...
def _jvp_decrement_nesting() -> int: ...
# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):
Torch: TransformType = ...
Vmap: TransformType = ...
Grad: TransformType = ...
Jvp: TransformType = ...
Functionalize: TransformType = ...
class RandomnessType(Enum):
Error: TransformType = ...
Same: TransformType = ...
Different: TransformType = ...
class CInterpreter:
def key(self) -> TransformType: ...
def level(self) -> int: ...
class CGradInterpreterPtr:
def __init__(self, interpreter: CInterpreter) -> None: ...
def lift(self, Tensor) -> Tensor: ...
def prevGradMode(self) -> bool: ...
class CJvpInterpreterPtr:
def __init__(self, interpreter: CInterpreter) -> None: ...
def lift(self, Tensor) -> Tensor: ...
def prevFwdGradMode(self) -> bool: ...
class CFunctionalizeInterpreterPtr:
def __init__(self, interpreter: CInterpreter) -> None: ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def functionalizeAddBackViews(self) -> bool: ...
class CVmapInterpreterPtr:
def __init__(self, interpreter: CInterpreter) -> None: ...
def key(self) -> TransformType: ...
def level(self) -> int: ...
def batchSize(self) -> int: ...
def randomness(self) -> RandomnessType: ...
class DynamicLayer: ...
def get_dynamic_layer_stack_depth() -> int: ...
def get_interpreter_stack() -> list[CInterpreter]: ...
def peek_interpreter_stack() -> CInterpreter: ...
def pop_dynamic_layer_stack() -> DynamicLayer: ...
def pop_dynamic_layer_stack_and_undo_to_depth(int) -> None: ...
def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...

View File

@ -0,0 +1,4 @@
# Defined in torch/csrc/instruction_counter/Module.cpp
def start() -> int: ...
def end(id: int) -> int: ...

View File

@ -0,0 +1,5 @@
# Defined in torch/csrc/itt.cpp
def is_available() -> None: ...
def rangePush(message: str) -> None: ...
def rangePop() -> None: ...
def mark(message: str) -> None: ...

View File

@ -0,0 +1,27 @@
# mypy: allow-untyped-defs
from torch import Tensor
# defined in torch/csrc/lazy/python/init.cpp
def _mark_step(device: str, devices: list[str], wait: bool): ...
def _wait_device_ops(devices: list[str]): ...
def _reset_metrics(): ...
def _counter_names() -> list[str]: ...
def _counter_value(name: str) -> int: ...
def _metrics_report() -> str: ...
def _get_graph_hash(tensors: list[Tensor]) -> str: ...
def _sync_multi(
tensors: list[Tensor],
devices: list[str],
wait: bool = True,
sync_ltc_data: bool = True,
): ...
def _get_tensor_id(tensor: Tensor) -> int: ...
def _get_tensors_text(tensors: list[Tensor]) -> str: ...
def _get_tensors_dot(tensors: list[Tensor]) -> str: ...
def _get_tensors_backend(tensors: list[Tensor]) -> str: ...
def _get_force_fallback() -> str: ...
def _set_force_fallback(newval: str): ...
def _clear_ir_cache(): ...
def _dump_ir_cache(filename: str): ...
def _set_reuse_ir(val: bool): ...
def _get_default_device_type(): ...

View File

@ -0,0 +1,12 @@
# mypy: allow-untyped-defs
# defined in torch/csrc/lazy/python/init.cpp
from typing import Any
from torch import Tensor
def _init(): ...
def _get_tensors_ts_device_data_node(
tensors: list[Tensor],
) -> tuple[list[int], list[Any]]: ...
def _run_cached_graph(hash_str: str, graph_inputs: list[Any]) -> list[Tensor]: ...

View File

@ -0,0 +1,44 @@
# Defined in torch/csrc/monitor/python_init.cpp
import datetime
from enum import Enum
from typing import Callable
class Aggregation(Enum):
VALUE = ...
MEAN = ...
COUNT = ...
SUM = ...
MAX = ...
MIN = ...
class Stat:
name: str
count: int
def __init__(
self,
name: str,
aggregations: list[Aggregation],
window_size: int,
max_samples: int = -1,
) -> None: ...
def add(self, v: float) -> None: ...
def get(self) -> dict[Aggregation, float]: ...
class Event:
name: str
timestamp: datetime.datetime
data: dict[str, int | float | bool | str]
def __init__(
self,
name: str,
timestamp: datetime.datetime,
data: dict[str, int | float | bool | str],
) -> None: ...
def log_event(e: Event) -> None: ...
class EventHandlerHandle: ...
def register_event_handler(handler: Callable[[Event], None]) -> EventHandlerHandle: ...
def unregister_event_handler(handle: EventHandlerHandle) -> None: ...

View File

@ -0,0 +1,89 @@
# @generated by tools/pyi/gen_pyi.py from torch/_C/_nn.pyi.in
# mypy: disable-error-code="type-arg"
from typing import List, Literal, Optional, overload, Sequence, Tuple, Union
from torch import memory_format, Tensor
from torch.types import _bool, _device, _dtype, _int, _size
# Defined in tools/autograd/templates/python_nn_functions.cpp
def adaptive_max_pool2d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
def adaptive_max_pool3d(input: Tensor, output_size: Union[_int, _size]) -> Tuple[Tensor, Tensor]: ...
def avg_pool2d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
def avg_pool3d(input: Tensor, kernel_size: Union[_int, _size], stride: Optional[Union[_int, _size]] = None, padding: Union[_int, _size] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> Tensor: ...
def elu_(input: Tensor, alpha: float = ...) -> Tensor: ...
def fractional_max_pool2d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
def fractional_max_pool3d(input: Tensor, kernel_size: Union[_int, _size], output_size: Union[_int, _size], _random_samples: Tensor) -> Tuple[Tensor, Tensor]: ...
def gelu(input: Tensor, approximate: str = ...) -> Tensor: ...
def hardsigmoid(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...
def hardtanh(input: Tensor, min_val: float = ..., max_val: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
def hardtanh_(input: Tensor, min_val: float = ..., max_val: float = ...) -> Tensor: ...
def leaky_relu(input: Tensor, negative_slope: float = ..., *, out: Optional[Tensor] = None) -> Tensor: ...
def leaky_relu_(input: Tensor, negative_slope: float = ...) -> Tensor: ...
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: ...
def log_sigmoid(input: Tensor) -> Tensor: ...
def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ...
def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: Optional[float] = None) -> Tensor: ...
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False) -> Tensor: ...
def softplus(input: Tensor, beta: float = ..., threshold: float = ...) -> Tensor: ...
def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ...
# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
def mkldnn_reorder_conv2d_weight(
self: Tensor,
padding: List,
stride: List,
dilatation: List,
groups: int,
) -> Tensor: ...
def mkldnn_reorder_conv3d_weight(
self: Tensor,
padding: List,
stride: List,
dilatation: List,
groups: int,
) -> Tensor: ...
# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
# Defined at tools/autograd/templates/python_nn_functions.cpp
@overload
def _parse_to(
device: _device,
dtype: _dtype,
non_blocking: _bool,
copy: _bool,
*,
memory_format: memory_format,
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
@overload
def _parse_to(
dtype: _dtype,
non_blocking: _bool,
copy: _bool,
*,
memory_format: memory_format,
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
@overload
def _parse_to(
tensor: Tensor,
non_blocking: _bool,
copy: _bool,
*,
memory_format: memory_format,
) -> Tuple[_device, _dtype, _bool, memory_format]: ...
# Defined in aten/src/ATen/native/PackedSequence.cpp
def pad_sequence(
sequences: Union[List[Tensor], Tuple[Tensor, ...]],
batch_first: bool = False,
padding_value: float = 0.0,
padding_side: Union[Literal["left", "right"], str] = "right",
) -> Tensor: ...
def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...

View File

@ -0,0 +1,7 @@
# mypy: allow-untyped-defs
# Defined in torch/csrc/cuda/shared/nvtx.cpp
def rangePushA(message: str) -> int: ...
def rangePop() -> int: ...
def rangeStartA(message: str) -> int: ...
def rangeEnd(int) -> None: ...
def markA(message: str) -> None: ...

View File

@ -0,0 +1,39 @@
# Defined in torch/csrc/onnx/init.cpp
from enum import Enum
PRODUCER_VERSION: str
class TensorProtoDataType(Enum):
UNDEFINED = ...
FLOAT = ...
UINT8 = ...
INT8 = ...
UINT16 = ...
INT16 = ...
INT32 = ...
INT64 = ...
STRING = ...
BOOL = ...
FLOAT16 = ...
DOUBLE = ...
UINT32 = ...
UINT64 = ...
COMPLEX64 = ...
COMPLEX128 = ...
BFLOAT16 = ...
FLOAT8E5M2 = ...
FLOAT8E4M3FN = ...
FLOAT8E5M2FNUZ = ...
FLOAT8E4M3FNUZ = ...
class OperatorExportTypes(Enum):
ONNX = ...
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
ONNX_FALLTHROUGH = ...
class TrainingMode(Enum):
EVAL = ...
PRESERVE = ...
TRAINING = ...

View File

@ -0,0 +1,244 @@
from enum import Enum
from typing import Any, Literal
from typing_extensions import TypeAlias
from torch._C import device, dtype, layout
# defined in torch/csrc/profiler/python/init.cpp
class RecordScope(Enum):
FUNCTION = ...
BACKWARD_FUNCTION = ...
TORCHSCRIPT_FUNCTION = ...
KERNEL_FUNCTION_DTYPE = ...
CUSTOM_CLASS = ...
BUILD_FEATURE = ...
LITE_INTERPRETER = ...
USER_SCOPE = ...
STATIC_RUNTIME_OP = ...
STATIC_RUNTIME_MODEL = ...
class ProfilerState(Enum):
Disable = ...
CPU = ...
CUDA = ...
NVTX = ...
ITT = ...
KINETO = ...
KINETO_GPU_FALLBACK = ...
KINETO_PRIVATEUSE1_FALLBACK = ...
KINETO_PRIVATEUSE1 = ...
class ActiveProfilerType(Enum):
NONE = ...
LEGACY = ...
KINETO = ...
NVTX = ...
ITT = ...
class ProfilerActivity(Enum):
CPU = ...
CUDA = ...
XPU = ...
MTIA = ...
PrivateUse1 = ...
class _EventType(Enum):
TorchOp = ...
Backend = ...
Allocation = ...
OutOfMemory = ...
PyCall = ...
PyCCall = ...
Kineto = ...
class _ExperimentalConfig:
def __init__(
self,
profiler_metrics: list[str] = ...,
profiler_measure_per_kernel: bool = ...,
verbose: bool = ...,
performance_events: list[str] = ...,
enable_cuda_sync_events: bool = ...,
) -> None: ...
class ProfilerConfig:
def __init__(
self,
state: ProfilerState,
report_input_shapes: bool,
profile_memory: bool,
with_stack: bool,
with_flops: bool,
with_modules: bool,
experimental_config: _ExperimentalConfig,
) -> None: ...
class _ProfilerEvent:
start_tid: int
start_time_ns: int
children: list[_ProfilerEvent]
# TODO(robieta): remove in favor of `self.typed`
extra_fields: (
_ExtraFields_TorchOp
| _ExtraFields_Backend
| _ExtraFields_Allocation
| _ExtraFields_OutOfMemory
| _ExtraFields_PyCall
| _ExtraFields_PyCCall
| _ExtraFields_Kineto
)
@property
def typed(
self,
) -> (
tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp]
| tuple[Literal[_EventType.Backend], _ExtraFields_Backend]
| tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation]
| tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory]
| tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall]
| tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall]
| tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto]
): ...
@property
def name(self) -> str: ...
@property
def tag(self) -> _EventType: ...
@property
def id(self) -> int: ...
@property
def parent(self) -> _ProfilerEvent | None: ...
@property
def correlation_id(self) -> int: ...
@property
def end_time_ns(self) -> int: ...
@property
def duration_time_ns(self) -> int: ...
class _TensorMetadata:
impl_ptr: int | None
storage_data_ptr: int | None
id: int | None
@property
def allocation_id(self) -> int | None: ...
@property
def layout(self) -> layout: ...
@property
def device(self) -> device: ...
@property
def dtype(self) -> dtype: ...
@property
def sizes(self) -> list[int]: ...
@property
def strides(self) -> list[int]: ...
Scalar: TypeAlias = int | float | bool | complex
Input: TypeAlias = _TensorMetadata | list[_TensorMetadata] | Scalar | None
class _ExtraFields_TorchOp:
name: str
sequence_number: int
allow_tf32_cublas: bool
@property
def inputs(self) -> list[Input]: ...
@property
def scope(self) -> RecordScope: ...
class _ExtraFields_Backend: ...
class _ExtraFields_Allocation:
ptr: int
id: int | None
alloc_size: int
total_allocated: int
total_reserved: int
@property
def allocation_id(self) -> int | None: ...
@property
def device(self) -> device: ...
class _ExtraFields_OutOfMemory: ...
class _PyFrameState:
line_number: int
function_name: str
@property
def file_name(self) -> str: ...
class _NNModuleInfo:
@property
def self_ptr(self) -> int: ...
@property
def cls_ptr(self) -> int: ...
@property
def cls_name(self) -> str: ...
@property
def parameters(
self,
) -> list[tuple[str, _TensorMetadata, _TensorMetadata | None]]: ...
class _OptimizerInfo:
@property
def parameters(
self,
) -> list[
tuple[
# Parameter
_TensorMetadata,
#
# Gradient (if present during optimizer.step())
_TensorMetadata | None,
#
# Optimizer state for Parameter as (name, tensor) pairs
list[tuple[str, _TensorMetadata]],
]
]: ...
class _ExtraFields_PyCCall:
@property
def caller(self) -> _PyFrameState: ...
class _ExtraFields_PyCall:
@property
def callsite(self) -> _PyFrameState: ...
@property
def caller(self) -> _PyFrameState: ...
@property
def module(self) -> _NNModuleInfo | None: ...
@property
def optimizer(self) -> _OptimizerInfo | None: ...
class _ExtraFields_Kineto: ...
def _add_execution_trace_observer(output_file_path: str) -> bool: ...
def _remove_execution_trace_observer() -> None: ...
def _enable_execution_trace_observer() -> None: ...
def _disable_execution_trace_observer() -> None: ...
def _set_record_concrete_inputs_enabled_val(val: bool) -> None: ...
def _set_fwd_bwd_enabled_val(val: bool) -> None: ...
def _set_cuda_sync_enabled_val(val: bool) -> None: ...
class CapturedTraceback: ...
def gather_traceback(python: bool, script: bool, cpp: bool) -> CapturedTraceback: ...
# The Dict has name, filename, line
def symbolize_tracebacks(
to_symbolize: list[CapturedTraceback],
) -> list[list[dict[str, str]]]: ...
class _RecordFunctionFast:
def __init__(
self,
name: str,
input_values: list | tuple | None = None,
keyword_values: dict | None = None,
) -> None: ...
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...

View File

@ -0,0 +1,3 @@
# Defined in torch/csrc/utils/verbose.cpp
def mkl_set_verbose(enable: int) -> int: ...
def mkldnn_set_verbose(level: int) -> int: ...

View File

@ -0,0 +1,31 @@
"""
This makes the functions in torch._C._VariableFunctions available as
torch._VF.<funcname>
without mypy being able to find them.
A subset of those functions are mapped to ATen functions in
torch/jit/_builtins.py
See https://github.com/pytorch/pytorch/issues/21478 for the reason for
introducing torch._VF
"""
import sys
import types
import torch
class VFModule(types.ModuleType):
vf: types.ModuleType
def __init__(self, name: str):
super().__init__(name)
self.vf = torch._C._VariableFunctions
def __getattr__(self, name: str) -> object:
return getattr(self.vf, name)
sys.modules[__name__] = VFModule(__name__)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
# mypy: allow-untyped-defs
import torch
def show():
"""
Return a human-readable string with descriptions of the
configuration of PyTorch.
"""
return torch._C._show_config()
# TODO: In principle, we could provide more structured version/config
# information here. For now only CXX_FLAGS is exposed, as Timer
# uses them.
def _cxx_flags():
"""Returns the CXX_FLAGS used when building PyTorch."""
return torch._C._cxx_flags()
def parallel_info():
r"""Returns detailed string with parallelization settings"""
return torch._C._parallel_info()

View File

@ -0,0 +1,75 @@
_overwrite_module_params_on_conversion: bool = False
_swap_module_params_on_conversion: bool = False
def set_overwrite_module_params_on_conversion(value: bool) -> None:
"""
Sets whether to assign new tensors to the parameters instead of changing the
existing parameters in-place when converting an ``nn.Module``.
When enabled, the following methods will assign new parameters to the module:
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
#. :meth:`nn.Module.to`
#. :meth:`nn.Module.to_empty`
Args:
value (bool): Whether to assign new tensors or not.
"""
global _overwrite_module_params_on_conversion
_overwrite_module_params_on_conversion = value
def get_overwrite_module_params_on_conversion() -> bool:
"""
Returns whether to assign new tensors to the parameters instead of changing the
existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
"""
return _overwrite_module_params_on_conversion
def set_swap_module_params_on_conversion(value: bool) -> None:
"""
Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
change the existing parameters in-place when converting an ``nn.Module`` and instead
of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
.. note::
This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
When enabled, the following methods will swap the existing parameters in-place:
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
#. :meth:`nn.Module.to`
#. :meth:`nn.Module.to_empty`
#. :meth:`nn.Module.load_state_dict`
The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
#. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
:meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
#. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
#. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
with ``res``
Args:
value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
"""
global _swap_module_params_on_conversion
_swap_module_params_on_conversion = value
def get_swap_module_params_on_conversion() -> bool:
"""
Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
"""
return _swap_module_params_on_conversion

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,667 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) 2005-2010 ActiveState Software Inc.
# Copyright (c) 2013 Eddy Petrișor
# flake8: noqa
"""
This file is directly from
https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py
The license of https://github.com/ActiveState/appdirs copied below:
# This is the MIT license
Copyright (c) 2010 ActiveState Software Inc.
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
"""Utilities for determining application-specific dirs.
See <https://github.com/ActiveState/appdirs> for details and usage.
"""
# Dev Notes:
# - MSDN on where to store app data files:
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version__ = "1.4.4"
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
import os
import sys
unicode = str
if sys.platform.startswith("java"):
import platform
os_name = platform.java_ver()[3][0]
if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc.
system = "win32"
elif os_name.startswith("Mac"): # "Mac OS X", etc.
system = "darwin"
else: # "Linux", "SunOS", "FreeBSD", etc.
# Setting this to "linux2" is not ideal, but only Windows or Mac
# are actually checked for and the rest of the module expects
# *sys.platform* style strings.
system = "linux2"
else:
system = sys.platform
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user data directories are:
Mac OS X: ~/Library/Application Support/<AppName>
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
That means, by default "~/.local/share/<AppName>".
"""
if system == "win32":
if appauthor is None:
appauthor = appname
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
path = os.path.normpath(_get_win_folder(const))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == "darwin":
path = os.path.expanduser("~/Library/Application Support/")
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of data dirs should be
returned. By default, the first item from XDG_DATA_DIRS is
returned, or '/usr/local/share/<AppName>',
if XDG_DATA_DIRS is not set
Typical site data directories are:
Mac OS X: /Library/Application Support/<AppName>
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
For Unix, this is using the $XDG_DATA_DIRS[0] default.
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == "darwin":
path = os.path.expanduser("/Library/Application Support")
if appname:
path = os.path.join(path, appname)
else:
# XDG default for $XDG_DATA_DIRS
# only first, if multipath is False
path = os.getenv(
"XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"])
)
pathlist = [
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
if appname and version:
path = os.path.join(path, version)
return path
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific config dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user config directories are:
Mac OS X: ~/Library/Preferences/<AppName>
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
That means, by default "~/.config/<AppName>".
"""
if system == "win32":
path = user_data_dir(appname, appauthor, None, roaming)
elif system == "darwin":
path = os.path.expanduser("~/Library/Preferences/")
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of config dirs should be
returned. By default, the first item from XDG_CONFIG_DIRS is
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
Typical site config directories are:
Mac OS X: same as site_data_dir
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
$XDG_CONFIG_DIRS
Win *: same as site_data_dir
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system == "win32":
path = site_data_dir(appname, appauthor)
if appname and version:
path = os.path.join(path, version)
elif system == "darwin":
path = os.path.expanduser("/Library/Preferences")
if appname:
path = os.path.join(path, appname)
else:
# XDG default for $XDG_CONFIG_DIRS
# only first, if multipath is False
path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg")
pathlist = [
os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)
]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific cache dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Cache" to the base app data dir for Windows. See
discussion below.
Typical user cache directories are:
Mac OS X: ~/Library/Caches/<AppName>
Unix: ~/.cache/<AppName> (XDG default)
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
On Windows the only suggestion in the MSDN docs is that local settings go in
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
app data dir (the default returned by `user_data_dir` above). Apps typically
put cache data somewhere *under* the given dir here. Some examples:
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
...\Acme\SuperApp\Cache\1.0
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
This can be disabled with the `opinion=False` option.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
if opinion:
path = os.path.join(path, "Cache")
elif system == "darwin":
path = os.path.expanduser("~/Library/Caches")
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific state dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user state directories are:
Mac OS X: same as user_data_dir
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
to extend the XDG spec and support $XDG_STATE_HOME.
That means, by default "~/.local/state/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv("XDG_STATE_HOME", os.path.expanduser("~/.local/state"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific log dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Logs" to the base app data dir for Windows, and "log" to the
base cache dir for Unix. See discussion below.
Typical user log directories are:
Mac OS X: ~/Library/Logs/<AppName>
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
On Windows the only suggestion in the MSDN docs is that local settings
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
examples of what some windows apps use for a logs dir.)
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
value for Windows and appends "log" to the user cache dir for Unix.
This can be disabled with the `opinion=False` option.
"""
if system == "darwin":
path = os.path.join(os.path.expanduser("~/Library/Logs"), appname)
elif system == "win32":
path = user_data_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "Logs")
else:
path = user_cache_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "log")
if appname and version:
path = os.path.join(path, version)
return path
class AppDirs(object):
"""Convenience wrapper for getting application dirs."""
def __init__(
self, appname=None, appauthor=None, version=None, roaming=False, multipath=False
):
self.appname = appname
self.appauthor = appauthor
self.version = version
self.roaming = roaming
self.multipath = multipath
@property
def user_data_dir(self):
return user_data_dir(
self.appname, self.appauthor, version=self.version, roaming=self.roaming
)
@property
def site_data_dir(self):
return site_data_dir(
self.appname, self.appauthor, version=self.version, multipath=self.multipath
)
@property
def user_config_dir(self):
return user_config_dir(
self.appname, self.appauthor, version=self.version, roaming=self.roaming
)
@property
def site_config_dir(self):
return site_config_dir(
self.appname, self.appauthor, version=self.version, multipath=self.multipath
)
@property
def user_cache_dir(self):
return user_cache_dir(self.appname, self.appauthor, version=self.version)
@property
def user_state_dir(self):
return user_state_dir(self.appname, self.appauthor, version=self.version)
@property
def user_log_dir(self):
return user_log_dir(self.appname, self.appauthor, version=self.version)
# ---- internal support stuff
def _get_win_folder_from_registry(csidl_name):
"""This is a fallback technique at best. I'm not sure if using the
registry for this guarantees us the correct answer for all CSIDL_*
names.
"""
import winreg as _winreg
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
"CSIDL_COMMON_APPDATA": "Common AppData",
"CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name]
key = _winreg.OpenKey(
_winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
)
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
return dir
def _get_win_folder_with_pywin32(csidl_name):
from win32com.shell import shell, shellcon
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
# Try to make this a unicode path because SHGetFolderPath does
# not return unicode strings when there is unicode data in the
# path.
try:
dir = unicode(dir)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
try:
import win32api
dir = win32api.GetShortPathName(dir)
except ImportError:
pass
except UnicodeError:
pass
return dir
def _get_win_folder_with_ctypes(csidl_name):
import ctypes
csidl_const = {
"CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35,
"CSIDL_LOCAL_APPDATA": 28,
}[csidl_name]
buf = ctypes.create_unicode_buffer(1024)
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in buf:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf2 = ctypes.create_unicode_buffer(1024)
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
buf = buf2
return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros("c", buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(
None,
getattr(win32.ShlObj, csidl_name),
None,
win32.ShlObj.SHGFP_TYPE_CURRENT,
buf,
)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros("c", buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
if system == "win32":
try:
import win32com.shell
_get_win_folder = _get_win_folder_with_pywin32
except ImportError:
try:
from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError:
try:
import com.sun.jna
_get_win_folder = _get_win_folder_with_jna
except ImportError:
_get_win_folder = _get_win_folder_from_registry
# ---- self test code
if __name__ == "__main__":
appname = "MyApp"
appauthor = "MyCompany"
props = (
"user_data_dir",
"user_config_dir",
"user_cache_dir",
"user_state_dir",
"user_log_dir",
"site_data_dir",
"site_config_dir",
)
print(f"-- app dirs {__version__} --")
print("-- app dirs (with optional 'version')")
dirs = AppDirs(appname, appauthor, version="1.0")
for prop in props:
print(f"{prop}: {getattr(dirs, prop)}")
print("\n-- app dirs (without optional 'version')")
dirs = AppDirs(appname, appauthor)
for prop in props:
print(f"{prop}: {getattr(dirs, prop)}")
print("\n-- app dirs (without optional 'appauthor')")
dirs = AppDirs(appname)
for prop in props:
print(f"{prop}: {getattr(dirs, prop)}")
print("\n-- app dirs (with disabled 'appauthor')")
dirs = AppDirs(appname, appauthor=False)
for prop in props:
print(f"{prop}: {getattr(dirs, prop)}")

View File

@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Generic, TypeVar
import torch
__all__ = ['Await']
W = TypeVar("W")
class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef]
pass
class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta):
r"""
Wrapper around a ``torch._C.Await`` which encapsulates delayed execution
of a callable. All manipulations happen with functions ``torch.jit._awaitable``,
``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``.
Torch scriptable manipulations:
``torch.jit._awaitable(func, *args)``
Creates ``Await[W]`` object, where W is return type of func.
Returns:
``torch.jit._awaitable_wait(Await[W])``
Returns the result of the function, specified at ``_awaitable``, with specified arguments.
Returns:
The result of type ``W`` of the function call. The result is owned by ``Await[W]``
and returned on all following ``_awaitable_wait`` calls.
``torch.jit._awaitable_nowait(W)``
Returns:
Trivial ``Await[W]`` with specified result.
Only in eager mode:
``fn() -> Callable[Tuple[Any], W]``
Returns:
Specified at ``_awaitable`` python function ``func``.
``args() -> Tuple[Any]``
Returns:
Specified at ``_awaitable`` python args.
``is_nowait() -> _bool``
Returns:
``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`).
In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``,
``_awaitable_wait()`` call will be transparently added.
"""

View File

@ -0,0 +1,56 @@
# mypy: allow-untyped-defs
import types
import torch._C
class _ClassNamespace(types.ModuleType):
def __init__(self, name):
super().__init__("torch.classes" + name)
self.name = name
def __getattr__(self, attr):
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
if proxy is None:
raise RuntimeError(f"Class {self.name}.{attr} not registered!")
return proxy
class _Classes(types.ModuleType):
__file__ = "_classes.py"
def __init__(self) -> None:
super().__init__("torch.classes")
def __getattr__(self, name):
namespace = _ClassNamespace(name)
setattr(self, name, namespace)
return namespace
@property
def loaded_libraries(self):
return torch.ops.loaded_libraries
def load_library(self, path):
"""
Loads a shared library from the given path into the current process.
The library being loaded may run global initialization code to register
custom classes with the PyTorch JIT runtime. This allows dynamically
loading custom classes. For this, you should compile your class
and the static registration code into a shared library object, and then
call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
shared object.
After the library is loaded, it is added to the
``torch.classes.loaded_libraries`` attribute, a set that may be inspected
for the paths of all libraries loaded using this function.
Args:
path (str): A path to a shared library to load.
"""
torch.ops.load_library(path)
# The classes "namespace"
classes = _Classes()

View File

@ -0,0 +1,38 @@
# mypy: allow-untyped-defs
"""
APIs related to torch.compile which lazily import torch._dynamo to avoid
circular dependencies.
"""
import functools
def _disable_dynamo(fn=None, recursive=True):
"""
This API should be only used inside torch, external users should still use
torch._dynamo.disable. The main goal of this API is to avoid circular
imports issues that is common while using _dynamo.disable inside torch
itself.
This API avoids it by lazily importing torch._dynamo from the import time to
the invocation of the decorated function.
"""
if fn is not None:
@functools.wraps(fn)
def inner(*args, **kwargs):
# cache this on the first invocation to avoid adding too much overhead.
disable_fn = getattr(fn, "__dynamo_disable", None)
if disable_fn is None:
import torch._dynamo
disable_fn = torch._dynamo.disable(fn, recursive)
fn.__dynamo_disable = disable_fn
return disable_fn(*args, **kwargs)
return inner
else:
# decorator usage like @_disable_dynamo(recursive=False). The resulting
# object expects the original decorated function as the arg.
return functools.partial(_disable_dynamo, recursive=recursive)

View File

@ -0,0 +1,275 @@
# mypy: allow-untyped-defs
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools
# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
autograd_fallback = autograd_not_implemented(custom_op)
def inner(*args, **kwargs):
if custom_op._has_impl('autograd'):
kernel = custom_op._get_impl('autograd').func
return kernel(*args, **kwargs)
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
# after the user gives us "backward" and "save_for_backward", we generate
# the "autograd" impl. If the user only provided one, then we tell
# the user they've done something wrong.
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
missing = (
'save_for_backward' if custom_op._has_impl('backward')
else 'backward'
)
found = 'save_for_backward' if missing == 'backward' else 'backward'
loc = custom_op._get_impl(found).location
raise RuntimeError(
f"We found a '{found}' registration for {custom_op} at "
f"{loc} but were unable to find a '{missing}' registration. "
f"To use the CustomOp API to register a backward formula, "
f"please provide us both a backward function and a "
f"'save for backward' function via `impl_backward` and "
f"`impl_save_for_backward` respectively.")
return autograd_fallback(*args, **kwargs)
return inner
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(custom_op):
def kernel(*args, **kwargs):
if torch.is_grad_enabled() and pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
with torch._C._AutoDispatchBelowAutograd():
return custom_op(*args, **kwargs)
return kernel
def mark_non_differentiable(ctx, output, output_differentiability):
# Output types are restricted to be:
# - Tensor
# - Tensor[]
# - int, bool, Scalar, float
# See _check_can_register_backward
if output_differentiability is not None:
if not isinstance(output, tuple):
tuple_output = (output,)
else:
tuple_output = output # type: ignore[assignment]
assert len(output_differentiability) == len(tuple_output)
non_differentiable_tensors = []
for idx, (differentiable, out) in enumerate(zip(output_differentiability, tuple_output)):
if isinstance(out, torch.Tensor):
if not differentiable:
non_differentiable_tensors.append(out)
continue
if isinstance(out, list):
if not differentiable:
non_differentiable_tensors.extend(out)
continue
if differentiable:
raise RuntimeError(
f"With output_differentiability={output_differentiability}. "
f"At idx {idx}, we received an object of type {type(out)} that "
f"is not a Tensor, so it cannot have be marked as differentiable in "
f"output_differentiability.")
if non_differentiable_tensors:
ctx.mark_non_differentiable(*non_differentiable_tensors)
def construct_autograd_kernel(
schema,
output_differentiability,
custom_op,
op_overload,
save_for_backward_fn,
backward_fn):
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
out_spec = None
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd():
output = op_overload(*args)
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
schema, pytree.tree_map(type, args))
save_for_backward_fn_inputs = namedtuple_args(schema, args)
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
mark_non_differentiable(ctx, output, output_differentiability)
nonlocal out_spec
flat_output, out_spec = pytree.tree_flatten(output)
return tuple(flat_output)
def backward(ctx, *flat_grad_output):
assert out_spec is not None
grads = pytree.tree_unflatten(list(flat_grad_output), out_spec)
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
if not isinstance(grads, tuple):
grads = (grads,)
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, custom_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
custom_op._opname + '_customop', forward, backward)
flat_output = generated_cls.apply(*flat_args)
assert out_spec is not None
return pytree.tree_unflatten(list(flat_output), out_spec)
return apply
def gen_autograd_function(name, forward, backward):
generated_cls = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
}
)
return generated_cls
@functools.lru_cache
def namedtuple_args_cls(schema):
attribs = [arg.name for arg in schema.arguments.flat_all]
name = str(schema.name) + "_args"
# mypy doesn't support dynamic namedtuple name
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
return tuple_cls
def namedtuple_args(schema, args):
assert isinstance(args, tuple)
tuple_cls = namedtuple_args_cls(schema)
return tuple_cls(*args)
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
def error(what):
backward = forward_op._get_impl('backward')
raise RuntimeError(
f"In the backward function defined for {forward_op} at "
f"{backward.location} using the CustomOp API, {what}")
if not isinstance(grad_inputs_dict, dict):
error(f"expected the output of the backward function to be a dict but "
f"got {type(grad_inputs_dict)}")
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
if arg.type.is_tensor_like()}
actual_keys = grad_inputs_dict.keys()
if expected_keys != actual_keys:
error(f"expected the returned grad_input dict to have keys "
f"{expected_keys} but got {actual_keys}. The backward "
f"function must return a gradient (can be None) for each arg "
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
f"Args declared to be non-Tensor-like types should not appear "
f"in the grad_input dict")
for name, grad in grad_inputs_dict.items():
arg_info = getattr(args_info, name)
if isinstance(arg_info, list):
if not isinstance(grad, (tuple, list)):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of gradients but got object of type "
f"{type(grad)}.")
if not len(grad) == len(arg_info):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "
f"{len(grad)}")
for idx, (g, info) in enumerate(zip(grad, arg_info)):
if g is None:
continue
if not isinstance(g, torch.Tensor):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of None or Tensor gradients but got "
f"object of {type(g)} at index {idx}")
if not issubclass(info, torch.Tensor):
error(f"for input '{name}', got a Tensor as the gradient "
f"for the {idx}-th value but expected None because "
f"the {idx}-th value was not a Tensor (it was "
f"type {arg_info}")
continue
if grad is None:
continue
if not isinstance(grad, torch.Tensor):
error(f"got object of type {type(grad)} as the gradient for input "
f"'{name}', "
f"but expected the gradient to be either None or a Tensor")
if not issubclass(arg_info, torch.Tensor):
error(f"got a Tensor as the gradient for input '{name}' but "
f"expected None as the gradient because input '{name}' "
f"was not a Tensor (it was type {arg_info}).")
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
result = []
for name, arg_info in args_info._asdict().items():
if name not in grad_inputs_dict:
result.append(pytree.tree_map(lambda x: None, arg_info))
continue
result.append(grad_inputs_dict[name])
return tuple(pytree.tree_leaves(result))
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
flat_stuff, spec = pytree.tree_flatten(stuff)
num_elts = len(flat_stuff)
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if isinstance(thing, torch.Tensor)]
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if not isinstance(thing, torch.Tensor)]
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
ctx.spec = spec
ctx.num_elts = num_elts
ctx.save_for_backward(*tensors)
ctx.tensor_idxs = tensor_idxs
ctx.saved_non_tensors = non_tensors
ctx.non_tensor_idxs = non_tensor_idxs
# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
flat_stuff = [None] * ctx.num_elts
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
flat_stuff[idx] = tensor
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
flat_stuff[idx] = non_tensor
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
return stuff

View File

@ -0,0 +1,188 @@
# mypy: allow-untyped-defs
import weakref
import torch
import torch.utils._pytree as pytree
from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
from torch._ops import OpOverload
from torch.library import Library
from torchgen.model import (
BaseTy,
BaseType,
FunctionSchema,
OperatorName,
OptionalType,
SchemaKind,
)
from .autograd import autograd_not_implemented
def register_functional_op(
lib: Library,
new_op_name: str,
mutable_op: OpOverload,
) -> None:
"""Given a mutable operator, registers the functional variant.
This API also correctly links the functional variant with the mutable
operator for the purposes of functionalization.
All of the new registrations are performed on the ``lib`` passed in.
Arguments:
lib (Library): Should be a torch.library.Library object that has
the same namespace as ``mutable_op``'s namespace.
lib will be used to register the new functional op as well
as a functionalization kernel for the ``mutable_op``
If you don't have a library handy, use
``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
new_op_name (str): The name of the functional operator (without the
namespace). If no namespace, the new functional variant will be
accessible under ``torch.ops.{lib.ns}.new_op_name``.
mutable_op (OpOverload): The mutable custom operator. Note
that you may need to add a `.default` to it, like
`torch.ops.aten.abs_.default`.
"""
validate(mutable_op)
schema = functional_schema(new_op_name, mutable_op)
lib.define(schema)
functional_impl = construct_functional_impl(mutable_op)
lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
# There's no easy way for us to generate the autograd kernel, so we
# use autograd_not_implemented. Also, this makes it so that the user
# is unable to register an autograd formula themselves. This shouldn't
# be a problem if the user doesn't use the functional op direclty
# in their program, but we may need to revist this in the future.
lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
lib.impl(mutable_op, f_kernel, 'Functionalize')
def construct_functional_impl(mutable_op):
def functional_impl(*args):
# Strategy:
# - clone args that would have been mutated
# - run mutable_op
# - return the cloned args as additional outputs
new_args = []
extra_rets = []
for is_write, arg in zip(mutable_args(mutable_op), args):
if is_write:
cloned = arg.clone() if arg is not None else None
new_args.append(cloned)
extra_rets.append(cloned)
else:
new_args.append(arg)
result = mutable_op(*new_args)
if result is None:
return tuple(extra_rets)
if isinstance(result, tuple):
return (*result, *extra_rets)
return (result, *extra_rets)
return functional_impl
def construct_functionalization_kernel(mutable_op, functional_op):
def kernel(*args):
# There's nothing to be functionalized!
# We can still end up here because DispatchKey::Functionalize is a mode key
if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
return mutable_op(*args)
# NB: This differs from the codegen -- codegen handles cases where there
# are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
# This only really matters for XLA (mixed CPU-XLA tensors) and
# running functionalization without the PT2 stack (which guarantees to us that
# all tensors are FunctionalTensorWrapper).
if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
unwrapped_args = []
for arg in args:
if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
torch._sync(arg)
unwrapped = torch._from_functional_tensor(arg)
unwrapped_args.append(unwrapped)
else:
unwrapped_args.append(arg)
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
output = functional_op(*unwrapped_args)
num_actual_output = len(mutable_op._schema.returns)
actual_output = pytree.tree_map(
torch._to_functional_tensor, output[:num_actual_output])
new_values_to_propagate = output[num_actual_output:]
inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
if is_write]
assert len(new_values_to_propagate) == len(inputs_to_replace)
for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
if (arg is None and new_value is None) or (arg is not None and new_value is not None):
continue
torch._C._propagate_xla_data(arg, new_value)
torch._C._replace_(arg, new_value)
torch._C._commit_update(arg)
torch._sync(arg)
if len(actual_output) == 1:
return actual_output[0]
elif len(actual_output) == 0:
return None
return actual_output
return kernel
def validate(mutable_op: OpOverload):
if not isinstance(mutable_op, OpOverload):
raise TypeError(
f"register_functional_op(mutable_op): expected mutable_op to be instance of "
f"OpOverload but got {type(mutable_op)}")
# There are generally three types of "in-place" or "mutable" ops.
# Each of them have their own conventions:
# - inplace (first input modified in-place and returned as only output)
# - out= (some args modified in-place and returned as outputs)
# - mutable (some args modified in-place but none of those returned as outputs)
# In theory we can support all three, but we'll just support the last
# option right now for simplicity.
schema = FunctionSchema.parse(str(mutable_op._schema))
if not schema.kind() == SchemaKind.mutable:
raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
for ret in schema.returns:
# construct_functionalization_kernel assumes this for simplicity
if ret.annotation is not None:
raise NotImplementedError(
"NYI: register_functional_op(op) where op returns a mutated or aliased value. "
"Please file an issue (and as a workaround, modify your operator to "
"not return the mutated value or aliases)")
for arg in schema.arguments.flat_all:
# construct_functionalization_kernel assumes this for simplicity
if arg.type.is_tensor_like() and (
arg.type != BaseType(BaseTy.Tensor)
and arg.type != OptionalType(BaseType(BaseTy.Tensor))
):
raise NotImplementedError(
"NYI: register_functional_op(op) where op has a List[Tensor] input."
"Please file an issue.")
def functional_schema(new_op_name, op: OpOverload):
schema = FunctionSchema.parse(str(op._schema))
schema = schema.signature().with_name(OperatorName.parse(new_op_name))
return str(schema)
def mutable_args(op: OpOverload):
return tuple(False if arg.alias_info is None else arg.alias_info.is_write
for arg in op._schema.arguments)

View File

@ -0,0 +1,670 @@
# mypy: allow-untyped-defs
import dataclasses
import functools
import inspect
import sys
import typing
import weakref
import warnings
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
import torch
import torch._C as _C
import torch.library as library
from torch.library import get_ctx
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
import torch._library.infer_schema
from torch._library.infer_schema import infer_schema
"""
torch._custom_op is deprecated. We shipped a production-ready version of it into torch.library.
Please use those APIs instead.
"""
__all__ = ["custom_op", "CustomOp", "get_ctx"]
SUPPORTED_DEVICE_TYPE_TO_KEY = {
"cpu": "CPU",
"cuda": "CUDA",
}
# We will not let users register CustomOps with anything that could look like
# PyTorch internals to avoid confusion.
RESERVED_NS = {
"prim",
"prims",
"aten",
"at",
"torch",
"pytorch",
}
def warn_deprecated():
warnings.warn(
"torch._custom_op is deprecated and will be removed in PyTorch 2.6, please "
"use the equivalent torch.library API instead.", DeprecationWarning)
def custom_op(
qualname: str, manual_schema: typing.Optional[str] = None
) -> typing.Callable:
r"""
This API is deprecated, please use torch.library.custom_op instead
"""
warn_deprecated()
def inner(func):
if not inspect.isfunction(func):
raise ValueError(
f"custom_op(...)(func): Expected `func` to be a Python "
f"function, got: {type(func)}"
)
ns, name = parse_qualname(qualname)
validate_namespace(ns)
if func.__name__ != name:
raise ValueError(
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
f"to have name '{name}' but got '{func.__name__}'. "
f"Please either change the name of `func` or the qualname that "
f"is passed to `custom_op`"
)
schema = infer_schema(func, mutates_args=()) if manual_schema is None else manual_schema
schema_str = f"{name}{schema}"
function_schema = FunctionSchema.parse(schema_str)
validate_schema(function_schema)
if manual_schema is not None:
validate_function_matches_schema(function_schema, func)
lib = library.Library(ns, "FRAGMENT")
lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
result.__name__ = func.__name__
result.__module__ = func.__module__
result.__doc__ = func.__doc__
library.impl(lib, result._opname, "Autograd")(
autograd_kernel_indirection(weakref.proxy(result))
)
torch._C._dispatch_set_report_error_callback(
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
)
return result
return inner
# Global dictionary holding references to all CustomOp objects
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
# An example usage is FakeTensor: FakeTensor checks if a specific operator
# has an implementation registered via the CustomOp API.
# Indexed by qualname (e.g. aten::foo)
global_registry: typing.Dict[str, "CustomOp"] = {}
class CustomOp:
r"""
This API is deprecated, please use torch.library.custom_op instead
"""
def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
super().__init__()
warn_deprecated()
if not _private_access:
raise RuntimeError(
"The CustomOp constructor is private and we do not guarantee "
"BC for it. Please use custom_op(...) to create a CustomOp object"
)
name = f"{cpp_ns}::{operator_name}"
self._schema = schema
self._cpp_ns = cpp_ns
self._lib: library.Library = lib
self._ophandle: _C._DispatchOperatorHandle = ophandle
# Has the name of the op, e.g. "foo". We cache here for convenience.
self._opname: str = operator_name
# this is _opname but with namespace. e.g. "custom::foo"
self._qualname: str = name
self.__name__ = None # mypy requires this
# NB: Some of these impls are registered as kernels to DispatchKeys.
# Modifying the _impls dict directly won't do anything in that case.
self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
# See NOTE [CustomOp autograd kernel indirection]
self._registered_autograd_kernel_indirection = False
global_registry[self._qualname] = self
def _register_autograd_kernel_indirection(self):
assert not self._registered_autograd_kernel_indirection
self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
self._registered_autograd_kernel_indirection = True
# Records the impl and the source location in self._impls
# Note that this doesn't cause torch.library to use the impl, that
# needs to be done in a separate self._lib.impl call.
def _register_impl(self, kind, func, stacklevel=2):
if self._has_impl(kind):
func_and_location = self._impls[kind]
assert func_and_location is not None # Pacify mypy
location = func_and_location.location
raise RuntimeError(
f"Attempting to register a {kind} impl for operator {self._qualname} "
f"that already has a {kind} impl registered from Python at "
f"{location}. This is not supported."
)
frame = inspect.getframeinfo(sys._getframe(stacklevel))
location = f"{frame.filename}:{frame.lineno}"
self._impls[kind] = FuncAndLocation(func, location)
def _get_impl(self, kind):
return self._impls[kind]
def _has_impl(self, kind):
return kind in self._impls
def _destroy(self):
# NOTE: [CustomOp lifetime]
# A CustomOp, once created, lives forever. The mechanism is that the
# global registry holds a reference to it. However, to make testing
# easier, we want to be able to destroy CustomOp objects.
# CustomOp._destroy does the job, though it leaves the CustomOp
# in a garbage state.
del self._lib
opnamespace = getattr(torch.ops, self._cpp_ns)
if hasattr(opnamespace, self._opname):
delattr(opnamespace, self._opname)
del global_registry[self._qualname]
def __repr__(self):
return f'<CustomOp(op="{self._qualname}")>'
def __call__(self, *args, **kwargs):
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
# issues from caching operators that make testing CustomOp difficult).
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
return result
def impl(
self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
) -> typing.Callable:
r"""
This API is deprecated, please use torch.library.custom_op instead
"""
if isinstance(device_types, str):
device_types = [device_types]
for device_type in device_types:
validate_device_type(device_type)
def inner(f):
for device_type in set(device_types):
self._check_doesnt_have_library_impl(device_type)
self._register_impl(device_type, f, stacklevel=_stacklevel)
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
library.impl(self._lib, self._opname, dispatch_key)(f)
return f
return inner
def _check_doesnt_have_library_impl(self, device_type):
if self._has_impl(device_type):
return
key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
raise RuntimeError(
f"impl(..., device_types={device_type}): the operator {self._qualname} "
f"already has an implementation for this device type via a "
f"pre-existing torch.library or TORCH_LIBRARY registration.")
def impl_factory(self) -> typing.Callable:
r"""Register an implementation for a factory function."""
def inner(f):
self._register_impl("factory", f)
library.impl(self._lib, self._opname, "BackendSelect")(f)
return f
return inner
def impl_abstract(self, _stacklevel=2) -> typing.Callable:
r"""
This API is deprecated, please use torch.library.custom_op instead
"""
def inner(f):
self._check_doesnt_have_library_meta_impl()
self._register_impl("abstract", f, stacklevel=_stacklevel)
location = self._get_impl("abstract").location
qualname = self._qualname
# Handle DispatchKey.Meta registration
@functools.wraps(f)
def f_with_ctx(*args, **kwargs):
def error_on_ctx():
raise RuntimeError(
f"Attempted to call get_ctx() for the meta implementation "
f"for {qualname}."
f"You have presumably called get_ctx() because the operator "
f"has a data-dependent output shape; if so, there is no "
f"such meta implementation and this error is the correct "
f"behavior. Otherwise, please remove the call to get_ctx() "
f"in the implementation registered with impl_abstract "
f"at {location}"
)
with torch._library.fake_impl.set_ctx_getter(error_on_ctx):
return f(*args, **kwargs)
self._lib.impl(self._opname, f_with_ctx, "Meta")
return f
return inner
def _check_can_register_backward(self):
def error(detail):
raise RuntimeError(
f"Cannot use torch._custom_ops APIs to register backward "
f"formula for {detail}. Got operator "
f"{self._qualname} with schema: {schema}"
)
schema = self._schema
if schema.kind() != SchemaKind.functional:
error("non-functional operator")
rets = schema.returns
if not schema.returns:
error("operator with no returns")
assert len(rets) > 0
is_non_mutating_view = any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
if is_non_mutating_view:
error("operator that returns views")
# We make assumptions about the schema's return types.
allowed_return_types = {
BaseType(BaseTy.int): "int",
BaseType(BaseTy.SymInt): "SymInt",
BaseType(BaseTy.bool): "bool",
BaseType(BaseTy.float): "float",
BaseType(BaseTy.Tensor): "Tensor",
ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
}
for ret in schema.returns:
if ret.type in allowed_return_types:
continue
error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
def _check_doesnt_have_library_autograd_impl(self):
if self._registered_autograd_kernel_indirection:
return
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an autograd formula; "
f"instead, the operator will decompose into its constituents and those "
f"can have autograd formulas defined on them.")
# We can improve this by adding "all Autograd<BACKEND> keys", but
# realistically people will just be using this API for CPU/CUDA for now.
for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
raise RuntimeError(
f"impl_backward/impl_save_for_backward: "
f"the operator {self._qualname} already has an Autograd kernel "
f"registered to DispatchKey::{key} vi a pre-existing "
f"torch.library or TORCH_LIBRARY registration. Please either "
f"remove those registrations or don't use the torch._custom_ops APIs")
def _check_doesnt_have_library_meta_impl(self):
if self._has_impl("abstract"):
return
# If the user's operator is CompositeExplicitAutograd,
# allow them to impl_abstract. This is being pragmatic
# (existing custom ops may have CompositeExplicitAutograd
# registration that don't work with Meta kernels, so this
# gives them an escape hatch).
if (
_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
):
return
# Otherwise, if the user's already has a Meta kernel or their
# op is CompositeImplicitAutograd or some other alias dispatch key,
# raise.
# Special case for CompositeImplicitAutograd
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
raise RuntimeError(
f"impl_abstract(...): the operator {self._qualname} "
f"already has an implementation for this device type via a "
f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
f"CompositeImplicitAutograd operators do not need an abstract impl; "
f"instead, the operator will decompose into its constituents and those "
f"can have abstract impls defined on them.")
if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
raise RuntimeError(
f"impl_abstract(...): the operator {self._qualname} "
f"already has an DispatchKey::Meta implementation via a "
f"pre-existing torch.library or TORCH_LIBRARY registration. "
f"Please either remove that registration or don't call impl_abstract.")
# NOTE ["backward", "save_for_backward", and "autograd"]
# As a part of the explicit autograd API, a user must provide us
# a "save_for_backward" function and a "backward" function.
# When both of these have been provided, then we automatically
# construct the "autograd" kernel.
def _register_autograd_kernel(self):
assert self._has_impl("backward")
assert self._has_impl("save_for_backward")
kernel = construct_autograd_kernel(
self._schema,
self._output_differentiability,
self,
get_op(self._qualname),
self._get_impl("save_for_backward").func,
self._get_impl("backward").func)
self._register_impl("autograd", kernel)
def impl_save_for_backward(self, _stacklevel=2):
r"""Register a function that tells us what to save for backward.
Please see impl_backward for more details.
"""
def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
if self._has_impl("backward"):
self._register_autograd_kernel()
return inner
def impl_backward(self, output_differentiability=None, _stacklevel=2):
r"""
This API is deprecated, please use torch.library.custom_op instead
"""
if output_differentiability is not None:
def yell():
raise RuntimeError(
f"impl_backward(output_differentiability): expected "
f"output_differentiability to be a list of bools with "
f"length equal to the number of outputs of this CustomOp "
f"got: {output_differentiability}")
if not isinstance(output_differentiability, list):
yell()
for diff in output_differentiability:
if not isinstance(diff, bool):
yell()
if len(self._schema.returns) != len(output_differentiability):
yell()
def inner(f):
self._check_can_register_backward()
self._check_doesnt_have_library_autograd_impl()
if not self._registered_autograd_kernel_indirection:
self._register_autograd_kernel_indirection()
self._register_impl("backward", f, stacklevel=_stacklevel)
self._output_differentiability = output_differentiability
if self._has_impl("save_for_backward"):
self._register_autograd_kernel()
return inner
@dataclasses.dataclass
class FuncAndLocation:
func: typing.Callable
location: str
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
overload_name = (
"" if operator_name.overload_name is None else operator_name.overload_name
)
return _C._dispatch_find_schema_or_throw(
f"{cpp_ns}::{str(operator_name.name)}", overload_name
)
def validate_namespace(ns: str) -> None:
if "." in ns:
raise ValueError(
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
f"valid variable name)"
)
if ns in RESERVED_NS:
raise ValueError(
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
f"please choose something else. "
)
def validate_schema(schema: FunctionSchema) -> None:
if not torch._library.utils.is_functional_schema(schema):
raise ValueError(
f"custom_op only supports functional operators "
f"(ops that do not mutate any inputs, do not return "
f"views of the inputs, and has at least one return). "
f"Got the following non-functional schema: {schema}"
)
# For simplicity: don't allow self arguments
if schema.arguments.self_arg is not None:
raise ValueError(
f"custom_op does not support arguments named 'self'. Please "
f"rename your argument. Got: {schema}"
)
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
names = qualname.split("::", 1)
if len(names) != 2:
raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
f"operator name should look something like ns::foo")
if '.' in names[1]:
raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
f"i.e. operator names with '.' in them. "
f"Please name your operator something like ns::foo. "
f"Got: {qualname}")
return names[0], names[1]
def validate_device_type(device_type: str) -> None:
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
raise ValueError(
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
)
def supported_param(param: inspect.Parameter) -> bool:
return param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
def validate_function_matches_schema(
schema: FunctionSchema, func: typing.Callable
) -> None:
sig = inspect.signature(func)
if not all(supported_param(p) for _, p in sig.parameters.items()):
raise ValueError(
f"custom_op(..., manual_schema)(func): positional-only args, "
f"varargs, and kwargs are not supported. Please rewrite `func` "
f"to not have them. Got `func` with signature: {sig}"
)
if (
any(
p.annotation is not inspect.Parameter.empty
for _, p in sig.parameters.items()
)
or sig.return_annotation is not inspect.Signature.empty
):
raise ValueError(
f"custom_op(..., manual_schema)(func): When passing in a manual "
f"schema, we expect `func` to have no type annotations to avoid "
f"ambiguity. Got `func` with signature: {sig}"
)
positional = [
(name, param)
for name, param in sig.parameters.items()
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwargonly = [
(name, param)
for name, param in sig.parameters.items()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
def error():
raise ValueError(
f"custom_op(..., manual_schema)(func): When passing in a manual "
f"schema, we expect `func`'s signature to match `manual_schema` "
f"(aside from type annotations). "
f"func's signature: {sig}, manual_schema: {schema}"
)
def error_default_args():
raise ValueError(
f"custom_op(..., manual_schema)(func): "
f"neither func nor manual_schema should have default "
f"arguments. Got "
f"func's signature: {sig}, manual_schema: {schema}"
)
def compare(sig_args, schema_args):
if len(sig_args) != len(schema_args):
error()
for (name, param), arg in zip(sig_args, schema_args):
if name != arg.name:
error()
if param.default is not inspect.Parameter.empty or arg.default is not None:
error_default_args()
compare(positional, schema.arguments.flat_positional)
compare(kwargonly, schema.arguments.flat_kwarg_only)
def report_error_callback(custom_op: typing.Any, key: str) -> None:
if key == "Undefined":
raise NotImplementedError(
f"{custom_op}: There were no Tensor inputs to this operator "
f"(e.g. you passed an empty list of Tensors). If your operator is a "
f"factory function (that is, it takes no Tensors and constructs "
f"a new one), then please use CustomOp.impl_factory to register "
f"an implementation for it"
)
if key == "Meta":
raise NotImplementedError(
f"{custom_op}: when running with device='Meta' tensors: there is no "
f"abstract impl registered for this CustomOp. Please register one via "
f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
)
if key in ("CPU", "CUDA"):
device = key.lower()
raise NotImplementedError(
f"{custom_op}: when running with device='{device}' tensors: there is no "
f"{device} impl registered for this CustomOp. Please register one via "
f"CustomOp.impl(device_type='{device}')"
)
raise NotImplementedError(
f"{custom_op}: No implementation for dispatch key {key}. It is likely "
f"that we have not added this functionality yet, please either open an "
f"issue or if you're feeling adventurous, use the low-level "
f"torch.library API"
)
def custom_op_from_existing(op):
ns = op.namespace
lib = torch.library.Library(ns, "FRAGMENT")
name = op.name().split("::")[-1]
schema_str = str(op._schema)
# CustomOp expects the schema string without the namespace
schema_str = schema_str.split("::")[-1]
schema = FunctionSchema.parse(schema_str)
return CustomOp(lib, ns, schema, name, op, _private_access=True)
def get_op(qualname):
def error_not_found():
raise ValueError(
f"Could not find the operator {qualname}. Please make sure you have "
f"already registered the operator and (if registered from C++) "
f"loaded it via torch.ops.load_library.")
ns, name = parse_qualname(qualname)
if not hasattr(torch.ops, ns):
error_not_found()
opnamespace = getattr(torch.ops, ns)
if not hasattr(opnamespace, name):
error_not_found()
packet = getattr(opnamespace, name)
if not hasattr(packet, 'default'):
error_not_found()
return packet.default
def _find_custom_op(qualname, also_check_torch_library=False):
if qualname in global_registry:
return global_registry[qualname]
if not also_check_torch_library:
raise RuntimeError(
f'Could not find custom op "{qualname}". Did you register it via '
f"the torch._custom_ops API?")
overload = get_op(qualname)
result = custom_op_from_existing(overload)
return result
def get_abstract_impl(qualname):
if qualname not in torch._custom_op.impl.global_registry:
return None
custom_op = torch._custom_op.impl.global_registry[qualname]
if custom_op is None:
return None
if not custom_op._has_impl("abstract"):
return None
return custom_op._get_impl("abstract").func
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
ns, name = qualname.split("::")
schema_str = f"{name}{schema}"
function_schema = FunctionSchema.parse(schema_str)
validate_schema(function_schema)
tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
lib = library.Library(ns, "FRAGMENT")
lib.define(schema_str, tags=tags)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
result._register_autograd_kernel_indirection()
torch._C._dispatch_set_report_error_callback(
ophandle, functools.partial(report_error_callback, weakref.proxy(result))
)
return get_op(qualname)

View File

@ -0,0 +1,324 @@
# mypy: allow-untyped-defs
import inspect
from torch._custom_op.impl import (
_custom_op_with_schema,
_find_custom_op,
infer_schema,
parse_qualname,
validate_namespace,
)
from torch.library import get_ctx
__all__ = [
"custom_op",
"impl",
"impl_abstract",
"get_ctx",
"impl_save_for_backward",
"impl_backward",
]
def custom_op(qualname, func_or_schema=None):
r"""Register a new custom operator
In PyTorch, defining an op (short for "operator") is a two step-process:
- we need to define the op (by providing an operator name and schema)
- we need to implement behavior for how the operator interacts with
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
This entrypoint defines the custom operator (the first step)
you must then perform the second step by calling various
``impl_*`` APIs.
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
Arguments:
qualname (str): Should be a string that looks like
"namespace::operator_name". Operators in PyTorch need a namespace to
avoid name collisions; a given operator may only be created once.
If you are writing a Python library, we recommend the namespace to
be the name of your top-level module.
func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
schema that tells PyTorch the types of the inputs/outputs.
If this is a Callable, we will automatically infer the schema from
the type annotations on the function (see examples). Otherwise,
if you don't want to use type annotations, you may provide us the
schema string.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Step 1: define the custom op.
>>> # We need to provide the API a "prototype function"
>>> # (a function that returns NotImplementedError), from which
>>> # we will infer the types of the inputs and outputs.
>>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> raise NotImplementedError
>>>
>>> # The custom op is now accessible via the torch.ops module:
>>> torch.ops.mylibrary.numpy_sin
>>>
>>> # Step 2: Register an implementation for various PyTorch subsystems
>>>
>>> # Register an implementation for CPU tensors
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
>>> def numpy_sin_impl_cpu(x):
>>> return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Register an implementation for CUDA tensors
>>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
>>> def numpy_sin_impl_cuda(x):
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
>>>
>>> x = torch.randn(3)
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu
>>>
>>> x_cuda = x.cuda()
>>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda
"""
ns, name = parse_qualname(qualname)
validate_namespace(ns)
def inner(func):
if not inspect.isfunction(func):
raise ValueError(
f"custom_op(...)(func): Expected `func` to be a Python "
f"function, got: {type(func)}"
)
if func.__name__ != name:
raise ValueError(
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
f"to have name '{name}' but got '{func.__name__}'. "
f"Please either change the name of `func` or the qualname that "
f"is passed to `custom_op`"
)
schema = infer_schema(func, mutates_args=())
_custom_op_with_schema(qualname, schema)
return func
if func_or_schema is None:
return inner
if isinstance(func_or_schema, str):
_custom_op_with_schema(qualname, func_or_schema)
else:
return inner(func_or_schema)
def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
r"""Register an implementation for a device type for this custom op.
If the op is passed multiple Tensor inputs with different device
types, it will dispatch to the registered implementation for the highest
priority device type among those present.
The supported device types, in order of priority, are {'cuda', 'cpu'}.
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
Arguments:
device_types (str or Iterable[str]): the device type(s) to register the function for.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Step 1: define the custom op.
>>> # We need to provide the API a "prototype function"
>>> # (a function that returns NotImplementedError), from which
>>> # we will infer the types of the inputs and outputs.
>>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
>>> def numpy_cos(x: Tensor) -> Tensor:
>>> raise NotImplementedError
>>>
>>> # The custom op is now accessible via the torch.ops module:
>>> torch.ops.mylibrary.numpy_cos
>>>
>>> # Step 2: Register an implementation for various PyTorch subsystems
>>>
>>> # Register an implementation for CPU tensors
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
>>> def numpy_cos_impl_cpu(x):
>>> return torch.from_numpy(np.cos(x.numpy()))
>>>
>>> # Register an implementation for CUDA tensors
>>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
>>> def numpy_cos_impl_cuda(x):
>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
>>>
>>> x = torch.randn(3)
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu
>>>
>>> x_cuda = x.cuda()
>>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda
"""
def inner(func):
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl(device_types, _stacklevel=3)(func)
return func
if func is None:
return inner
return inner(func)
def impl_abstract(qualname, *, func=None):
r"""Register an abstract implementation for this operator.
An "abstract implementation" specifies the behavior of this operator on
Tensors that carry no data. Given some input Tensors with certain properties
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
The abstract implementation has the same signature as the operator.
It is run for both FakeTensors and meta tensors. To write an abstract
implementation, assume that all Tensor inputs to the operator are
regular CPU/CUDA/Meta tensors, but they do not have storage, and
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
The abstract implementation must consist of only PyTorch operations
(and may not directly access the storage or data of any input or
intermediate Tensors).
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please see
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
Examples::
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
>>> raise NotImplementedError
>>>
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
>>> def custom_linear_abstract(x, weight):
>>> assert x.dim() == 2
>>> assert weight.dim() == 2
>>> assert bias.dim() == 1
>>> assert x.shape[1] == weight.shape[1]
>>> assert weight.shape[0] == bias.shape[0]
>>> assert x.device == weight.device
>>>
>>> return (x @ weight.t()) + bias
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>> ...
>>>
>>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
>>> def custom_nonzero_abstract(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch._custom_ops.get_ctx()
>>> nnz = ctx.create_unbacked_symint()
>>> shape = [x.dim(), nnz]
>>> result = x.new_empty(shape, dtype=torch.long)
>>> return result
>>>
>>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
>>> def custom_nonzero_impl(x):
>>> x_np = to_numpy(x)
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
>>> # constrain the range to at least 2
>>> if res.shape[0] <= 1:
>>> raise RuntimeError("not supported")
>>> return torch.tensor(res, device=x.device)
"""
import torch.library
return torch.library.register_fake(qualname, func, _stacklevel=2)
def impl_save_for_backward(qualname, *, func=None):
r"""Register a function that tells us what to save for backward.
Please see :func:`impl_backward` for more details.
"""
def inner(func):
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_save_for_backward(_stacklevel=3)(func)
return func
if func is None:
return inner
return inner(func)
def impl_backward(qualname, output_differentiability=None, *, func=None):
r"""Registers a backward formula for an operator.
In order for an operator to work with autograd, you need to register
a backward formula. There are two pieces to this:
1. You must give us a function to specify what to save for backward.
Call this the "save for backward" function.
2. You must give us a function that computes gradients. Call this the
"backward" function.
Use `impl_save_for_backward` to define a "save for backward" function
that specifies what gets saved for backward. The function should accept
two arguments ``(inputs, output)`` and return the quantities to be saved
for backward.
During runtime, when you call the operator in a forwards pass, PyTorch
will invoke the "save for backward" function with the inputs and output
of the operator.
Use `impl_backward` to define the "backward" function. The backward
function must accept ``(ctx, saved, *grads)``:
- ``ctx`` is a context object where we may provide information
- ``saved`` is exactly what gets returned from the "save for backward"
function
- ``grads`` is one or more gradients. The number of gradients matches
the number of outputs of the operator.
The backward function must return a dict that maps the name of
an input to the operator to its corresponding gradient. All inputs that
were declared to be Tensors in the operator definition must be accounted
for in the dict. The gradient may be a Tensor or None.
For a detailed guide on custom ops, please see
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
"""
def inner(func):
custom_op = _find_custom_op(qualname, also_check_torch_library=True)
custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
return func
if func is None:
return inner
return inner(func)
def _destroy(qualname):
"""De-registers a custom op. For testing purposes only"""
custom_op = _find_custom_op(qualname)
custom_op._destroy()

View File

@ -0,0 +1,484 @@
# mypy: allow-untyped-defs
import inspect
from collections import defaultdict
from functools import wraps
from itertools import chain
from typing import Callable, Dict, List, Sequence, TypeVar, Union
from typing_extensions import ParamSpec
import torch
import torch.library
from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
from torch._prims_common import CustomOutParamAnnotation
from torch.utils import _pytree as pytree
__all__ = [
"decomposition_table",
"pre_autograd_decomposition_table",
"meta_table",
"register_decomposition",
"get_decompositions",
"core_aten_decompositions",
]
_T = TypeVar("_T")
_P = ParamSpec("_P")
# TODO: relax key type here; torch registrations should be possible to; but
# right now this type is accurate
global_decomposition_table: Dict[
str, Dict[torch._ops.OperatorBase, Callable]
] = defaultdict(dict)
decomposition_table = global_decomposition_table["post_autograd"]
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
meta_table = global_decomposition_table["meta"]
def _add_op_to_registry(registry, op, fn):
"""
This is an internal API for adding an op to the decomposition table.
If op is OpOverload, it will be added to the registry directly.
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
"""
overloads: List[Union[torch._ops.OperatorBase]] = []
if isinstance(op, HigherOrderOperator):
# There's no concept of overloads for HigherOrderOperator
registry[op] = fn
return
elif isinstance(op, OpOverload):
overloads.append(op)
else:
assert isinstance(op, OpOverloadPacket)
for ol in op.overloads():
overloads.append(getattr(op, ol))
for op_overload in overloads:
if op_overload in registry:
raise RuntimeError(f"duplicate registrations for {op_overload}")
# TorchScript dumps a bunch of extra nonsense overloads
# which don't have corresponding dispatcher entries, we need
# to filter those out, e.g aten.add.float_int
if torch._C._dispatch_has_kernel(op_overload.name()):
registry[op_overload] = fn
def _convert_out_params(f):
out_annotation = f.__annotations__.get("out")
# If there are no out params, do not wrap the function.
if not out_annotation:
return f
# Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
if getattr(out_annotation, "__origin__", None) is tuple:
sig = inspect.signature(f)
out_names = sig.return_annotation._fields
# If out is a tuple, we need to register a function that unpacks all the out
# elements as this is what native_functions.yaml expects
@wraps(f)
def _fn(*args, **kwargs):
out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
# Either all of the out kwargs are set or none of them
is_none = out_kwargs[0] is None
assert all((o is None) == is_none for o in out_kwargs)
return f(*args, **kwargs, out=None if is_none else out_kwargs)
out_params = [
inspect.Parameter(
o,
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=t,
)
for o, t in zip(out_names, out_annotation.__args__)
]
# Drop the out parameter and concatenate the new kwargs in the signature
params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
)
# Drop the out parameter and concatenate the new kwargs in the annotations
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
for o in out_params:
_fn.__annotations__[o.name] = o.annotation
# Propagate that this function is wrapped by `out_wrapper`
_fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
return _fn
# Alternatively, there may be a single tensor out parameter with a name
# other than "out". This will need special treatment and is indicated by an
# annotation, which we will remove here so it is not exposed after wrapping.
custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
if custom_out_param_name:
@wraps(f)
def _fn(*args, **kwargs):
out_kwarg = kwargs.pop(custom_out_param_name, None)
return f(*args, **kwargs, out=out_kwarg)
out_param = inspect.Parameter(
custom_out_param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=out_annotation,
)
# Drop the out parameter and concatenate the new kwarg in the signature
sig = inspect.signature(f)
params = chain(
(v for k, v in sig.parameters.items() if k != "out"), (out_param,)
)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
)
# Drop the out parameter and concatenate the new kwargs in the annotations
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
_fn.__annotations__[out_param.name] = out_param.annotation
return _fn
return f
def register_decomposition(
aten_op, registry=None, *, type="post_autograd", unsafe=False
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
A decorator to register a function as a decomposition to the Python
decomposition table. Use it like this::
@register_decomposition(torch.ops.aten.clamp_min)
def clamp_min(x):
return torch.clamp(self, min=min)
If you are writing a new decomposition, consider contributing it
directly to PyTorch in torch._decomp.decompositions.
This API is experimental; we are almost certainly going to extend
the API when we make decompositions eligible for use in transforms (e.g.,
autograd) and not just backend tracing, where we then need to know if a
decomposition can be used to simulate a transform.
By default, we also will register it to the Meta key of dispatcher,
and replace the c++ Meta implementation if there is already one.
unsafe kwarg is for reuse of this function for registering non-function
things
"""
assert type in {"post_autograd", "pre_autograd", "meta"}
def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
orig_fn = fn
if not unsafe:
fn = _convert_out_params(fn)
nonlocal registry
if registry is None:
registry = global_decomposition_table[type]
def register(op):
_add_op_to_registry(registry, op, fn)
# To handle allowing multiple aten_ops at once
pytree.tree_map_(register, aten_op)
return orig_fn
return decomposition_decorator
def get_decompositions(
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
type: str = "post_autograd",
) -> Dict[torch._ops.OperatorBase, Callable]:
"""
Retrieve a dictionary of decompositions corresponding to the list of
operator overloads and overload packets passed as input. Overload
packets will include all decomposed overloads in the packet. If there is
no decomposition for a requested operator, it is silently ignored.
This API is experimental; we are almost certainly going to give an alternate,
more recommended formulation, where a user provides the set of operators
they know how to implement, and we provide decompositions for everything
not in this set.
"""
assert type in {"post_autograd", "pre_autograd", "meta"}
registry = global_decomposition_table[type]
packets_to_overloads = defaultdict(list)
for opo in registry:
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
for op in aten_ops:
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
for op_overload in packets_to_overloads[op]:
decompositions[op_overload] = registry[op_overload]
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
decompositions[op] = registry[op]
return decompositions
def remove_decompositions(
decompositions: Dict[torch._ops.OperatorBase, Callable],
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
) -> None:
"""
Given a dictionary of decompositions obtained from get_decompositions(), removes
operators associated with a list of operator overloads and overload packets passed
as input. If the decomposition dictionary does not contain a decomposition that is
specified to be removed, it is silently ignored.
"""
for op in aten_ops:
if isinstance(op, OpOverloadPacket):
for overload_name in op.overloads():
opo = getattr(op, overload_name)
decompositions.pop(opo, None)
elif isinstance(op, OpOverload):
decompositions.pop(op, None)
# populate the table
import torch._decomp.decompositions
import torch._refs
# See NOTE [Core ATen Ops]
#
# list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops
def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten = torch.ops.aten
return get_decompositions(
[
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.affine_grid_generator,
aten.alias_copy,
aten.all,
aten.aminmax,
aten.arange.default,
aten.arange.start,
aten.avg_pool2d_backward,
aten.baddbmm,
aten.binary_cross_entropy,
aten.binary_cross_entropy_backward,
aten.binary_cross_entropy_with_logits,
aten.block_diag,
aten.celu,
aten.celu_,
aten.channel_shuffle,
aten.clamp_max,
aten.clamp_min,
aten.col2im,
aten.count_nonzero,
aten.linalg_cross,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.miopen_batch_norm_backward,
aten.deg2rad,
aten.deg2rad_,
aten.detach,
aten.diag_embed,
aten.diagonal_backward,
aten.dot,
aten.vdot,
aten.elu,
aten.elu_,
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten.empty_like,
aten._euclidean_dist.default,
aten.expand_as,
aten.expand_copy,
aten.eye,
aten.fill,
aten.fill_,
aten.floor_divide,
aten.frac,
aten.frac_,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu_,
aten.gelu_backward,
aten.glu,
aten.glu_backward,
aten.hardshrink,
aten.hardsigmoid,
aten.hardsigmoid_,
aten.hardsigmoid_backward,
aten.hardswish,
aten.hardswish_,
aten.hardswish_backward,
aten.hardtanh_,
aten.hardtanh_backward,
aten.heaviside,
aten.heaviside_,
aten.huber_loss,
aten.huber_loss_backward,
aten.im2col,
aten.index_add,
aten.index_add_,
aten.index_copy,
aten.index_copy_,
aten.index_fill,
aten.index_fill_,
aten.isin,
aten.isneginf,
aten.isposinf,
aten.l1_loss,
aten._lazy_clone,
aten._test_parallel_materialize,
aten.leaky_relu_,
aten.leaky_relu_backward,
aten.lerp,
aten.lerp_,
aten.linspace,
aten.logaddexp,
aten.logaddexp2,
aten.logit,
aten.logit_,
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
aten.masked_fill,
aten.masked_fill_,
aten.mish,
aten.mish_,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
aten.multilabel_margin_loss_forward,
aten.mv,
aten.mvlgamma,
aten.mvlgamma_,
aten.nansum,
aten.nan_to_num,
aten.nan_to_num_,
aten.narrow,
aten.native_batch_norm_backward,
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,
aten.new_zeros,
aten.nll_loss2d_forward,
aten.nll_loss2d_backward,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.norm,
aten.ones,
aten.ones_like,
aten.pixel_shuffle,
aten.pixel_unshuffle,
aten._prelu_kernel,
aten._prelu_kernel_backward,
aten._reshape_alias,
aten.rad2deg,
aten.rad2deg_,
aten.reflection_pad1d,
aten.reflection_pad1d_backward,
aten.reflection_pad2d,
aten.reflection_pad2d_backward,
aten.reflection_pad3d,
aten.reflection_pad3d_backward,
aten.replication_pad1d,
aten.replication_pad2d,
aten.replication_pad3d,
aten.renorm,
aten.renorm_,
aten.replication_pad2d,
aten.resize_as,
aten.roll,
aten.rot90,
aten.rrelu_with_noise,
aten.rrelu_with_noise_,
aten.rsub,
aten._safe_softmax,
aten._scaled_dot_product_flash_attention_for_cpu.default,
aten.select_backward,
aten.select_scatter,
aten.sgn,
aten.sgn_,
aten.sigmoid_backward,
aten.silu,
aten.silu_,
aten.silu_backward,
aten.sinc,
aten.sinc_,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
aten.softshrink,
aten.special_entr,
aten.special_log_ndtr,
aten.special_xlog1py,
aten.split.Tensor,
aten.split_with_sizes_copy,
aten.squeeze.default,
aten.squeeze.dim,
aten.std,
aten.std_mean,
aten.stack,
aten.sum.default,
aten.sum.out,
aten.t,
aten.t_copy,
aten.take,
aten.tanh_backward,
aten.threshold,
aten.threshold_,
aten.threshold_backward,
aten.trace,
aten.transpose.int,
aten.tril,
aten.tril_,
aten.triu,
aten.triu_,
aten.unbind,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
aten._unsafe_index_put,
aten._unsafe_masked_index,
aten._unsafe_masked_index_put_accumulate,
aten.unsafe_split.Tensor,
aten.unsafe_split_with_sizes,
aten.unsqueeze_copy,
aten._unsafe_view,
aten.upsample_linear1d,
aten.upsample_bilinear2d,
aten.upsample_trilinear3d,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
aten.xlogy_,
aten.zero,
aten.zero_,
aten.zeros,
aten.zeros_like,
aten._chunk_cat,
aten._weight_norm_interface,
]
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,335 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch._decomp
from torch import Tensor
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten
# NOTE: [forward-mode AD decompositions mechanism]
#
# The mechanism is in VariableType,
# IF any inputs have forward grad
# AND there is no forward AD formula implemented
# AND the functions is actually differentiable
# run the decomposition
# See run_jit_decomposition_with_args_for_jvp
# We currently use python decompositions that we torchscript.
#
# Note that we would be building the backward graph at the decomposed level
# too, but that is OK, because we would've errored out otherwise anyway.
#
# TODO: The mechanism we are using to register decompositions doesn't
# seem to be exclusively used for jvp. So open question here is whether
# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things.
# If that is the case, we may go down the decomposition path unexpectedly
# (and possibly produce an unintelligible error) vs erroring out earlier and
# printing that the forward AD formula is not implemented.
#
# The solution to this may be to have a explicitly white list control when
# to enable the decomposition.
def maybe_register_decomposition(op):
def decorator(f):
try:
return register_decomposition(op)(f)
except Exception:
return f
return decorator
# Functions where we need a special decomposition for jvp but there's another version that
# should be used more generally (ex. for jvp we need to recompute the mean and variance for
# the backwards of a normalization function. Without jvp, it should use the saved value)
decomposition_table_for_jvp = {}
def register_decomposition_for_jvp(fn):
return register_decomposition(fn, registry=decomposition_table_for_jvp)
def _register_jit_decomposition_for_jvp(decomp, use_python=False):
if decomp in decomposition_table_for_jvp:
decomposition_table_used = decomposition_table_for_jvp
elif decomp in decomposition_table:
decomposition_table_used = decomposition_table
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
decomp_fn = decomposition_table_used[decomp]
# `out_wrapper` extends a decompositions signature with
# an `out` parameter. However jit will use the unwrapped function's
# signature instead so we need to unwrap here to prevent an error
decomp_fn = _maybe_remove_out_wrapper(decomp_fn)
if use_python:
decomp_fn = torch.jit.ignore(decomp_fn)
sig = inspect.signature(decomp_fn)
# Create a string wrapping the function from the signature
# example output:
# def wrapped_decomp(x: torch.Tensor, y: int, z: int):
# return decomp_fn(x, y, z)
# Thanks copilot!
def get_function_def(sig):
param_def = [f"{param_str}" for param_str in sig.parameters.values()]
param_use = [f"{param_str}" for param_str in sig.parameters.keys()]
return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n"
f_str = get_function_def(sig)
graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph
else:
graph = torch.jit.script(decomp_fn).graph
torch.jit._register_decomposition(decomp, graph)
# The only decompositions here are temporary or hacks for the purposes of jvp
# TODO: do these also belong here?
@maybe_register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda:
buffer = self.new_zeros((0,))
else:
buffer = z
return min - torch.log1p(z), buffer
def recompute_mean_var(
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input
mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim)
var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim)
eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside
eps = eps.detach()
rstd = 1 / torch.sqrt(var + eps)
return mean, rstd
@register_decomposition_for_jvp(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices = list(range(axis, input_ndim))
outer_dim_indices = list(range(0, axis))
N = 1
for i in inner_dims:
N *= i
M = 1
for i in outer_dims:
M *= i
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape),
input.new_zeros(input_shape[axis:]),
input.new_zeros(input_shape[axis:]),
)
mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True)
x_hat = (input - mean_) * rstd_
if weight is not None:
grad_x_hat = grad_out * weight
else:
grad_x_hat = grad_out
a = grad_x_hat * N
b = torch.sum(grad_x_hat, inner_dim_indices, True)
c1 = torch.mul(grad_x_hat, x_hat)
c2 = torch.sum(c1, inner_dim_indices, True)
c3 = torch.mul(x_hat, c2)
inner = a - b - c3
if output_mask[0]:
d_input: Optional[Tensor] = (rstd_ / N) * inner
else:
d_input = torch.zeros_like(input) # should be None but doesn't work with vjp
if output_mask[1] and weight is not None:
if len(outer_dim_indices) > 0:
d_weight: Optional[Tensor] = torch.sum(
grad_out * x_hat, outer_dim_indices, False
)
else:
d_weight = grad_out * x_hat
elif weight is not None:
d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp
else:
d_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2] and bias is not None:
if len(outer_dim_indices) > 0:
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
else:
d_bias = grad_out.clone()
elif bias is not None:
d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp
else:
d_bias = torch.zeros(()) # should be None but doesn't work with vjp
return (d_input, d_weight, d_bias)
def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r
@register_decomposition_for_jvp(aten.native_batch_norm_backward)
def native_batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"
axis = 1
num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type]
mean = save_mean
invstd = save_invstd
if train:
assert (
save_mean is not None and save_invstd is not None
), "when train=True, save_mean and save_invstd are required"
reduciton_dims = [0] + list(range(2, input.dim()))
assert invstd is not None # for typing
mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False)
else:
assert running_mean is not None and running_var is not None
mean = running_mean
invstd = torch.rsqrt(running_var + eps)
assert invstd is not None and mean is not None
broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]
reduction_axes: List[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)
mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)
grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)
if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale
if output_mask[1]:
grad_weight = dot_p * invstd
elif weight is not None:
grad_weight = torch.zeros_like(
weight
) # should be None but doesn't work with vjp
else:
grad_weight = torch.zeros(()) # should be None but doesn't work with vjp
if output_mask[2]:
grad_bias = grad_output_sum
else:
grad_bias = torch.zeros_like(
grad_output_sum
) # should be None but doesn't work with vjp
return (grad_input, grad_weight, grad_bias)
@register_decomposition_for_jvp(aten.batch_norm_backward)
def batch_norm_backward(
grad_out: Tensor,
input: Tensor,
weight: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_var: Optional[Tensor],
update: bool,
eps: float,
output_mask: List[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
update,
eps,
output_mask,
)
_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default)
_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default)

View File

@ -0,0 +1,266 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
from collections import defaultdict
from typing import Callable, Dict
import torch
import torch._decomp as decomp
from torch._decomp import get_decompositions
from torch._ops import OpOverload
aten = torch.ops.aten
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
def register_rng_decomposition(aten_op):
return decomp.register_decomposition(aten_op, rng_decompositions)
def throw_on_non_cuda(device):
raise RuntimeError(
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
)
# TODO - We have to register many more distributions here, and also higher level
# ops like dropout which have fused implementation and can hide the rand inside.
@register_rng_decomposition(aten.rand)
def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
if device and device.type != "cuda":
throw_on_non_cuda(device)
seed, offset = PhiloxStateTracker.get_state_as_tuple()
dtype = dtype or torch.float32
out, offset_jump = torch.ops.rngprims.philox_rand(
shape, seed, offset, None, device, dtype
)
PhiloxStateTracker.advance_offset(offset_jump)
return out
@register_rng_decomposition(aten.rand_like)
def rand_like(
x: torch.Tensor,
dtype=None,
layout=None,
device=None,
pin_memory=False,
memory_format=torch.preserve_format,
):
device = device or x.device
if device.type != "cuda":
throw_on_non_cuda(device)
dtype = dtype or x.dtype
seed, offset = PhiloxStateTracker.get_state_as_tuple()
out, offset_jump = torch.ops.rngprims.philox_rand(
x.shape, seed, offset, None, device, dtype
)
PhiloxStateTracker.advance_offset(offset_jump)
return out
class PhiloxState:
"""
Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
relative_offset. seed and base_offset basically point to the rng state just
before tracing starts. relative offset tracks the totally consumed offset at
trace time.
"""
def __init__(self) -> None:
self.reset()
def reset(self):
self.seed = torch.tensor(())
self.base_offset = torch.tensor(())
self.relative_offset = 0
self.offset_advanced_alteast_once = False
def validate_state(self):
assert self.seed.numel() != 0 and self.base_offset.numel() != 0
def advance_offset(self, consumed_offset):
self.offset_advanced_alteast_once = True
self.relative_offset = self.relative_offset + consumed_offset
def set_state(self, seed, base_offset, relative_offset=0):
self.seed = seed
self.base_offset = base_offset
self.relative_offset = relative_offset
def get_state_as_tuple(self):
self.validate_state()
return (self.seed, self.base_offset + self.relative_offset)
def get_state_as_tensor(self):
# Only needed because we override get_rng_state.
self.validate_state()
return torch.stack([self.seed, self.base_offset + self.relative_offset])
def set_state_from_tensor(self, state):
# Only needed because we override set_rng_state.
self.seed, self.base_offset = torch.unbind(state)
self.relative_offset = 0
class PhiloxStateTracker:
"""
Singleton class to track the philox rng state during AOT Autograd tracing.
For each aot tracing instance, AOT Autograd resets this tracker and keeps
track of both forward and backward offsets. At runtime, we only care about
the total consumed forward and backward offsets. For dynamic shapes, these
offsets are a function of input shapes. Therefore, the AOT generated graphs
have additional outputs that compute total consumed forward and backward
offsets.
"""
running_state: PhiloxState
fwd_state: PhiloxState
bwd_state: PhiloxState
def __enter__(self):
PhiloxStateTracker.reset()
return self
def __exit__(self, exc_type, exc_cal, exc_tb):
PhiloxStateTracker.reset()
@classmethod
def reset(cls):
cls.running_state = PhiloxState()
cls.fwd_state = PhiloxState()
cls.bwd_state = PhiloxState()
@classmethod
def mark_beginning_of_forward(cls):
# Tells the tracker to use fwd_state as the running state
cls.running_state = cls.fwd_state
@classmethod
def mark_beginning_of_backward(cls):
# Tells the tracker to use bwd_state as the running state
cls.running_state = cls.bwd_state
@classmethod
def record_state(cls, seed, offset, mode):
# Records the seed and offset tensors. These tensors are used to invoke
# the philox_rand functional primitives.
if mode == "forward":
cls.fwd_state.set_state(seed, offset)
cls.mark_beginning_of_forward()
else:
assert mode == "backward"
cls.bwd_state.set_state(seed, offset)
@classmethod
def get_state_as_tensor(cls):
# The only reason this exists is because we override get_rng_state and
# set_rng_state during tracing. get_rng_state expects a tensor output,
# so return (seed, offset) tuple upset other parts of the program like
# ctx.saved_tensors.
# A bad consequence is that if user saves and restores rng state, we
# have little bit of ugliness in the generated code, where we first
# concat the (seed, offset) to create a tensor for get_rng_state, and
# then split it back to get (seed, offset) tuple in set_rng_state.
# TODO: Investigate if there is be a better way to wrap the tuple in a
# false Tensor object, and then desugar it later on.
return cls.running_state.get_state_as_tensor()
@classmethod
def get_state_as_tuple(cls):
return cls.running_state.get_state_as_tuple()
@classmethod
def set_state_from_tensor(cls, x):
# This is only needed because we override set_rng_state. Look at the
# comment in get_state_from_tensor method.
cls.running_state.set_state_from_tensor(x)
@classmethod
def advance_offset(cls, consumed_offset):
cls.running_state.advance_offset(consumed_offset)
@classmethod
def get_current_relative_offset(cls):
return cls.running_state.relative_offset
@staticmethod
def multiple_of_4(offset):
# torch cuda rng state offset must be a multiple of 4. For inductor, as
# we sum up all the numel, the result might not be a multiple of 4. This
# method achieves that.
return (offset + 3) // 4 * 4
@classmethod
def get_updated_fwd_offset(cls):
# Short circuit if no rand ops were observed
if not cls.fwd_state.offset_advanced_alteast_once:
return cls.fwd_state.base_offset
return cls.multiple_of_4(
cls.fwd_state.base_offset + cls.fwd_state.relative_offset
)
@classmethod
def get_updated_bwd_offset(cls):
# Short circuit if no rand ops were observed
if not cls.bwd_state.offset_advanced_alteast_once:
return cls.bwd_state.base_offset
return cls.multiple_of_4(
cls.bwd_state.base_offset + cls.bwd_state.relative_offset
)
# Adding more decompositions which eventually use rand_like inside decomps.
# Adding these in rng_decompositions ensures the functionalization of rand_like
# ops used in these decomps. The list is copied from inductor codebase, which
# uses it for similar purpose.
#
# Caution - These decomps do not have same accuracy as that of eager. However,
# we can't just disable them with a config flag like fallback_random, because
# for functionalization of rng ops, we have to decompose these ops.
extra_random_decomps = get_decompositions(
[
aten.cauchy,
aten.cauchy_,
aten.exponential,
aten.exponential_,
aten.geometric,
aten.geometric_,
aten.native_dropout,
aten.normal,
aten.normal_,
aten.normal_functional,
aten.log_normal,
aten.log_normal_,
aten.rrelu_with_noise,
aten.rrelu_with_noise_,
aten.uniform_,
]
)
register_extra_random_decomp = functools.partial(
decomp.register_decomposition, registry=extra_random_decomps
)
@register_extra_random_decomp([aten.bernoulli_])
def bernoulli_(self, p=0.5):
if self.device == torch.device("cpu"):
return NotImplemented
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
@register_extra_random_decomp([aten.bernoulli.p])
def bernoulli_p(self, p=0.5, *, generator=None):
if self.device == torch.device("cpu"):
return NotImplemented
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < p
rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type]

View File

@ -0,0 +1,104 @@
# mypy: allow-untyped-defs
import io
import torch
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
from torch.package._package_pickler import create_pickler
from torch.package._package_unpickler import PackageUnpickler
from torch.serialization import _maybe_decode_ascii
def _save_storages(importer, obj):
serialized_storages = []
serialized_dtypes = []
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
importers: Importer
if importer is not None:
importers = OrderedImporter(importer, sys_importer)
else:
importers = sys_importer
def persistent_id(obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
dtype = obj.dtype
else:
dtype = torch.uint8
serialized_storages.append(obj)
serialized_dtypes.append(dtype)
return ("storage", len(serialized_storages) - 1)
if hasattr(obj, "__reduce_deploy__"):
if _serialized_reduces.get(id(obj)) is None:
_serialized_reduces[id(obj)] = (
"reduce_deploy",
id(obj),
*obj.__reduce_deploy__(importers),
)
return _serialized_reduces[id(obj)]
return None
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = create_pickler(data_buf, importers)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
return (
data_value,
serialized_storages,
serialized_dtypes,
importer.zip_reader if importer else None,
)
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == "storage":
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
storage = serialized_storages[data[0]]
dtype = serialized_dtypes[data[0]]
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
)
if typename == "reduce_deploy":
reduce_id, func, args = data
if reduce_id not in _loaded_reduces:
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
return _loaded_reduces[reduce_id]
return None
importer: Importer
if zip_reader is not None:
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
else:
importer = sys_importer
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
unpickler.persistent_load = persistent_load # type: ignore[method-assign]
result = _deploy_objects[id] = unpickler.load()
return result
def _get_package(zip_reader):
if zip_reader not in _raw_packages:
_raw_packages[zip_reader] = PackageImporter(zip_reader)
return _raw_packages[zip_reader]
_raw_packages: dict = {}
_deploy_objects: dict = {}
_serialized_reduces: dict = {}
_loaded_reduces: dict = {}

View File

@ -0,0 +1,180 @@
# mypy: allow-untyped-defs
import itertools
import unittest.mock
from contextlib import contextmanager
from typing import Iterator
import torch
import torch._C
import torch._ops
import torch.utils._python_dispatch
import torch.utils._pytree as pytree
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
enable_pre_dispatch = torch._C._EnablePreDispatch
CROSSREF_FUNCTIONALIZE = False
def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
"""
Warning: the set of overloads this will report is very subtle. It is precisely
the set of torch.ops functions that have actually been accessed from Python
(e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
from the set of registered operators, which will in general be a larger set,
as this would include all operators which we ran C++ static initializers or
Python operator registration on. This does not eagerly populate the list on
torch.ops.aten; this list is lazy!
In other words, this is good for traversing over everything that has an
OpOverload object allocated in Python. We use it for cache invalidation, but
don't rely on this list being complete.
Note that even if we did report all C++ registered overloads, this isn't guaranteed
to be complete either, as a subsequent lazy load of a library which triggers more
registrations could add more things to the set.
"""
for ns in torch.ops:
packets = getattr(torch.ops, ns)
for op_name in packets:
packet = getattr(packets, op_name)
for overload in packet:
yield getattr(packet, overload)
@contextmanager
def suspend_functionalization():
f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
torch._C.DispatchKey.Functionalize
)
f_rv = torch._C._functionalization_reapply_views_tls()
if f_tls:
torch._disable_functionalization()
try:
yield
finally:
if f_tls:
torch._enable_functionalization(reapply_views=f_rv)
def check_tensor_metadata_matches(nv, rv, desc):
assert callable(desc)
assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
same_strides, idx = torch._prims_common.check_significant_strides(
nv, rv, only_cuda=False
)
assert (
same_strides
), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
def check_metadata_matches(n, r, desc):
assert callable(desc)
n_vals, n_spec = pytree.tree_flatten(n)
r_vals, r_spec = pytree.tree_flatten(r)
# TODO: test the specs match; empirically sometimes we have a tuple
# on one side and a list on the other
assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
if not isinstance(rv, torch.Tensor):
continue
check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
class Lit:
def __init__(self, s):
self.s = s
def __repr__(self):
return self.s
def _fmt(a: object) -> object:
if isinstance(a, torch.Tensor):
return Lit(
f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
)
else:
return a
def make_crossref_functionalize(op, final_key):
from torch._subclasses.fake_tensor import FakeTensorMode
# This case is pretty weird, suppress it for now
if op == torch.ops.aten.lift_fresh.default:
return final_key
def handler(*args, **kwargs):
fake_mode = FakeTensorMode()
def fakeify_defun(t):
if isinstance(t, torch.Tensor):
if torch._is_functional_tensor(t):
r = torch._from_functional_tensor(t)
# NB: This assumes that the inner tensor sizes/strides match
# the outer tensor sizes/strides. This doesn't necessarily have to
# be the case, see discussion at
# https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
assert t.size() == r.size()
assert t.stride() == r.stride()
else:
r = t
# TODO: suppress guards
return fake_mode.from_tensor(r)
return t
def maybe_detach(t):
if isinstance(t, torch.Tensor):
return t.detach()
else:
return t
# TODO: This probably does the wrong thing if you're running other
# substantive modes with the normal op outside here
with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(
maybe_detach, (f_args, f_kwargs)
)
with fake_mode:
f_r = op(*f_args, **f_kwargs)
r = op._op_dk(final_key, *args, **kwargs)
def desc():
fmt_args = ", ".join(
itertools.chain(
(repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
(
f"{k}={pytree.tree_map(_fmt, v)}"
for k, v in orig_f_kwargs.items()
),
)
)
return f"{op}({fmt_args})"
check_metadata_matches(f_r, r, desc)
return r
return handler
# NB: enabling this is slow, don't do it in a hot loop. This is purely
# for debugging purposes.
@contextmanager
def enable_crossref_functionalize():
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
try:
with enable_python_dispatcher(), unittest.mock.patch(
"torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
):
yield
finally:
for op in all_py_loaded_overloads():
op._uncache_dispatch(torch._C.DispatchKey.Functionalize)

View File

@ -0,0 +1,109 @@
import torch
from . import convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, lookup_backend, register_backend
from .callback import callback_handler, on_compile_end, on_compile_start
from .code_context import code_context
from .convert_frame import replay
from .decorators import (
allow_in_graph,
assume_constant_result,
disable,
disallow_in_graph,
forbid_in_graph,
graph_break,
mark_dynamic,
mark_static,
mark_static_address,
maybe_mark_dynamic,
run,
substitute_in_graph,
)
from .eval_frame import (
_reset_guarded_backend_cache,
explain,
export,
is_dynamo_supported,
is_inductor_supported,
optimize,
optimize_assert,
OptimizedModule,
reset_code,
)
from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
# Register polyfill functions
from .polyfills import loader as _ # usort: skip # noqa: F401
__all__ = [
"allow_in_graph",
"assume_constant_result",
"disallow_in_graph",
"forbid_in_graph",
"substitute_in_graph",
"graph_break",
"mark_dynamic",
"maybe_mark_dynamic",
"mark_static",
"mark_static_address",
"optimize",
"optimize_assert",
"export",
"explain",
"run",
"replay",
"disable",
"reset",
"OptimizedModule",
"is_compiling",
"register_backend",
"list_backends",
"lookup_backend",
]
if torch.manual_seed is torch.random.manual_seed:
import torch.jit._builtins
# Wrap manual_seed with the disable decorator.
# Can't do it at its implementation due to dependency issues.
torch.manual_seed = torch._disable_dynamo(torch.manual_seed)
# Add the new manual_seed to the builtin registry.
torch.jit._builtins._register_builtin(torch.manual_seed, "aten::manual_seed")
def reset() -> None:
"""Clear all compile caches and restore initial state"""
with convert_frame.compile_lock:
reset_code_caches()
convert_frame.input_codes.clear()
convert_frame.output_codes.clear()
orig_code_map.clear()
guard_failures.clear()
graph_break_reasons.clear()
resume_execution.ContinueExecutionCache.cache.clear()
_reset_guarded_backend_cache()
reset_frame_count()
torch._C._dynamo.compiled_autograd.clear_cache()
convert_frame.FRAME_COUNTER = 0
convert_frame.FRAME_COMPILE_COUNTER.clear()
callback_handler.clear()
GenerationTracker.clear()
torch._dynamo.utils.warn_once_cache.clear()
torch._dynamo.utils.user_obj_id_to_weakref.clear()
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
def reset_code_caches() -> None:
"""Clear compile caches that are keyed by code objects"""
with convert_frame.compile_lock:
for weak_code in (
convert_frame.input_codes.seen + convert_frame.output_codes.seen
):
code = weak_code()
if code:
reset_code(code)
code_context.clear()

Some files were not shown because too many files have changed in this diff Show More