I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,75 @@
from typing import Callable
from torch._utils import CallbackRegistry
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
EventCreationCallbacks.add_callback(cb)
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
EventDeletionCallbacks.add_callback(cb)
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
EventRecordCallbacks.add_callback(cb)
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
EventWaitCallbacks.add_callback(cb)
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
MemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
MemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
StreamCreationCallbacks.add_callback(cb)
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
DeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
StreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
EventSynchronizationCallbacks.add_callback(cb)

View File

@ -0,0 +1,632 @@
# mypy: allow-untyped-defs
import pickle
import sys
import os
import io
import subprocess
import json
from functools import lru_cache
from typing import Any
from itertools import groupby
import base64
import warnings
import operator
cache = lru_cache(None)
__all__ = ["format_flamegraph", "segments", "memory", "compare"]
def _frame_fmt(f, full_filename=False):
i = f['line']
fname = f['filename']
if not full_filename:
fname = fname.split('/')[-1]
func = f['name']
return f'{fname}:{i}:{func}'
@cache
def _frame_filter(name, filename):
omit_functions = [
"unwind::unwind",
"CapturedTraceback::gather",
"gather_with_cpp",
"_start",
"__libc_start_main",
"PyEval_",
"PyObject_",
"PyFunction_",
]
omit_filenames = [
"core/boxing",
"/Register",
"/Redispatch",
"pythonrun.c",
"Modules/main.c",
"Objects/call.c",
"Objects/methodobject.c",
"pycore_ceval.h",
"ceval.c",
"cpython/abstract.h",
]
for of in omit_functions:
if of in name:
return False
for of in omit_filenames:
if of in filename:
return False
return True
def _frames_fmt(frames, full_filename=False, reverse=False):
if reverse:
frames = reversed(frames)
return [_frame_fmt(f, full_filename) for f in frames if _frame_filter(f['name'], f['filename'])]
def _block_extra_legacy(b):
if 'history' in b:
frames = b['history'][0].get('frames', [])
real_size = b['history'][0]['real_size']
else:
real_size = b.get('requested_size', b['size'])
frames = []
return frames, real_size
def _block_extra(b):
if 'frames' not in b:
# old snapshot format made it more complicated to get frames/allocated size
return _block_extra_legacy(b)
return b['frames'], b['requested_size']
def format_flamegraph(flamegraph_lines, flamegraph_script=None):
if flamegraph_script is None:
flamegraph_script = f'/tmp/{os.getuid()}_flamegraph.pl'
if not os.path.exists(flamegraph_script):
import urllib.request
print(f"Downloading flamegraph.pl to: {flamegraph_script}")
urllib.request.urlretrieve(
'https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl', flamegraph_script)
subprocess.check_call(['chmod', '+x', flamegraph_script])
args = [flamegraph_script, '--countname', 'bytes']
p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='utf-8')
assert p.stdin is not None
assert p.stdout is not None
p.stdin.write(flamegraph_lines)
p.stdin.close()
result = p.stdout.read()
p.stdout.close()
p.wait()
assert p.wait() == 0
return result
def _write_blocks(f, prefix, blocks):
def frames_fragment(frames):
if not frames:
return "<non-python>"
return ';'.join(_frames_fmt(frames, reverse=True))
for b in blocks:
if 'history' not in b:
frames, accounted_for_size = _block_extra(b)
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {accounted_for_size}\n')
else:
accounted_for_size = 0
for h in b['history']:
sz = h['real_size']
accounted_for_size += sz
if 'frames' in h:
frames = h['frames']
f.write(f'{prefix};{b["state"]};{frames_fragment(frames)} {sz}\n')
else:
f.write(f'{prefix};{b["state"]};<no-context> {sz}\n')
gaps = b['size'] - accounted_for_size
if gaps:
f.write(f'{prefix};{b["state"]};<gaps> {gaps}\n')
def segments(snapshot, format_flamegraph=format_flamegraph):
f = io.StringIO()
for seg in snapshot['segments']:
prefix = f'stream_{seg["stream"]};seg_{seg["address"]}'
_write_blocks(f, prefix, seg['blocks'])
return format_flamegraph(f.getvalue())
def memory(snapshot, format_flamegraph=format_flamegraph):
f = io.StringIO()
for seg in snapshot['segments']:
prefix = f'stream_{seg["stream"]}'
_write_blocks(f, prefix, seg['blocks'])
return format_flamegraph(f.getvalue())
def compare(before, after, format_flamegraph=format_flamegraph):
def _seg_key(seg):
return (seg['address'], seg['total_size'])
def _seg_info(seg):
return f'stream_{seg["stream"]};seg_{seg["address"]}'
f = io.StringIO()
before_segs = {_seg_key(seg) for seg in before}
after_segs = {_seg_key(seg) for seg in after}
print(f'only_before = {[a for a, _ in (before_segs - after_segs)]}')
print(f'only_after = {[a for a, _ in (after_segs - before_segs)]}')
for seg in before:
if _seg_key(seg) not in after_segs:
_write_blocks(f, f'only_before;{_seg_info(seg)}', seg['blocks'])
for seg in after:
if _seg_key(seg) not in before_segs:
_write_blocks(f, f'only_after;{_seg_info(seg)}', seg['blocks'])
return format_flamegraph(f.getvalue())
def _format_size(num):
# https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return f"{num:3.1f}{unit}B"
num /= 1024.0
return f"{num:.1f}YiB"
class Bytes:
def __init__(self, value):
self.value = value
def __add__(self, rhs):
return Bytes(self.value + rhs)
def __repr__(self):
return _format_size(self.value)
def calc_active(seg):
return sum(b['size'] for b in seg['blocks'] if b['state'] == 'active_allocated')
def _report_free(free_external, free_internal):
total = free_external + free_internal
suffix = ''
if total != 0:
pct = (free_internal / total) * 100
suffix = f' ({pct:.1f}% internal)'
return f'{Bytes(total)}{suffix}'
PAGE_SIZE = 1024 * 1024 * 20
legend = f"""\
Legend:
[a ] - a segment in the allocator
^-- a page {Bytes(PAGE_SIZE)} of memory in the segment
a-z: pages filled with a single block's content
' ': page is completely free
*: page if completely full with multiple blocks
0-9: page is partially full with tensors of multiple blocks (9 == 90% full)
(X% internal) - of the free memory, X% is free because we rounded the size of the allocation.
"""
def segsum(data):
r"""Visually reports how the allocator has filled its segments.
This printout can help debug fragmentation issues since free fragments
will appear as gaps in this printout. The amount of free space is reported
for each segment.
We distinguish between internal free memory which occurs because the
allocator rounds the allocation size, and external free memory, which are
the gaps between allocations in a segment.
Args:
data: snapshot dictionary created from _snapshot()
"""
segments = []
out = io.StringIO()
out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n")
total_reserved = 0
total_allocated = 0
free_external = 0
free_internal = 0
for seg in sorted(data['segments'], key=lambda x: (x['total_size'], calc_active(x))):
total_reserved += seg['total_size']
seg_free_external = 0
seg_free_internal = 0
seg_allocated = 0
all_ranges = []
boffset = 0
for b in seg['blocks']:
active = b['state'] == 'active_allocated'
if active:
_, allocated_size = _block_extra(b)
all_ranges.append((boffset, allocated_size, True))
seg_allocated += allocated_size
seg_free_internal += b['size'] - allocated_size
else:
seg_free_external += b['size']
boffset += b['size']
total_allocated += seg_allocated
free_external += seg_free_external
free_internal += seg_free_internal
nseg = (seg['total_size'] - 1) // PAGE_SIZE + 1
occupied = [' ' for _ in range(nseg)]
frac = [0.0 for _ in range(nseg)]
active_size = 0
for i, (start_, size, active) in enumerate(all_ranges):
active_size += size
finish_ = (start_ + size)
start = start_ // PAGE_SIZE
finish = (finish_ - 1) // PAGE_SIZE + 1
m = chr(ord('a' if active else 'A') + (i % 26))
for j in range(start, finish):
s = max(start_, j * PAGE_SIZE)
e = min(finish_, (j + 1) * PAGE_SIZE)
frac[j] += (e - s) / PAGE_SIZE
if occupied[j] != ' ':
occupied[j] = '0123456789*'[int(frac[j] * 10)]
else:
occupied[j] = m
stream = '' if seg['stream'] == 0 else f', stream_{seg["stream"]}'
body = ''.join(occupied)
assert seg_free_external + seg_free_internal + seg_allocated == seg['total_size']
stream = f' stream_{seg["stream"]}' if seg['stream'] != 0 else ''
if seg['total_size'] >= PAGE_SIZE:
out.write(f'[{body}] {Bytes(seg["total_size"])} allocated, '
f'{_report_free(seg_free_external, seg_free_internal)} free{stream}\n')
out.write(f'segments: {len(data["segments"])}\n')
out.write(f'total_reserved: {Bytes(total_reserved)}\n')
out.write(f'total_allocated: {Bytes(total_allocated)}\n')
internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else ''
out.write(f'total_free: {_report_free(free_external, free_internal)}\n')
out.write(legend)
assert free_internal + free_external + total_allocated == total_reserved
return out.getvalue()
def trace(data):
out = io.StringIO()
def format(entries):
segment_intervals : list = []
segment_addr_to_name = {}
allocation_addr_to_name = {}
free_names : list = []
next_name = 0
def _name():
nonlocal next_name
if free_names:
return free_names.pop()
r, m = next_name // 26, next_name % 26
next_name += 1
return f'{chr(ord("a") + m)}{"" if r == 0 else r}'
def find_segment(addr):
for name, saddr, size in segment_intervals:
if addr >= saddr and addr < saddr + size:
return name, saddr
for i, seg in enumerate(data['segments']):
saddr = seg['address']
size = seg['allocated_size']
if addr >= saddr and addr < saddr + size:
return f'seg_{i}', saddr
return None, None
count = 0
out.write(f'{len(entries)} entries\n')
total_reserved = 0
for seg in data['segments']:
total_reserved += seg['total_size']
for count, e in enumerate(entries):
if e['action'] == 'alloc':
addr, size = e['addr'], e['size']
n = _name()
seg_name, seg_addr = find_segment(addr)
if seg_name is None:
seg_name = "MEM"
offset = addr
else:
offset = addr - seg_addr
out.write(f'{n} = {seg_name}[{offset}:{Bytes(size)}]\n')
allocation_addr_to_name[addr] = (n, size, count)
count += size
elif e['action'] == 'free_requested':
addr, size = e['addr'], e['size']
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
out.write(f'del {name} # {Bytes(size)}\n')
elif e['action'] == 'free_completed':
addr, size = e['addr'], e['size']
count -= size
name, _, _ = allocation_addr_to_name.get(addr, (addr, None, None))
out.write(f'# free completed for {name} {Bytes(size)}\n')
if name in allocation_addr_to_name:
free_names.append(name)
del allocation_addr_to_name[name]
elif e['action'] == 'segment_alloc':
addr, size = e['addr'], e['size']
name = _name()
out.write(f'{name} = cudaMalloc({addr}, {Bytes(size)})\n')
segment_intervals.append((name, addr, size))
segment_addr_to_name[addr] = name
elif e['action'] == 'segment_free':
addr, size = e['addr'], e['size']
name = segment_addr_to_name.get(addr, addr)
out.write(f'cudaFree({name}) # {Bytes(size)}\n')
if name in segment_addr_to_name:
free_names.append(name)
del segment_addr_to_name[name]
elif e['action'] == 'oom':
size = e['size']
free = e['device_free']
out.write(f'raise OutOfMemoryError # {Bytes(size)} requested, {Bytes(free)} free in CUDA\n')
else:
out.write(f'{e}\n')
out.write(f"TOTAL MEM: {Bytes(count)}")
for i, d in enumerate(data['device_traces']):
if d:
out.write(f'Device {i} ----------------\n')
format(d)
return out.getvalue()
_memory_viz_template = r"""
<!DOCTYPE html>
<html>
<head>
</head>
<body>
<script type="module">
import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
const local_files = $SNAPSHOT
add_local_files(local_files, $VIZ_KIND)
</script>
</body>
"""
def _format_viz(data, viz_kind, device):
if device is not None:
warnings.warn(
'device argument is deprecated, plots now contain all device',
FutureWarning,
stacklevel=3,
)
buffer = pickle.dumps(data)
buffer += b'\x00' * (3 - len(buffer) % 3)
# Encode the buffer with base64
encoded_buffer = base64.b64encode(buffer).decode('utf-8')
json_format = json.dumps([{"name": 'snapshot.pickle', "base64": encoded_buffer}])
return _memory_viz_template.replace('$VIZ_KIND', repr(viz_kind)) \
.replace('$SNAPSHOT', json_format)
def trace_plot(data, device=None, plot_segments=False):
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
Args:
data: Memory snapshot as generated from torch.cuda.memory._snapshot()
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
Defaults to False.
Returns:
str: HTML of visualization
"""
return _format_viz(data, 'Active Memory Timeline' if not plot_segments else 'Active Cached Memory Timeline', device)
def _profile_to_snapshot(profile):
import torch
from torch.profiler._memory_profiler import Action, TensorKey
from torch._C._profiler import _EventType
memory_profile = profile._memory_profile()
allocation_stacks = {}
for event in memory_profile._op_tree.sorted_nodes:
if event.tag == _EventType.Allocation:
parent = event.parent
python_parents = []
while parent:
if parent.tag in (_EventType.PyCall, _EventType.PyCCall):
python_parents.append(parent)
parent = parent.parent
key = TensorKey.from_allocation(event.extra_fields)
# Corner case: If allocation doesn't have an ID (can't prove it was used as a Tensor)
# key will be None. I should add some way to identify these, I just haven't yet.
if key and event.extra_fields.alloc_size > 0:
allocation_stacks[key] = python_parents
device_count = torch.cuda.device_count()
snapshot = {
'device_traces': [[] for _ in range(device_count + 1)],
'segments': [{'device': device,
'address': None,
'total_size': 0,
'stream': 0,
'blocks': []} for device in range(device_count + 1)]
}
def to_device(device):
if device.type == 'cuda':
return device.index
else:
return device_count
def allocate(size, tensor_key, version, during_trace=True):
device = to_device(tensor_key.device)
addr = tensor_key.storage.ptr
seg = snapshot['segments'][device] # type: ignore[index]
if seg['address'] is None or seg['address'] > addr:
seg['address'] = addr
seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later
category = memory_profile._categories.get(tensor_key, version)
category = category.name.lower() if category is not None else "unknown"
stack = allocation_stacks.get(tensor_key, ())
stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack]
r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category}
if during_trace:
snapshot['device_traces'][device].append(r) # type: ignore[index]
return r
def free(alloc, device):
for e in ('free_requested', 'free_completed'):
snapshot['device_traces'][device].append({'action': e, # type: ignore[index]
'addr': alloc['addr'],
'size': alloc['size'],
'stream': 0,
'frames': alloc['frames']})
kv_to_elem = {}
# create the device trace
for time, action, (tensor_key, version), size in memory_profile.timeline:
if not isinstance(tensor_key, TensorKey):
continue
if action == Action.CREATE:
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version)
elif action == Action.DESTROY:
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
elif action == Action.INCREMENT_VERSION:
free(kv_to_elem.pop((tensor_key, version)), to_device(tensor_key.device))
kv_to_elem[(tensor_key, version + 1)] = allocate(size, tensor_key, version + 1)
elif action == Action.PREEXISTING:
kv_to_elem[(tensor_key, version)] = allocate(size, tensor_key, version, during_trace=False)
# create the final snapshot state
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
for (tensor_key, version), event in kv_to_elem.items()]
for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
seg = snapshot['segments'][device] # type: ignore[index]
last_addr = seg['address']
for _, addr, size, frames in blocks:
if last_addr < addr:
seg['blocks'].append({'size': addr - last_addr, 'state': 'inactive'})
seg['blocks'].append({'size': size, 'state': 'active_allocated', 'requested_size': size, 'frames': frames})
last_addr = addr + size
if last_addr < seg['total_size']:
seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'})
snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined]
for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef]
seg['total_size'] -= seg['address']
if not seg['blocks']:
seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'})
return snapshot
def profile_plot(profile, device=None):
"""Generate a visualization over time of the memory usage recorded by kineto memory profiling as an html file.
Args:
profile: profile as generated by `torch.profiler.profile(profile_memory=True)`
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
Returns:
str: HTML of visualization
"""
snapshot = _profile_to_snapshot(profile)
return _format_viz(snapshot, 'Active Memory Timeline', device)
def segment_plot(data: Any, device=None):
return _format_viz(data, 'Allocator State History', device)
if __name__ == "__main__":
import os.path
thedir = os.path.realpath(os.path.dirname(__file__))
if thedir in sys.path:
# otherwise we find cuda/random.py as random...
sys.path.remove(thedir)
import argparse
fn_name = 'torch.cuda.memory._snapshot()'
pickled = f'pickled memory statistics from {fn_name}'
parser = argparse.ArgumentParser(description=f'Visualize memory dumps produced by {fn_name}')
subparsers = parser.add_subparsers(dest='action')
def _output(p):
p.add_argument('-o', '--output', default='output.svg', help='flamegraph svg (default: output.svg)')
description = 'Prints overall allocation statistics and a visualization of how the allocators segments are currently filled.'
stats_a = subparsers.add_parser('stats', description=description)
stats_a.add_argument('input', help=pickled)
description = 'Prints buffer of the most recent allocation events embedded in the snapshot in a Pythonic style.'
trace_a = subparsers.add_parser('trace', description=description)
trace_a.add_argument('input', help=pickled)
description = 'Generate a flamegraph that visualizes what memory is stored in each allocator segment (aka block)'
segments_a = subparsers.add_parser('segments', description=description)
segments_a.add_argument('input', help=pickled)
_output(segments_a)
description = "Generate a flamegraph the program locations contributing to CUDA memory usage."
memory_a = subparsers.add_parser('memory', description=description)
memory_a.add_argument('input', help=pickled)
_output(memory_a)
description = 'Generate a flamegraph that shows segments (aka blocks) that have been added ' \
'or removed between two different memorys snapshots.'
compare_a = subparsers.add_parser('compare', description=description)
compare_a.add_argument('before', help=pickled)
compare_a.add_argument('after', help=pickled)
_output(compare_a)
plots = (
("trace_plot", "Generate a visualization over time of the memory usage recorded by the trace as an html file."),
("segment_plot", "Visualize how allocations are packed into allocator segments at each point in a trace as an html file.")
)
for cmd, description in plots:
trace_plot_a = subparsers.add_parser(cmd, description=description)
trace_plot_a.add_argument('input', help=pickled)
help = 'visualize trace from this device (default: chooses the only device with trace info or errors)'
trace_plot_a.add_argument('-d', '--device', type=int, default=None, help=help)
help = 'path to save the visualization(default: output.html)'
trace_plot_a.add_argument('-o', '--output', default='output.html', help=help)
if cmd == "trace_plot":
help = 'visualize change to segments rather than individual allocations'
trace_plot_a.add_argument('-s', '--segments', action='store_true', help=help)
args = parser.parse_args()
def _read(name):
if name == '-':
f = sys.stdin.buffer
else:
f = open(name, 'rb')
data = pickle.load(f)
if isinstance(data, list): # segments only...
data = {'segments': data, 'traces': []}
return data
def _write(name, data):
with open(name, 'w') as f:
f.write(data)
if args.action == 'segments':
data = _read(args.input)
_write(args.output, segments(data))
elif args.action == 'memory':
data = _read(args.input)
_write(args.output, memory(data))
elif args.action == 'stats':
data = _read(args.input)
print(segsum(data))
elif args.action == 'trace':
data = _read(args.input)
print(trace(data))
elif args.action == 'compare':
before = _read(args.before)
after = _read(args.after)
_write(args.output, compare(before, after))
elif args.action == 'trace_plot':
data = _read(args.input)
_write(args.output, trace_plot(data, device=args.device, plot_segments=args.segments))
elif args.action == 'segment_plot':
data = _read(args.input)
_write(args.output, segment_plot(data, device=args.device))

View File

@ -0,0 +1,621 @@
# mypy: allow-untyped-defs
r"""
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
It stores information on accesses to tensors to determine if they are synchronized
or not. When enabled in a python program and a possible data race is detected, a
detailed warning will be printed and the program will exit.
It can be enabled either by importing this module and calling
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
environment variable.
"""
import enum
import functools
import inspect
import io
import logging
import sys
import textwrap
import traceback
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
import torch
import torch.cuda._gpu_trace as gpu_trace
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
DEFAULT_STREAM_ID = 0
TK = TypeVar("TK")
TVa = TypeVar("TVa")
TVb = TypeVar("TVb")
DataPtr = int
StreamId = int
EventId = int
SeqNum = int
logger = logging.getLogger(__name__)
class AccessType(enum.Enum):
READ = enum.auto()
WRITE = enum.auto()
def __str__(self):
return "reading from" if self is AccessType.READ else "writing to"
@dataclass
class Access:
r"""Stores information about a single access to a tensor by a kernel.
Args:
type: either AccessType.READ or AccessType.Write.
seq_num: the sequential number of the kernel performing the access.
stream: the stream id of the stream executing the kernel.
operator: the schema of the launched kernel, which lists the
arguments and return type.
aliases: the arguments in the schema this access corresponds to.
is_output: Whether the tensor was an output of the kernel.
stack_trace: the stack summary object captured during access.
"""
type: AccessType
seq_num: SeqNum
stream: StreamId
operator: str
aliases: List[str]
is_output: bool
stack_trace: traceback.StackSummary
class SynchronizationError(Exception):
"""Base class for errors detected by CUDA Sanitizer."""
class UnsynchronizedAccessError(SynchronizationError):
"""Stores information about two unsynchronized accesses to one data pointer."""
def __init__(
self,
data_ptr: DataPtr,
allocation_stack_trace: Optional[traceback.StackSummary],
current_access: Access,
previous_access: Access,
):
self.data_ptr = data_ptr
self.allocation_stack_trace = allocation_stack_trace
self.current_access = current_access
self.previous_access = previous_access
def __str__(self):
def format_access(access: Access):
message.write(f"{access.operator}\n{access.type}")
if access.aliases:
message.write(" argument(s) " + ", ".join(access.aliases))
if access.is_output:
message.write(", and to")
if access.is_output:
message.write(" the output")
message.write(
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
)
with io.StringIO() as message:
message.write(
textwrap.dedent(
f"""\
============================
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
Access by stream {self.current_access.stream} during kernel:
"""
)
)
format_access(self.current_access)
message.write(
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
)
format_access(self.previous_access)
if self.allocation_stack_trace:
message.write(
"Tensor was allocated with stack trace:\n"
f"{''.join(self.allocation_stack_trace.format())}"
)
else:
message.write("Trace for tensor allocation not found.")
return message.getvalue()
class CUDASanitizerErrors(Exception):
"""Wrapper class for errors reported by CUDA Sanitizer."""
def __init__(self, errors: List[SynchronizationError]):
self.errors = errors
def __str__(self):
return f"detected {len(self.errors)} errors"
@dataclass
class TensorInfo:
r"""Stores information about a single tensor and recent accesses to it.
Args:
allocation_stack_trace: the stack summary object captured during tensor
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
reads: list of read accesses to the tensor that were performed since
the last write.
write: the last write access to the tensor.
"""
allocation_stack_trace: Optional[traceback.StackSummary]
reads: List[Access] = field(default_factory=list)
write: Optional[Access] = None
class _TensorsAccessed:
def __init__(self) -> None:
self.accesses: Dict[DataPtr, TensorInfo] = {}
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
if data_ptr not in self.accesses:
logger.info(
"Found tensor with pointer: %s, but no matching tensor "
"allocation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
data_ptr,
)
self.create_tensor(data_ptr, None)
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
if data_ptr in self.accesses:
logger.info(
"Found duplicate tensor allocation in the trace for tensor with "
"pointer: %s. Assuming the trace for tensor deallocation "
"wasn't caught and backfilling it now. "
"Perhaps the sanitizer was enabled after some torch operations?",
data_ptr,
)
self.delete_tensor(data_ptr)
def create_tensor(
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
) -> None:
self.accesses[data_ptr] = TensorInfo(stack_trace)
def delete_tensor(self, data_ptr: DataPtr) -> None:
del self.accesses[data_ptr]
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
return True if self.accesses[data_ptr].reads else False
def get_allocation_stack_trace(
self, data_ptr: DataPtr
) -> Optional[traceback.StackSummary]:
return self.accesses[data_ptr].allocation_stack_trace
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
return self.accesses[data_ptr].write
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
return self.accesses[data_ptr].reads
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
self.accesses[data_ptr].reads.append(access)
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
self.accesses[data_ptr].write = access
self.accesses[data_ptr].reads = []
class StreamSynchronizations:
def __init__(self) -> None:
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
self.host_sync_state: Dict[StreamId, SeqNum] = {}
self.create_stream(DEFAULT_STREAM_ID)
def _ensure_stream_exists(self, stream: StreamId) -> None:
if stream not in self.current_sync_states:
logger.info(
"Found Stream with id: %s, but no matching stream "
"creation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
stream,
)
self.create_stream(stream)
def _ensure_event_exists(self, event: EventId) -> None:
if event not in self.recorded_sync_states:
logger.info(
"Found Event with id: %s, but no matching event "
"creation in the trace. Backfilling the trace now. "
"Perhaps the sanitizer was enabled after some torch operations?",
event,
)
self.create_event(event)
def _ensure_event_does_not_exist(self, event: EventId) -> None:
if event in self.recorded_sync_states:
logger.info(
"Found duplicate event creation in the trace for event with "
"id: %s. Assuming the trace for event deletion wasn't caught "
"and backfilling it now. "
"Perhaps the sanitizer was enabled after some torch operations?",
event,
)
self.delete_event(event)
def create_stream(self, stream: StreamId) -> None:
if stream in self.current_sync_states:
logger.info(
"Found duplicate Stream creation in the trace for Stream with "
"id: %s. PyTorch Streams are only created once, so this "
"trace entry is ignored.",
stream,
)
else:
self.host_sync_state[stream] = 0
self.current_sync_states[stream] = self.host_sync_state.copy()
def create_event(self, event: EventId) -> None:
self._ensure_event_does_not_exist(event)
self.recorded_sync_states[event] = {}
def delete_event(self, event: EventId) -> None:
self._ensure_event_exists(event)
del self.recorded_sync_states[event]
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
self._ensure_stream_exists(stream)
self.current_sync_states[stream][stream] = seq_num
def record_state(self, event: EventId, stream: StreamId) -> None:
self._ensure_event_exists(event)
self._ensure_stream_exists(stream)
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
def _state_wait_for_other(
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
) -> None:
for stream, seq_num in other.items():
state[stream] = max(state.get(stream, -1), seq_num)
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
self._ensure_stream_exists(stream)
self._ensure_event_exists(event)
self._state_wait_for_other(
self.current_sync_states[stream], self.recorded_sync_states[event]
)
def all_streams_wait_for_event(self, event: EventId) -> None:
self._ensure_event_exists(event)
for stream in self.current_sync_states.keys():
self.stream_wait_for_event(stream, event)
self._state_wait_for_other(
self.host_sync_state, self.recorded_sync_states[event]
)
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
self._ensure_stream_exists(stream)
for state in self.current_sync_states.values():
self._state_wait_for_other(state, self.current_sync_states[stream])
self._state_wait_for_other(
self.host_sync_state, self.current_sync_states[stream]
)
def sync_all_streams(self) -> None:
for stream, state in self.current_sync_states.items():
self.host_sync_state[stream] = state[stream]
for state in self.current_sync_states.values():
self._state_wait_for_other(state, self.host_sync_state)
def is_ordered_after(
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
) -> bool:
self._ensure_stream_exists(current_stream)
self._ensure_stream_exists(other_stream)
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
class EventHandler:
"""Analyzes CSAN trace for synchronization errors.
Stores information on each stream's synchronizations with other streams as well
as tensor accesses to determine whether a given kernel launch might cause a
data race.
"""
def __init__(self) -> None:
self.tensors_accessed = _TensorsAccessed()
self.syncs = StreamSynchronizations()
self.seq_num: SeqNum = 0
def _handle_kernel_launch(
self,
stream: StreamId,
read_only: Set[DataPtr],
read_write: Set[DataPtr],
outputs: Set[DataPtr],
operator: str,
tensor_aliases: Dict[int, List[str]],
) -> List[SynchronizationError]:
def check_conflict(
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
) -> None:
if previous_access is None:
return
if not self.syncs.is_ordered_after(
current_access.stream, previous_access.seq_num, previous_access.stream
):
error_list.append(
UnsynchronizedAccessError(
data_ptr,
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
current_access,
previous_access,
)
)
error_list: List[SynchronizationError] = []
self.seq_num += 1
self.syncs.update_seq_num(stream, self.seq_num)
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
)
# The stack trace generated in this way is in the inverse order, so it must be
# reversed.
stack_trace.reverse()
for data_ptr in read_only:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
current_access = Access(
AccessType.READ,
self.seq_num,
stream,
operator,
tensor_aliases[data_ptr],
data_ptr in outputs,
stack_trace,
)
check_conflict(
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
)
self.tensors_accessed.add_read(data_ptr, current_access)
for data_ptr in read_write:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
current_access = Access(
AccessType.WRITE,
self.seq_num,
stream,
operator,
tensor_aliases[data_ptr],
data_ptr in outputs,
stack_trace,
)
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
for previous_access in self.tensors_accessed.get_reads(data_ptr):
check_conflict(data_ptr, current_access, previous_access)
else:
check_conflict(
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
)
self.tensors_accessed.set_write(data_ptr, current_access)
return error_list
def _handle_event_creation(self, event: EventId) -> None:
self.syncs.create_event(event)
def _handle_event_deletion(self, event: EventId) -> None:
self.syncs.delete_event(event)
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
self.syncs.record_state(event, stream)
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
self.syncs.stream_wait_for_event(stream, event)
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
stack_trace = traceback.StackSummary.extract(
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
)
# The stack trace generated in this way is in the inverse order, so it must be
# reversed.
stack_trace.reverse()
self.tensors_accessed.create_tensor(
data_ptr,
stack_trace,
)
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
self.tensors_accessed.ensure_tensor_exists(data_ptr)
self.tensors_accessed.delete_tensor(data_ptr)
def _handle_stream_creation(self, stream: StreamId) -> None:
self.syncs.create_stream(stream)
def _handle_device_synchronization(self) -> None:
self.syncs.sync_all_streams()
def _handle_stream_synchronization(self, stream: StreamId) -> None:
self.syncs.all_streams_wait_for_stream(stream)
def _handle_event_synchronization(self, event: EventId) -> None:
self.syncs.all_streams_wait_for_event(event)
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
for arg, value in a.items():
if arg in b:
yield arg, value, b[arg]
def zip_arguments(
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterator[Tuple[torch.Argument, Any]]:
schema_args = schema.arguments[: len(args)]
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
yield from zip(schema_args, args)
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
yield (argument, value)
class ArgumentHandler:
def __init__(self) -> None:
self.dataptrs_read: Set[DataPtr] = set()
self.dataptrs_written: Set[DataPtr] = set()
self.tensor_aliases: Dict[DataPtr, List[str]] = {}
self.outputs: Set[DataPtr] = set()
def _handle_argument(
self,
value: Any,
is_write: bool,
name: Optional[str] = None,
is_output: bool = False,
) -> None:
if isinstance(value, torch.Tensor) and value.is_cuda:
data_ptr = value.data_ptr()
if is_write:
self.dataptrs_written.add(data_ptr)
else:
self.dataptrs_read.add(data_ptr)
self.tensor_aliases.setdefault(data_ptr, [])
if name is not None:
self.tensor_aliases[data_ptr].append(name)
if is_output:
self.outputs.add(data_ptr)
def parse_inputs(
self,
schema: torch.FunctionSchema,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
for argument, value in zip_arguments(schema, args, kwargs):
is_write = argument.alias_info is not None and argument.alias_info.is_write
pytree.tree_map_(
functools.partial(
self._handle_argument, is_write=is_write, name=argument.name
),
value,
)
def parse_outputs(self, outputs: Any) -> None:
pytree.tree_map_(
functools.partial(self._handle_argument, is_write=True, is_output=True),
outputs,
)
class CUDASanitizerDispatchMode(TorchDispatchMode):
def __init__(self) -> None:
self.event_handler = EventHandler()
torch._C._activate_gpu_trace()
gpu_trace.register_callback_for_event_creation(
self.event_handler._handle_event_creation
)
gpu_trace.register_callback_for_event_deletion(
self.event_handler._handle_event_deletion
)
gpu_trace.register_callback_for_event_record(
self.event_handler._handle_event_record
)
gpu_trace.register_callback_for_event_wait(
self.event_handler._handle_event_wait
)
gpu_trace.register_callback_for_memory_allocation(
self.event_handler._handle_memory_allocation
)
gpu_trace.register_callback_for_memory_deallocation(
self.event_handler._handle_memory_deallocation
)
gpu_trace.register_callback_for_stream_creation(
self.event_handler._handle_stream_creation
)
gpu_trace.register_callback_for_device_synchronization(
self.event_handler._handle_device_synchronization
)
gpu_trace.register_callback_for_stream_synchronization(
self.event_handler._handle_stream_synchronization
)
gpu_trace.register_callback_for_event_synchronization(
self.event_handler._handle_event_synchronization
)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
argument_handler = ArgumentHandler()
argument_handler.parse_inputs(func._schema, args, kwargs)
outputs = func(*args, **kwargs)
argument_handler.parse_outputs(outputs)
errors = self.event_handler._handle_kernel_launch(
torch.cuda.current_stream().cuda_stream,
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
argument_handler.dataptrs_written,
argument_handler.outputs,
func._schema,
argument_handler.tensor_aliases,
)
if errors:
for error in errors:
print(error, file=sys.stderr)
raise CUDASanitizerErrors(errors)
return outputs
class CUDASanitizer:
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
context manager in the enable function/destructor, respectively. This is to
explicitly set the lifetime of the dispatch mode object to that of the application.
This approach was deemed more elegant than using the atexit module.
"""
def __init__(self) -> None:
self.dispatch = CUDASanitizerDispatchMode()
self.enabled = False
def enable(self):
self.dispatch.__enter__()
self.enabled = True
def __del__(self):
if self.enabled:
self.dispatch.__exit__(None, None, None)
def enable_cuda_sanitizer():
"""Enable CUDA Sanitizer.
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
for synchronization errors. All data races found will be printed to the standard
error output along with stack traces of suspected causes. For best results, the
sanitizer should be enabled at the very beginning of the program.
"""
cuda_sanitizer.enable()
cuda_sanitizer = CUDASanitizer()

View File

@ -0,0 +1,38 @@
from typing import Any
import torch
# The _get_device_index has been moved to torch.utils._get_device_index
from torch._utils import _get_device_index as _torch_get_device_index
def _get_device_index(
device: Any, optional: bool = False, allow_cpu: bool = False
) -> int:
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
i.e., ``torch.device('cuda')``, this will return the current default CUDA
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
CPU devices will be accepted and ``-1`` will be returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default CUDA
device if :attr:`optional` is ``True``.
"""
if isinstance(device, int):
return device
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if allow_cpu:
if device.type not in ["cuda", "cpu"]:
raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
elif device.type != "cuda":
raise ValueError(f"Expected a cuda device, but got: {device}")
if not torch.jit.is_scripting():
if isinstance(device, torch.cuda.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)

View File

@ -0,0 +1,12 @@
from .autocast_mode import autocast, custom_bwd, custom_fwd
from .common import amp_definitely_not_available
from .grad_scaler import GradScaler
__all__ = [
"amp_definitely_not_available",
"autocast",
"custom_bwd",
"custom_fwd",
"GradScaler",
]

View File

@ -0,0 +1,90 @@
# mypy: allow-untyped-defs
import functools
from typing import Any
from typing_extensions import deprecated
import torch
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(torch.amp.autocast_mode.autocast):
r"""See :class:`torch.autocast`.
``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead.
"""
@deprecated(
"`torch.cuda.amp.autocast(args...)` is deprecated. "
"Please use `torch.amp.autocast('cuda', args...)` instead.",
category=FutureWarning,
)
def __init__(
self,
enabled: bool = True,
dtype: torch.dtype = torch.float16,
cache_enabled: bool = True,
):
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = "cuda"
self.fast_dtype = dtype
return
super().__init__(
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
if torch._jit_internal.is_scripting():
return self
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return super().__call__(func)
# Preserved only for BC reasons
@deprecated(
"`torch.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. "
"Please use `torch.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.",
category=FutureWarning,
)
def _cast(value, dtype):
return torch.amp.autocast_mode._cast(value, "cuda", dtype)
@deprecated(
"`torch.cuda.amp.custom_fwd(args...)` is deprecated. "
"Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
``torch.amp.custom_fwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(torch.amp.custom_fwd, device_type="cuda")(
fwd=fwd, cast_inputs=cast_inputs
)
@deprecated(
"`torch.cuda.amp.custom_bwd(args...)` is deprecated. "
"Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_bwd(bwd):
"""
``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
``torch.amp.custom_bwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd)

View File

@ -0,0 +1,11 @@
# mypy: allow-untyped-defs
from importlib.util import find_spec
import torch
__all__ = ["amp_definitely_not_available"]
def amp_definitely_not_available():
return not (torch.cuda.is_available() or find_spec("torch_xla"))

View File

@ -0,0 +1,38 @@
from typing_extensions import deprecated
import torch
# We need to keep this unused import for BC reasons
from torch.amp.grad_scaler import OptState # noqa: F401
__all__ = ["GradScaler"]
class GradScaler(torch.amp.GradScaler):
r"""
See :class:`torch.amp.GradScaler`.
``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
"""
@deprecated(
"`torch.cuda.amp.GradScaler(args...)` is deprecated. "
"Please use `torch.amp.GradScaler('cuda', args...)` instead.",
category=FutureWarning,
)
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
super().__init__(
"cuda",
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)

View File

@ -0,0 +1,19 @@
# The functions here have been moved to torch.nn.parallel.comm
from torch.nn.parallel.comm import (
broadcast,
broadcast_coalesced,
gather,
reduce_add,
reduce_add_coalesced,
scatter,
)
__all__ = [
"broadcast",
"broadcast_coalesced",
"reduce_add",
"reduce_add_coalesced",
"scatter",
"gather",
]

View File

View File

@ -0,0 +1,129 @@
import os
import sys
from typing import Callable, List, Optional
import torch
from torch.types import Storage
__all__: List[str] = []
def _dummy_fn(name: str) -> Callable:
def fn(*args, **kwargs): # type: ignore[no-untyped-def]
raise RuntimeError(f"torch._C.{name} is not supported on this platform")
return fn
if not hasattr(torch._C, "_gds_register_buffer"):
assert not hasattr(torch._C, "_gds_deregister_buffer")
assert not hasattr(torch._C, "_gds_register_handle")
assert not hasattr(torch._C, "_gds_deregister_handle")
assert not hasattr(torch._C, "_gds_load_storage")
assert not hasattr(torch._C, "_gds_save_storage")
# Define functions
torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
def _gds_register_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_register_buffer(s)
def _gds_deregister_buffer(s: Storage) -> None:
"""Registers a buffer.
Args:
s (Storage): Buffer to register.
"""
torch._C._gds_deregister_buffer(s)
class _GdsFile:
r"""Wrapper around cuFile.
cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
Args:
filename (str): Name of the file to open.
flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
be added automatically.
.. _CUDA GPUDirect Storage Documentation:
https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
"""
def __init__(self, filename: str, flags: int):
if sys.platform == "win32":
raise RuntimeError("GdsFile is not supported on this platform.")
self.filename = filename
self.flags = flags
self.fd = os.open(filename, flags | os.O_DIRECT)
self.handle: Optional[int] = None
self.register_handle()
def __del__(self) -> None:
if self.handle is not None:
self.deregister_handle()
os.close(self.fd)
def register_handle(self) -> None:
"""Registers file descriptor to cuFile Driver.
This is a wrapper around ``cuFileHandleRegister``.
"""
assert (
self.handle is None
), "Cannot register a handle that is already registered."
self.handle = torch._C._gds_register_handle(self.fd)
def deregister_handle(self) -> None:
"""Deregisters file descriptor from cuFile Driver.
This is a wrapper around ``cuFileHandleDeregister``.
"""
assert (
self.handle is not None
), "Cannot deregister a handle that is not registered."
torch._C._gds_deregister_handle(self.handle)
self.handle = None
def load_storage(self, storage: Storage, offset: int = 0) -> None:
"""Loads data from the file into the storage.
This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
will be loaded from the file at ``offset`` into the storage.
Args:
storage (Storage): Storage to load data into.
offset (int, optional): Offset into the file to start loading from. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot load data from a file that is not registered."
torch._C._gds_load_storage(self.handle, storage, offset)
def save_storage(self, storage: Storage, offset: int = 0) -> None:
"""Saves data from the storage into the file.
This is a wrapper around ``cuFileWrite``. All bytes of the storage
will be written to the file at ``offset``.
Args:
storage (Storage): Storage to save data from.
offset (int, optional): Offset into the file to start saving to. (Default: 0)
"""
assert (
self.handle is not None
), "Cannot save data to a file that is not registered."
torch._C._gds_save_storage(self.handle, storage, offset)

View File

@ -0,0 +1,491 @@
# mypy: allow-untyped-defs
import gc
import typing
import torch
from .._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes
torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph")
torch._C.__dict__["_graph_pool_handle"] = _dummy_type("_graph_pool_handle")
torch._C.__dict__["_cuda_isCurrentStreamCapturing"] = _dummy_type(
"_cuda_isCurrentStreamCapturing"
)
from torch._C import ( # noqa: F401
_cuda_isCurrentStreamCapturing,
_CUDAGraph,
_graph_pool_handle,
)
def is_current_stream_capturing():
r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise.
If a CUDA context does not exist on the current device, returns False without initializing the context.
"""
return _cuda_isCurrentStreamCapturing()
# Python shim helps Sphinx process docstrings more reliably.
def graph_pool_handle():
r"""Return an opaque token representing the id of a graph memory pool.
See :ref:`Graph memory management<graph-memory-management>`.
.. warning::
This API is in beta and may change in future releases.
"""
return _graph_pool_handle()
# Python shim helps Sphinx process docstrings more reliably.
class CUDAGraph(torch._C._CUDAGraph):
r"""Wrapper around a CUDA graph.
.. warning::
This API is in beta and may change in future releases.
"""
def __new__(cls):
return super().__new__(cls)
def capture_begin(self, pool=None, capture_error_mode="global"):
r"""Begin capturing CUDA work on the current stream.
Typically, you shouldn't call ``capture_begin`` yourself.
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
which call ``capture_begin`` internally.
Arguments:
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
""" # noqa: B950
super().capture_begin(pool=pool, capture_error_mode=capture_error_mode)
def capture_end(self):
r"""End CUDA graph capture on the current stream.
After ``capture_end``, ``replay`` may be called on this instance.
Typically, you shouldn't call ``capture_end`` yourself.
Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,
which call ``capture_end`` internally.
"""
super().capture_end()
def replay(self):
r"""Replay the CUDA work captured by this graph."""
super().replay()
def reset(self):
r"""Delete the graph currently held by this instance."""
super().reset()
def pool(self):
r"""Return an opaque token representing the id of this graph's memory pool.
This id can optionally be passed to another graph's ``capture_begin``,
which hints the other graph may share the same memory pool.
"""
return super().pool()
def enable_debug_mode(self):
r"""Enable debugging mode for CUDAGraph.debug_dump."""
return super().enable_debug_mode()
def debug_dump(self, debug_path):
r"""
Arguments:
debug_path (required): Path to dump the graph to.
Calls a debugging function to dump the graph if the debugging is
enabled via CUDAGraph.enable_debug_mode()
"""
return super().debug_dump(debug_path)
class graph:
r"""Context-manager that captures CUDA work into a :class:`torch.cuda.CUDAGraph` object for later replay.
See :ref:`CUDA Graphs <cuda-graph-semantics>` for a general introduction,
detailed use, and constraints.
Arguments:
cuda_graph (torch.cuda.CUDAGraph): Graph object used for capture.
pool (optional): Opaque token (returned by a call to :func:`~torch.cuda.graph_pool_handle()` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) hinting this graph's capture
may share memory from the specified pool. See :ref:`Graph memory management<graph-memory-management>`.
stream (torch.cuda.Stream, optional): If supplied, will be set as the current stream in the context.
If not supplied, ``graph`` sets its own internal side stream as the current stream in the context.
capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting
unless you're familiar with `cudaStreamCaptureMode <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85>`_
.. note::
For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture
used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture.
.. warning::
This API is in beta and may change in future releases.
.. _cudaStreamCaptureMode:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
""" # noqa: B950
default_capture_stream: typing.Optional["torch.cuda.Stream"] = None
def __init__(
self,
cuda_graph,
pool=None,
stream=None,
capture_error_mode: str = "global",
):
# Lazy-init of default_capture_stream helps avoid circular-import errors.
# Not thread safe, but graphs already have the general (explicitly documented)
# restriction that only one capture may be underway at a time in the process.
if self.__class__.default_capture_stream is None:
self.__class__.default_capture_stream = torch.cuda.Stream()
self.pool = () if pool is None else (pool,)
self.capture_stream = (
stream if stream is not None else self.__class__.default_capture_stream
)
assert self.capture_stream is not None
self.stream_ctx = torch.cuda.stream(self.capture_stream)
self.cuda_graph = cuda_graph
self.capture_error_mode = capture_error_mode
def __enter__(self):
# Free as much memory as we can for the graph
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
# Stackoverflow seems comfortable with this pattern
# https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487
self.stream_ctx.__enter__()
self.cuda_graph.capture_begin(
*self.pool, capture_error_mode=self.capture_error_mode
)
def __exit__(self, exc_type, exc_value, traceback):
self.cuda_graph.capture_end()
self.stream_ctx.__exit__(exc_type, exc_value, traceback)
# returning None should propagate exceptions from either capture_end or stream_ctx.__exit__()
def make_graphed_callables(
callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None
):
r"""Accept callables (functions or :class:`nn.Module<torch.nn.Module>`\ s) and returns graphed versions.
Each graphed callable's forward pass runs its source callable's
forward CUDA work as a CUDA graph inside a single autograd node.
The graphed callable's forward pass also appends
a backward node to the autograd graph. During backward, this node runs the
callable's backward work as a CUDA graph.
Therefore, each graphed callable should be a drop-in replacement for its source callable
in an autograd-enabled training loop.
See :ref:`Partial-network capture<partial-network-capture>` for detailed use and constraints.
If you pass a tuple of several callables, their captures will use the same memory pool.
See :ref:`Graph memory management<graph-memory-management>` for when this is appropriate.
Arguments:
callables (torch.nn.Module or Python function, or tuple of these): Callable or callables to graph.
See :ref:`Graph memory management<graph-memory-management>` for when passing a tuple of callables
is appropriate. If you pass a tuple of callables, their order in the tuple must be the same order
they'll run in the live workload.
sample_args (tuple of Tensors, or tuple of tuples of Tensors): Samples args for each callable.
If a single callable was passed, ``sample_args`` must be a single tuple of argument Tensors.
If a tuple of callables was passed, ``sample_args`` must be tuple of tuples of argument Tensors.
num_warmup_iters (int): The number of warmup iterations. Currently, ``DataDistributedParallel`` needs
11 iterations for warm up. Default: ``3``.
allow_unused_input (bool): If False, specifying inputs that were not used when computing outputs
(and therefore their grad is always zero) is an error. Defaults to False.
pool (optional): Token (returned by :func:`~torch.cuda.graph_pool_handle` or
:meth:`other_Graph_instance.pool()<torch.cuda.CUDAGraph.pool>`) that hints this graph may share memory
with the indicated pool. See :ref:`Graph memory management<graph-memory-management>`.
.. note::
The ``requires_grad`` state of each Tensor in ``sample_args`` must match the state
that's expected for the corresponding real input in the training loop.
.. warning::
This API is in beta and may change in future releases.
.. warning::
``sample_args`` for each callable must contain only Tensors. Other types are not allowed.
.. warning::
Returned callables do not support higher order differentiation (e.g., double backward).
.. warning::
In any :class:`~torch.nn.Module` passed to :func:`~make_graphed_callables`, only parameters
may be trainable. Buffers must have ``requires_grad=False``.
.. warning::
After you pass a :class:`torch.nn.Module` through :func:`~make_graphed_callables`,
you may not add or remove any of that Module's parameters or buffers.
.. warning::
:class:`torch.nn.Module`\s passed to :func:`~torch.cuda.make_graphed_callables` must not have module hooks
registered on them at the time they are passed. However, registering hooks on modules *after* passing them
through :func:`~torch.cuda.make_graphed_callables` is allowed.
.. warning::
When running a graphed callable, you must pass its arguments in the same order and format
they appeared in that callable's ``sample_args``.
.. warning::
The automatic mixed precision is supported in :func:`~torch.cuda.make_graphed_callables` only with disabled
caching. The context manager `torch.cuda.amp.autocast()` must have `cache_enabled=False`.
"""
if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
raise RuntimeError(
"make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
)
just_one_callable = False
if not isinstance(callables, tuple):
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
flatten_sample_args = []
for c, args in zip(callables, sample_args):
if isinstance(c, torch.nn.Module):
assert (
len(c._backward_hooks) == 0
and len(c._forward_hooks) == 0
and len(c._forward_pre_hooks) == 0
), (
"Modules must not have hooks registered at the time they are passed. However, registering hooks "
+ "on modules after passing them through make_graphed_callables is allowed."
)
assert all(b.requires_grad is False for b in c.buffers()), (
"In any :class:`~torch.nn.Module` passed to "
+ ":func:`~make_graphed_callables`, only parameters may be trainable. All buffers must have "
+ "``requires_grad=False``."
)
flatten_arg = torch.utils._pytree.arg_tree_leaves(*args)
flatten_sample_args.append(tuple(flatten_arg))
assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
per_callable_len_user_args = [len(args) for args in flatten_sample_args]
per_callable_module_params = [
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
for c in callables
]
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(callables))
]
fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))]
mempool = graph_pool_handle() if pool is None else pool
# Warmup
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
for func, args, static_input_surface in zip(
callables, sample_args, per_callable_static_input_surfaces
):
grad_inputs, outputs, outputs_grad = None, None, None
for _ in range(num_warmup_iters):
outputs = torch.utils._pytree.tree_leaves(func(*args))
outputs_grad = tuple(o for o in outputs if o.requires_grad)
if len(outputs_grad) > 0:
grad_inputs = torch.autograd.grad(
outputs=outputs_grad,
inputs=tuple(
i for i in static_input_surface if i.requires_grad
),
grad_outputs=tuple(
torch.empty_like(o) for o in outputs if o.requires_grad
),
only_inputs=True,
allow_unused=allow_unused_input,
)
for v in [outputs, outputs_grad, grad_inputs]:
del v
torch.cuda.synchronize()
# All captures here share a mempool. To avoid replays corrupting each other's memory,
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.
# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs)
per_callable_static_outputs.append(tuple(flatten_outputs))
per_callable_output_unflatten_spec.append(spec)
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
for static_input_surface, static_outputs, bwd_graph, module_params in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
reversed(bwd_graphs),
reversed(per_callable_module_params),
):
# For now, assumes all static_outputs require grad
# assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad."
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
outputs_grad = tuple(o for o in static_outputs if o.requires_grad)
grad_inputs = None
if len(outputs_grad) > 0:
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=outputs_grad,
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
only_inputs=True,
allow_unused=allow_unused_input,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
# I couldn't think of a slick one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad and grad_inputs is not None:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None) # type: ignore[arg-type]
static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment]
per_callable_static_grad_outputs.append(static_grad_outputs)
per_callable_static_grad_inputs.append(static_grad_inputs)
# Reverses the most recent two lists
per_callable_static_grad_outputs.reverse()
per_callable_static_grad_inputs.reverse()
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.
def make_graphed_autograd_function(
fwd_graph,
bwd_graph,
module_params,
len_user_args,
output_unflatten_spec,
static_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
):
class Graphed(torch.autograd.Function):
@staticmethod
def forward(ctx, *inputs):
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *grads):
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
if g is not None:
# don't copy if autograd gods have been kind and the
# incoming grad is already in the right place
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
def functionalized(*user_args):
# Runs the autograd function with inputs == all inputs to the graph that might require grad
# (explicit user args + module parameters)
# Assumes module params didn't change since capture.
flatten_user_args = torch.utils._pytree.arg_tree_leaves(*user_args)
out = Graphed.apply(*(tuple(flatten_user_args) + module_params))
return torch.utils._pytree.tree_unflatten(out, output_unflatten_spec)
return functionalized
# Put together the final graphed callables
ret = []
for i, func in enumerate(callables):
graphed = make_graphed_autograd_function(
fwd_graphs[i],
bwd_graphs[i],
per_callable_module_params[i],
per_callable_len_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
)
if isinstance(func, torch.nn.Module):
def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
def new_fwd(*user_args):
# If the module's training-or-eval state matches what we graphed,
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
return graphed(*user_args)
else:
return orig_fwd(*user_args)
return new_fwd
func.forward = make_graphed_forward(func, func.training, graphed, func.forward) # type: ignore[assignment]
ret.append(func)
else:
ret.append(graphed)
if just_one_callable:
return ret[0]
return tuple(ret)

View File

@ -0,0 +1,187 @@
# mypy: allow-untyped-defs
import re
from typing import Callable, List
import torch
from torch import Tensor
__all__: List[str] = []
class _CodeParser:
def __init__(self, code_string: str):
optional_ws = r"\s*"
required_ws = r"\s+"
template_params = r"(?P<template_params>\<.+\>)"
return_type = r"(?P<return_type>\w+)"
function_name = r"(?P<function_name>\w+)"
function_params = r"(?P<function_params>\(.+\))"
function_body = r"(?P<function_body>\{.+\})"
pattern = (
optional_ws
+ "template"
+ optional_ws
+ template_params
+ optional_ws
+ return_type
+ required_ws
+ function_name
+ optional_ws
+ function_params
+ optional_ws
+ function_body
+ optional_ws
)
result = re.match(
pattern, code_string, re.DOTALL
) # DOTALL for matching multiline
if result is None:
raise Exception( # noqa: TRY002
f"Couldn't parse code, please check correctness:\n {code_string}"
)
self.template_params = result["template_params"]
self.return_type = result["return_type"]
self.function_name = result["function_name"]
self.function_params = result["function_params"]
self.function_body = result["function_body"]
class _JittedFunction:
def __init__(
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
):
self.code_string = code_string
assert (
return_by_ref or num_outputs == 1
), "Return by value only works for single output. "
self.return_by_ref = return_by_ref
self.num_outputs = num_outputs
parsed_code = _CodeParser(code_string)
self.kernel_name = parsed_code.function_name
self.kwargs_dict = kwargs
self.is_cuda_available = torch.cuda.is_available()
def __call__(self, *tensors: Tensor, **kwargs):
# Jiterator follow torch.cuda's lazy initialization behavior
# Defer checking cuda's availability at the function invocation time
assert (
self.is_cuda_available
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
expanded_kwargs = self.kwargs_dict.copy()
for key, value in kwargs.items():
if key in self.kwargs_dict:
expanded_kwargs[key] = value
else:
raise KeyError(f"{key} is not declared in function definition")
return torch._C._cuda_jiterator_compile_and_launch_kernel(
self.code_string,
self.kernel_name,
self.return_by_ref,
self.num_outputs,
tensors,
expanded_kwargs,
)
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op.
The code string has to be a valid CUDA function that describes the computation for a single element. The code
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
local temp dir.
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
Example::
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b) # using default val=0.0
Jiterator can be used together with python registration to override an operator's cuda kernel.
Following example is overriding gelu's cuda kernel with relu.
Example::
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 1 output
.. warning::
All input tensors must live in CUDA device
"""
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
def _create_multi_output_jit_fn(
code_string: str, num_outputs: int, **kwargs
) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
num_outputs(int): number of outputs return by the kernel
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 8 outputs
"""
return _JittedFunction(
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,151 @@
# mypy: allow-untyped-defs
import collections
import warnings
from typing import Optional, Sequence, Union
import torch.cuda
__all__ = ["all_reduce", "reduce", "broadcast", "all_gather", "reduce_scatter"]
SUM = 0 # ncclRedOp_t
def is_available(tensors):
if not hasattr(torch._C, "_nccl_all_reduce"):
warnings.warn("PyTorch is not compiled with NCCL support")
return False
devices = set()
for tensor in tensors:
if tensor.is_sparse:
return False
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
return True
def version():
"""
Returns the version of the NCCL.
This function returns a tuple containing the major, minor, and patch version numbers of the NCCL.
The suffix is also included in the tuple if a version suffix exists.
Returns:
tuple: The version information of the NCCL.
"""
ver = torch._C._nccl_version()
major = ver >> 32
minor = (ver >> 16) & 65535
patch = ver & 65535
suffix = torch._C._nccl_version_suffix().decode("utf-8")
if suffix == "":
return (major, minor, patch)
else:
return (major, minor, patch, suffix)
def unique_id():
return torch._C._nccl_unique_id()
def init_rank(num_ranks, uid, rank):
return torch._C._nccl_init_rank(num_ranks, uid, rank)
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
if not isinstance(inputs, collections.abc.Container) or isinstance(
inputs, torch.Tensor
):
raise TypeError("Inputs should be a collection of tensors")
def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
_check_sequence_type(inputs)
if outputs is None:
outputs = inputs
_check_sequence_type(outputs)
torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
# `output` used to be `outputs`, taking in a list of tensors. So we have two
# arguments for BC reasons.
def reduce(
inputs: Sequence[torch.Tensor],
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
root: int = 0,
op: int = SUM,
streams: Optional[Sequence[torch.cuda.Stream]] = None,
comms=None,
*,
outputs: Optional[Sequence[torch.Tensor]] = None,
) -> None:
_check_sequence_type(inputs)
_output: torch.Tensor
if outputs is not None:
if output is not None:
raise ValueError(
"'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
"favor of 'output', taking in a single output tensor. The signature of reduce is: "
"reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None)."
)
else:
warnings.warn(
"`nccl.reduce` with an output tensor list is deprecated. "
"Please specify a single output tensor with argument 'output' instead instead.",
FutureWarning,
stacklevel=2,
)
_output = outputs[root]
elif not isinstance(output, torch.Tensor) and isinstance(
output, collections.abc.Sequence
):
# User called old API with positional arguments of list of output tensors.
warnings.warn(
"nccl.reduce with an output tensor list is deprecated. "
"Please specify a single output tensor.",
FutureWarning,
stacklevel=2,
)
_output = output[root]
else:
_output = inputs[root] if output is None else output
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
def broadcast(
inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None
) -> None:
_check_sequence_type(inputs)
torch._C._nccl_broadcast(inputs, root, streams, comms)
def all_gather(
inputs: Sequence[torch.Tensor],
outputs: Sequence[torch.Tensor],
streams=None,
comms=None,
) -> None:
_check_sequence_type(inputs)
_check_sequence_type(outputs)
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
def reduce_scatter(
inputs: Sequence[torch.Tensor],
outputs: Sequence[torch.Tensor],
op: int = SUM,
streams=None,
comms=None,
) -> None:
_check_sequence_type(inputs)
_check_sequence_type(outputs)
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)

View File

@ -0,0 +1,93 @@
# mypy: allow-untyped-defs
r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
from contextlib import contextmanager
try:
from torch._C import _nvtx
except ImportError:
class _NVTXStub:
@staticmethod
def _fail(*args, **kwargs):
raise RuntimeError(
"NVTX functions not installed. Are you sure you have a CUDA build?"
)
rangePushA = _fail
rangePop = _fail
markA = _fail
_nvtx = _NVTXStub() # type: ignore[assignment]
__all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
def range_push(msg):
"""
Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
Args:
msg (str): ASCII message to associate with range
"""
return _nvtx.rangePushA(msg)
def range_pop():
"""Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
return _nvtx.rangePop()
def range_start(msg) -> int:
"""
Mark the start of a range with string message. It returns an unique handle
for this range to pass to the corresponding call to rangeEnd().
A key difference between this and range_push/range_pop is that the
range_start/range_end version supports range across threads (start on one
thread and end on another thread).
Returns: A range handle (uint64_t) that can be passed to range_end().
Args:
msg (str): ASCII message to associate with the range.
"""
return _nvtx.rangeStartA(msg)
def range_end(range_id) -> None:
"""
Mark the end of a range for a given range_id.
Args:
range_id (int): an unique handle for the start range.
"""
_nvtx.rangeEnd(range_id)
def mark(msg):
"""
Describe an instantaneous event that occurred at some point.
Args:
msg (str): ASCII message to associate with the event.
"""
return _nvtx.markA(msg)
@contextmanager
def range(msg, *args, **kwargs):
"""
Context manager / decorator that pushes an NVTX range at the beginning
of its scope, and pops it at the end. If extra arguments are given,
they are passed as arguments to msg.format().
Args:
msg (str): message to associate with the range
"""
range_push(msg.format(*args, **kwargs))
try:
yield
finally:
range_pop()

View File

@ -0,0 +1,86 @@
# mypy: allow-untyped-defs
import contextlib
import tempfile
import torch
from . import check_error, cudart
__all__ = ["init", "start", "stop", "profile"]
DEFAULT_FLAGS = [
"gpustarttimestamp",
"gpuendtimestamp",
"gridsize3d",
"threadblocksize",
"streamid",
"enableonstart 0",
"conckerneltrace",
]
def init(output_file, flags=None, output_mode="key_value"):
rt = cudart()
if not hasattr(rt, "cudaOutputMode"):
raise AssertionError("HIP does not support profiler initialization!")
if (
hasattr(torch.version, "cuda")
and torch.version.cuda is not None
and int(torch.version.cuda.split(".")[0]) >= 12
):
# Check https://github.com/pytorch/pytorch/pull/91118
# cudaProfilerInitialize is no longer needed after CUDA 12
raise AssertionError("CUDA12+ does not need profiler initialization!")
flags = DEFAULT_FLAGS if flags is None else flags
if output_mode == "key_value":
output_mode_enum = rt.cudaOutputMode.KeyValuePair
elif output_mode == "csv":
output_mode_enum = rt.cudaOutputMode.CSV
else:
raise RuntimeError(
"supported CUDA profiler output modes are: key_value and csv"
)
with tempfile.NamedTemporaryFile(delete=True) as f:
f.write(b"\n".join(f.encode("ascii") for f in flags))
f.flush()
check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
def start():
r"""Starts cuda profiler data collection.
.. warning::
Raises CudaError in case of it is unable to start the profiler.
"""
check_error(cudart().cudaProfilerStart())
def stop():
r"""Stops cuda profiler data collection.
.. warning::
Raises CudaError in case of it is unable to stop the profiler.
"""
check_error(cudart().cudaProfilerStop())
@contextlib.contextmanager
def profile():
"""
Enable profiling.
Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> model = torch.nn.Linear(20, 30).cuda()
>>> inputs = torch.randn(128, 20).cuda()
>>> with torch.cuda.profiler.profile() as prof:
... model(inputs)
"""
try:
start()
yield
finally:
stop()

View File

@ -0,0 +1,182 @@
# mypy: allow-untyped-defs
from typing import Iterable, List, Union
import torch
from torch import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
]
def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
r"""Return the random number generator state of the specified GPU as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
.. warning::
This function eagerly initializes CUDA.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
r"""Return a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
) -> None:
r"""Set the random number generator state of the specified GPU.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
r"""Set the random number generator state of all devices.
Args:
new_states (Iterable of torch.ByteTensor): The desired state for each device.
"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
def manual_seed(seed: int) -> None:
r"""Set the seed for generating random numbers for the current GPU.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-GPU model, this function is insufficient
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb, seed=True)
def manual_seed_all(seed: int) -> None:
r"""Set the seed for generating random numbers on all GPUs.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
def cb():
for i in range(device_count()):
default_generator = torch.cuda.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb, seed_all=True)
def seed() -> None:
r"""Set the seed for generating random numbers to a random number for the current GPU.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
.. warning::
If you are working with a multi-GPU model, this function will only initialize
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
"""
def cb():
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
def seed_all() -> None:
r"""Set the seed for generating random numbers to a random number on all GPUs.
It's safe to call this function if CUDA is not available; in that
case, it is silently ignored.
"""
def cb():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = torch.cuda.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
seeded = True
else:
default_generator.manual_seed(random_seed)
_lazy_call(cb)
def initial_seed() -> int:
r"""Return the current random seed of the current GPU.
.. warning::
This function eagerly initializes CUDA.
"""
_lazy_init()
idx = current_device()
default_generator = torch.cuda.default_generators[idx]
return default_generator.initial_seed()

View File

@ -0,0 +1 @@
# The Tensor classes are added to this module by python_tensor.cpp

View File

@ -0,0 +1,242 @@
# mypy: allow-untyped-defs
import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from torch._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes
torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
class Stream(torch._C._CudaStreamBase, _StreamBase):
r"""Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific
device, independent from other streams. See :ref:`cuda-semantics` for
details.
Args:
device(torch.device or int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream, should be 0 or
negative, where negative numbers indicate higher priority. By default,
streams have priority 0.
"""
def __new__(cls, device=None, priority=0, **kwargs):
# setting device manager is expensive, so we avoid it unless necessary
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super().__new__(cls, priority=priority, **kwargs)
else:
with torch.cuda.device(device):
return super().__new__(cls, priority=priority, **kwargs)
def wait_event(self, event) -> None:
r"""Make all future work submitted to the stream wait for an event.
Args:
event (torch.cuda.Event): an event to wait for.
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
`CUDA Stream documentation`_ for more info.
This function returns without waiting for :attr:`event`: only future
operations are affected.
.. _CUDA Stream documentation:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
"""
event.wait(self)
def wait_stream(self, stream) -> None:
r"""Synchronize with another stream.
All future work submitted to this stream will wait until all kernels
submitted to a given stream at the time of call complete.
Args:
stream (Stream): a stream to synchronize.
.. note:: This function returns without waiting for currently enqueued
kernels in :attr:`stream`: only future operations are affected.
"""
self.wait_event(stream.record_event())
def record_event(self, event=None):
r"""Record an event.
Args:
event (torch.cuda.Event, optional): event to record. If not given, a new one
will be allocated.
Returns:
Recorded event.
"""
if event is None:
event = Event()
event.record(self)
return event
def query(self) -> bool:
r"""Check if all the work submitted has been completed.
Returns:
A boolean indicating if all kernels in this stream are completed.
"""
return super().query()
def synchronize(self) -> None:
r"""Wait for all the kernels in this stream to complete.
.. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
`CUDA Stream documentation`_ for more info.
"""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.cuda_stream)
def __eq__(self, o) -> bool:
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.cuda_stream, self.device))
def __repr__(self):
return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
class ExternalStream(Stream):
r"""Wrapper around an externally allocated CUDA stream.
This class is used to wrap streams allocated in other libraries in order
to facilitate data exchange and multi-library interactions.
.. note:: This class doesn't manage the stream life-cycle, it is the user
responsibility to keep the referenced stream alive while this class is
being used.
Args:
stream_ptr(int): Integer representation of the `cudaStream_t` value.
allocated externally.
device(torch.device or int, optional): the device where the stream
was originally allocated. If device is specified incorrectly,
subsequent launches using this stream may fail.
"""
def __new__(cls, stream_ptr, device=None, **kwargs):
with torch.cuda.device(device):
return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
class Event(torch._C._CudaEventBase, _EventBase):
r"""Wrapper around a CUDA event.
CUDA events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize CUDA
streams.
The underlying CUDA events are lazily initialized when the event is first
recorded or exported to another process. After creation, only streams on the
same device may record the event. However, streams on any device can wait on
the event.
Args:
enable_timing (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
interprocess (bool): if ``True``, the event can be shared between processes
(default: ``False``)
.. _CUDA Event Documentation:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
"""
def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
return super().__new__(
cls,
enable_timing=enable_timing,
blocking=blocking,
interprocess=interprocess,
)
@classmethod
def from_ipc_handle(cls, device, handle):
r"""Reconstruct an event from an IPC handle on the given device."""
return super().from_ipc_handle(device, handle)
def record(self, stream=None):
r"""Record the event in a given stream.
Uses ``torch.cuda.current_stream()`` if no stream is specified. The
stream's device must match the event's device.
"""
if stream is None:
stream = torch.cuda.current_stream()
super().record(stream)
def wait(self, stream=None) -> None:
r"""Make all future work submitted to the given stream wait for this event.
Use ``torch.cuda.current_stream()`` if no stream is specified.
.. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
`CUDA Event documentation`_ for more info.
"""
if stream is None:
stream = torch.cuda.current_stream()
super().wait(stream)
def query(self):
r"""Check if all work currently captured by event has completed.
Returns:
A boolean indicating if all work currently captured by event has
completed.
"""
return super().query()
def elapsed_time(self, end_event):
r"""Return the time elapsed.
Time reported in milliseconds after the event was recorded and
before the end_event was recorded.
"""
return super().elapsed_time(end_event)
def synchronize(self) -> None:
r"""Wait for the event to complete.
Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
.. note:: This is a wrapper around ``cudaEventSynchronize()``: see
`CUDA Event documentation`_ for more info.
"""
super().synchronize()
def ipc_handle(self):
r"""Return an IPC handle of this event.
If not recorded yet, the event will use the current device.
"""
return super().ipc_handle()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.cuda_event)
def __repr__(self) -> str:
if self.cuda_event:
return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
else:
return "<torch.cuda.Event uninitialized>"

View File

@ -0,0 +1,242 @@
r"""
This module exposes a TunableOp interface.
Some operations, such as GEMMs, could be implemented using more than one library
or more than one technique. For example, a GEMM could be implemented for CUDA or
ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
hipblaslt libraries allow the user to query for all possible algorithms and then
choose one. How does one know which implementation is the fastest and should be
chosen? That's what TunableOp provides.
Enabling TunableOp and Tuning Separately
========================================
The TunableOp feature is enabled separately from enabling the tuning phase
itself. Enabling TunableOp means that PyTorch will replace any standard
operators with their Tunable implementations. Any call to a TunableOp first
checks whether it has already been tuned for the given operator inputs. If so,
it will immediately call the tuned operation; no further tuning will take place
even when the tuning setting is enabled. Instead if no tuning result is found,
and tuning is enabled, the TunableOp will benchmark every registered
implementation of that operator for the given set of inputs and select the
fastest.
File Input and Output
=====================
The first time any TunableOp is invoked, the internal database of tuned
operations will be prepared by attempting to read the results from the given
file. The default filename is 'tunableop_results.csv'. To support tuning when
multiple GPUs are used across multiple processes, the GPU device ordinal is
automatically inserted into the filename to avoid multiple processes overwriting
the same file.
If tuning is enabled and new tunings are discovered during the course of your
workload, it will also write out to this same filename with all tunings, both
the ones it read in at startup as well as the new ones found at runtime. This
can be used, for example, to build up a tunings file across many workloads by
reusing the same file. The output file is automatically created when the
application terminates. This behavior can be controlled by the C++ and Python
APIs but not the environment variables.
Assuming you specified a filename, you'll end up with a CSV file with contents
like so::
Validator,PT_VERSION,2.2.0
Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
Note the "Validator" lines. If you change a library verison, or ROCm version, or
PyTorch version, TunableOp will detect this and reject the tunings file because
the prior tunings are likely affected by other software changes.
The remaining lines are the tuned solutions for each TunableOp encountered
during your execution. Each line consists of 4 comma-separated fields: operator
name, operator parameters, solution name, and average execution time. The
execution time is an optional field. The CSV file can be edited, but with
caution. For example, the solution name (field 3) can be changed to "Default"
and it will fall back to the original PyTorch untuned implementation. Or, in the
case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
index you can override the solution that TunableOp selected by replacing the
value. The operator name and parameters (fields 1 and 2) are internally named
and should not be modified. In the case of GemmTunableOp, field 1 indicates the
datatype and whether the inputs are transposed (T) or not (N) and field 2
indicates the M, N, K input shapes.
There is an option to enable verbose output but it is only recommended for
debugging purposes. This will produce a lot of diagnostic messages but may be
useful to see if TunableOp is being used at all. Otherwise, TunableOp is
completely silent, besides file output, unless there is a warning or error
during its use. The verbose option is only available by setting the environment
variable PYTORCH_TUNABLEOP_VEROBSE=1.
A Note on Tuning Behavior
=========================
Tuning an operator consists of iterating through the list or registered
implementations and profiling each one. The profile is established by running a
single implementation in a loop multiple times and taking the average execution
time.
By default, each possible solution for a given operator will be run for either
100 iterations or as many iterations that can be run within 30ms, whichever is
smaller, and its average execution will be calculated. The fastest solution
among all that were successfully profiled will be chosen. A profile might fail
if the given solution doesn't achieve the same accuracy as the default
implementation or if the solution returns an error code.
Current Tunable Operators
=========================
TunableGemm for ROCm
--------------------
Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
PyTorch will function correctly when using TunableOp but the only solution
available to CUDA builds is the 'Default' implementation i.e. the original
cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
given set of input arguments (transa, transb, m, n, k) will attempt to use the
fastest available implementation across both rocblas and hipblaslt.
Tuning Context
==============
The behavior of TunableOp is currently manipulated through environment
variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
torch.cuda.tunable python interfaces that wrap the C++ TuningContext. The
environment variables take precedence over any setting you manipulate using the
C++ or Python APIs.
"""
from typing import Optional, Tuple
import torch
__all__ = [
"enable",
"is_enabled",
"tuning_enable",
"tuning_is_enabled",
"set_max_tuning_duration",
"get_max_tuning_duration",
"set_max_tuning_iterations",
"get_max_tuning_iterations",
"set_filename",
"get_filename",
"get_results",
"get_validators",
"write_file_on_exit",
"write_file",
"read_file",
]
def enable(val: bool = True) -> None:
r"""This is the big on/off switch for all TunableOp implementations."""
torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined]
def is_enabled() -> bool:
r"""Returns whether the TunableOp feature is enabled."""
return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined]
def tuning_enable(val: bool = True) -> None:
r"""Enable tuning of TunableOp implementations.
When enabled, if a tuned entry isn't found, run the tuning step and record
the entry.
"""
torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined]
def tuning_is_enabled() -> bool:
r"""Returns whether TunableOp implementations can be tuned."""
return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined]
def set_max_tuning_duration(duration: int) -> None:
r"""Set max time in milliseconds to spend tuning a given solution.
If both max tuning duration and iterations are set, the smaller of the two
will be honored. At minimum 1 tuning iteration will always be run.
"""
torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined]
def get_max_tuning_duration() -> int:
r"""Get max time to spend tuning a given solution."""
return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined]
def set_max_tuning_iterations(iterations: int) -> None:
r"""Set max number of iterations to spend tuning a given solution.
If both max tuning duration and iterations are set, the smaller of the two
will be honored. At minimum 1 tuning iteration will always be run.
"""
torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined]
def get_max_tuning_iterations() -> int:
r"""Get max iterations to spend tuning a given solution."""
return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined]
def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
r"""Set the filename to use for input/output of tuning results.
If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
will be added to the given filename automatically. This can be used in a
1-process-per-gpu cenario to ensure all processes write to a separate file.
"""
torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined]
def get_filename() -> str:
r"""Get the results filename."""
return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
def get_results() -> Tuple[str, str, str, float]:
r"""Return all TunableOp results."""
return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
def get_validators() -> Tuple[str, str]:
r"""Return the TunableOp validators."""
return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
def write_file_on_exit(val: bool) -> None:
r"""During Tuning Context destruction, write file to disk.
This is useful as a final flush of your results to disk if your application
terminates as result of normal operation or an error. Manual flushing of
your results can be achieved by manually calling ``write_file()``."""
torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined]
def write_file(filename: Optional[str] = None) -> bool:
r"""Write results to a CSV file.
If :attr:`filename` is not given, ``get_filename()`` is called.
"""
if filename is None:
filename = get_filename()
return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined]
def read_file(filename: Optional[str] = None) -> bool:
r"""Read results from a TunableOp CSV file.
If :attr:`filename` is not given, ``get_filename()`` is called.
"""
if filename is None:
filename = get_filename()
return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined]