I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,55 @@
# mypy: allow-untyped-defs
import torch._C._lazy
from torch.utils._pytree import tree_flatten, tree_unflatten
from .closure import add_step_closure, run_step_closures
def mark_step(device: str = "", wait=False):
"""Triggers a mark step, which amounts to
- collecting a group of 'live' lazy tensors to index into the compilation cache
(lowering/compiling their IR graphs if not cached)
- kicking off execution of the compiled function
- (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
"""
# TODO(whc) expand this to include backend hooks and align with XLA backend needs
torch._C._lazy._mark_step(device, [], wait=wait)
run_step_closures()
def wait_device_ops(devices=None):
"""Waits for all the async operations on the given devices to complete.
Args:
devices (string..., optional): The devices whose async ops need to be waited
for. If empty, all the local devices will be waited for.
"""
if devices is None:
devices = []
torch._C._lazy._wait_device_ops(devices=devices)
def sync_multi(tensors, devices):
"""
Sync the list of lazy tensors so there IR get lowered for the activate backend
and the compiled computation graph get cached.
"""
torch._C._lazy._sync_multi(tensors, devices)
def get_tensor_id(tensor):
"""Return a unique id of the lazy tensor maintained by LTC"""
return torch._C._lazy._get_tensor_id(tensor)
def to_cpu(tensors, devices=None):
devices = devices or ["lazy"]
flattened, spec = tree_flatten(tensors)
sync_multi(flattened, devices)
return tree_unflatten([t.to("cpu") for t in flattened], spec)
def save(tensors, *args, **kwargs):
torch.save(to_cpu(tensors), *args, **kwargs)

View File

@ -0,0 +1,135 @@
# mypy: allow-untyped-defs
import os
import threading
from queue import Empty as EmptyQueue, Queue
from torch._lazy.device_context import get_device_context
class ClosureHandler:
def __init__(self) -> None:
pass
def run(self, closure):
"""Run closure function
Args:
closure: callable function to run
"""
closure()
def __call__(self, closures):
for closure in closures:
self.run(closure)
class AsyncClosureHandler(ClosureHandler):
"""Handler for Asynchronous Step Closures
Args:
max_queue_size: The maximum length of the closure queue after which
the training loop will block until closures are evaluated.
By default, a reasonable limit of a maximum of 100 on the queue.
This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment
variable.
"""
def __init__(self, max_queue_size=100):
super().__init__()
self._closure_queue: Queue = Queue(
int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size))
)
self._closure_exception: Queue = Queue()
self._closure_lock = threading.Lock()
self._closure_event_loop_finished = threading.Event()
self._closure_event_loop = None
def start_event_loop(self):
"""Start closure event loop if not started"""
if self._closure_event_loop is None:
def event_loop():
# Run loop until closure event is set and closure queue is empty
while True:
try:
closure = self._closure_queue.get(block=True, timeout=3)
closure()
self._closure_queue.task_done()
except EmptyQueue:
with self._closure_lock:
if self._closure_queue.empty():
self._closure_event_loop_finished.set()
return
except Exception as e:
self._closure_exception.put(e)
return
self._closure_event_loop = threading.Thread(target=event_loop)
self._closure_event_loop.start()
def run(self, closure):
with self._closure_lock:
self._closure_queue.put(closure, block=True)
if (
self._closure_event_loop is None
or not self._closure_event_loop.is_alive()
):
try:
e = self._closure_exception.get(block=False)
raise RuntimeError(
"Cannot run asynchronous closure due to previously raised exception"
) from e
except EmptyQueue:
self._closure_event_loop = None
self.start_event_loop()
def add_step_closure(closure, args=(), run_async=False):
"""Adds a closure to the list of the ones to be run at the end of the step.
Many times during model training there is the need to print/report (print to
console, post to tensorboard, etc...) information which require the content of
intermediary tensors to be inspected.
Inspecting different tensors content in different points of the model code
requires many executions and typically causes performance issues.
Adding a step closure will ensure that it will be run after the barrier, when
all the live tensors will be already materialized to device data.
Live tensors which will include the ones captured by the closure arguments.
So using `add_step_closure()` will ensure a single execution will be
performed, even when multiple closures are queued, requiring multiple tensors
to be inspected.
Step closures will be run sequentially in the order they have been queued.
Note that even though using this API the execution will be optimized, it is
advised to throttle the printing/reporting events once every N steps.
Args:
closure (callable): The function to be called.
args (tuple): The arguments to be passed to the closure.
run_async: If True, run the closure asynchronously.
"""
devctx = get_device_context()
closures_type = "async_step_closures" if run_async else "step_closures"
step_closures = getattr(devctx, closures_type, None)
if step_closures is None:
step_closures = []
setattr(devctx, closures_type, step_closures)
step_closures.append(lambda a=args: closure(*a))
def run_step_closures():
devctx = get_device_context()
async_step_closures = getattr(devctx, "async_step_closures", None)
if async_step_closures is not None:
devctx.async_step_closures = []
async_closure_handler = getattr(devctx, "async_closure_handler", None)
if async_closure_handler is None:
async_closure_handler = AsyncClosureHandler()
devctx.async_closure_handler = async_closure_handler
async_closure_handler(async_step_closures)
step_closures = getattr(devctx, "step_closures", None)
if step_closures is not None:
devctx.step_closures = []
closure_handler = getattr(devctx, "closure_handler", None)
if closure_handler is None:
closure_handler = ClosureHandler()
devctx.closure_handler = closure_handler
closure_handler(step_closures)
return devctx

View File

@ -0,0 +1,27 @@
# mypy: allow-untyped-defs
import torch._C._lazy
import torch._C._lazy_ts_backend
def get_tensors_ts_device_data_node(tensors):
"""Return tensor ids and eager tensors for DeviceData nodes in the
IR for the passed in lazy tensors.
TODO: This API is currently ts backend specific. We are working on
generalizing it to all backends including XLA.
"""
return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
def get_graph_hash(tensors):
"""Return the graph hash for the passed in lazy tensors"""
return torch._C._lazy._get_graph_hash(tensors)
def run_cached_graph(hash_str, graph_inputs):
"""Running the cached computation graph with the given inputs
TODO: This API is currently ts backend specific. We are working on
generalizing it to all backends including XLA.
"""
return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)

View File

@ -0,0 +1,17 @@
# mypy: allow-untyped-defs
import torch._C._lazy
def get_force_fallback():
"""Get the config used to force LTC fallback"""
return torch._C._lazy._get_force_fallback()
def set_force_fallback(configval):
"""Set the config used to force LTC fallback"""
torch._C._lazy._set_force_fallback(configval)
def set_reuse_ir(val: bool):
"""Set the config to reuse IR nodes for faster tracing"""
torch._C._lazy._set_reuse_ir(val)

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
import torch._C._lazy
def render_ir_graph(tensors):
"""Return a text dump of the LTC IR graph in dot format for the tensors.
The text can be processed by tools like dot to be rendered in pdf,png etc."""
return torch._C._lazy._get_tensors_dot(tensors)
def dump_ir(tensors, ir_format):
"""Return a dump of the tensors in the specified format.
Valid format are
- text: for LTC IR
- backend: for the activate backend IR
"""
if ir_format == "text":
return torch._C._lazy._get_tensors_text(tensors)
elif ir_format == "backend":
return torch._C._lazy._get_tensors_backend(tensors)
else:
raise RuntimeError(f"Unrecognized IR format: {ir_format}")

View File

@ -0,0 +1,26 @@
# mypy: allow-untyped-defs
import threading
from typing import Any, Dict
import torch._C._lazy
class DeviceContext:
_CONTEXTS: Dict[str, Any] = {}
_CONTEXTS_LOCK = threading.Lock()
def __init__(self, device):
self.device = device
def get_device_context(device=None):
if device is None:
device = torch._C._lazy._get_default_device_type()
else:
device = str(device)
with DeviceContext._CONTEXTS_LOCK:
devctx = DeviceContext._CONTEXTS.get(device, None)
if devctx is None:
devctx = DeviceContext(device)
DeviceContext._CONTEXTS[device] = devctx
return devctx

View File

@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
import copy
import dataclasses
import itertools
import os
from typing import Any, Callable, Dict, List
import torch
import torch._lazy as lazy
import torch._lazy.metrics as metrics
from torch import fx
from torch._lazy import computation, debug as lazy_debug
from torch._lazy.tensor_factory_functions import tensor_factory_functions
debug = os.environ.get("debug_extract_compiled_graph") is not None
@dataclasses.dataclass
class GraphInputMatcher:
"""
The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing.
Specifically, those graph inputs corresponding to method parameters should be replaced with the
arguments for the current call.
tensor_id_to_arg_idx maps the tensor id to the parameter index.
graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the
TS/XLA graph inputs.
"""
tensor_id_to_arg_idx: Dict[int, int]
graph_input_tensor_ids: List[int]
# there are 2 categories of graph_input_tensors.
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
# most likely const tensors and we can get its content from graph_input_tensors
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
# the tensor from method arguments
graph_input_ivalues: List[Any]
# get the real graph input tensors
def __call__(self, args):
real_input = []
for tensor_id, traced_ivalue in zip(
self.graph_input_tensor_ids, self.graph_input_ivalues
):
arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
if arg_idx is None:
inp = traced_ivalue
else:
inp = args[arg_idx]
real_input.append(inp)
return real_input
class ReturnValueHandler:
r"""
When ltc_sync_multi is called on multi tensors, the compiled graph
will contain output only for unique tensors - if a tensor appears multiple
times in the input to _ltc_sync_multi, only the first occurance matters.
However from python level, we still expect multi tensors returned with duplciation
even if the TS graph dedup the output. e.g. for method:
def forward(self, a):
return a, a
the TS graph captured by LTC will return a single tensor, but Python method expects 2.
This class dedup the lazy tensors first to get the index that will be used
to duplicate the eager tensors later.
"""
def __init__(self, lazy_out_list):
self.index: List[List[int]] = []
self.total_count = len(lazy_out_list)
tensor_id_to_idx: Dict[int, int] = {}
for dup_idx, lazy_tensor in enumerate(lazy_out_list):
uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
if uniq_idx is not None:
self.index[uniq_idx].append(dup_idx)
else:
uniq_idx = len(self.index)
self.index.append([dup_idx])
tensor_id_to_idx[id(lazy_tensor)] = uniq_idx
def duplicate_eager_tensors(self, eager_tensor_list):
duplicated_list = [None] * self.total_count
assert len(eager_tensor_list) == len(self.index)
for uniq_idx, eager_tensor in enumerate(eager_tensor_list):
for dup_idx in self.index[uniq_idx]:
duplicated_list[dup_idx] = eager_tensor
return duplicated_list
def force_lazy_device(model: fx.GraphModule):
"""
Factory methods in a Fx graph may create tensors for a specific eager devices.
If we take no actions, those eager tensors will be mixed with lazy tensors and
cause crash. This method overwrite those eager device to lazy device.
"""
def tolazydevice(dev):
if isinstance(dev, torch.device):
return torch.device("lazy", index=dev.index)
return dev
def hasDeviceArg(args, kwargs):
return any(
isinstance(arg, torch.device)
for arg in itertools.chain(args, kwargs.values())
)
for nd in model.graph.nodes:
nd.args = tuple(tolazydevice(arg) for arg in nd.args)
nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}
# For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
# eager tensors on the default device
# (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
# and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
# To force those tensors on the lazy device, we can not simply override
# the device argument since there is no explicit device argument.
# What we are doing here is, for the list of covered tensor factory methods
# we add a lazy device argument explicity.
#
# TODO: This solution is no ideal since we may miss some factory methods. In future
# when we support lazy mode, this method can be replaced by that.
if nd.target in tensor_factory_functions and not hasDeviceArg(
nd.args, nd.kwargs
):
kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy.
kwargs["device"] = torch.device("lazy")
nd.kwargs = kwargs
model.recompile()
def get_fallback_ops():
fallback_ops = []
for opname in metrics.counter_names():
if "aten::" not in opname:
continue
val = int(metrics.counter_value(opname))
if val > 0:
fallback_ops.append(f"{opname}={val}")
return fallback_ops
def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable:
"""
Optimize an eager model with LTC and returns a wrapper to execute the
compiled graph directly without retracing. It depends on other mechanisms
like TorchDynamo guards to guarantee the returned wrapper is only called
when it's safe.
"""
lazy_args = [arg.to(device="lazy") for arg in example_inputs]
args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args]
tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
lazy_model = copy.deepcopy(model).to(device=torch.device("lazy"))
force_lazy_device(lazy_model)
# This line executes lazy tracing and enable us extracting compiled graph later
metrics.reset()
lazy_out = lazy_model(*lazy_args)
fallback_ops = get_fallback_ops()
metrics.reset()
if len(fallback_ops) > 0:
raise RuntimeError(
f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}"
)
if not isinstance(lazy_out, (tuple, list)):
lazy_out = (lazy_out,)
args_and_out = tuple(lazy_args) + tuple(lazy_out)
return_value_handler = ReturnValueHandler(args_and_out)
if debug:
print("Fx code:\n", model.code)
print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text"))
# TODO: this part is TS backend specific for now and will be generalized to
# support XLA
(
graph_input_tensor_ids,
graph_input_ivalues,
) = computation.get_tensors_ts_device_data_node(args_and_out)
assert len(graph_input_tensor_ids) == len(graph_input_ivalues)
graph_input_matcher = GraphInputMatcher(
tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues
)
graph_hash = computation.get_graph_hash(args_and_out)
if debug:
print("graph_hash", graph_hash)
print(f"args_tensor_ids {args_tensor_ids}")
print("tensor ids from device data:", graph_input_tensor_ids)
# sync the list of output tensors so the computation graph for these
# tensors will be cached. Those computation graphs can be retrieved
# by graph hash later.
lazy.sync_multi(args_and_out, [])
def optimized_mod(*args):
if len(args_and_out) == 0:
return ()
graph_input = graph_input_matcher(args)
res = return_value_handler.duplicate_eager_tensors(
computation.run_cached_graph(graph_hash, graph_input)
)
assert len(res) == len(args_and_out)
for i, arg in enumerate(args):
# only copy those tensors that get inplace updated
if arg is not res[i]:
arg.copy_(res[i])
# skip the args
return res[len(args) :]
return optimized_mod

View File

@ -0,0 +1,14 @@
# mypy: allow-untyped-defs
import torch._C._lazy
def dump(dot_file_name: str):
"""Dump TrieCache in the dot format"""
return torch._C._lazy._dump_ir_cache(dot_file_name)
def reset():
"""Clear TrieCache. This is needed in testing to avoid
node reusing between different tests.
"""
return torch._C._lazy._clear_ir_cache()

View File

@ -0,0 +1,22 @@
# mypy: allow-untyped-defs
import torch._C._lazy
def reset():
"""Resets all metric counters."""
torch._C._lazy._reset_metrics()
def counter_names():
"""Retrieves all the currently active counter names."""
return torch._C._lazy._counter_names()
def counter_value(name: str):
"""Return the value of the counter with the speficied name"""
return torch._C._lazy._counter_value(name)
def metrics_report():
"""Return the combined (lazy core and backend) metric report"""
return torch._C._lazy._metrics_report()

View File

@ -0,0 +1,49 @@
import torch
"""
tensor_factory_functions defines the list of torch functions that create tensors.
The list is grabbed by searching thru native_functions.yaml by the following
regular expression:
cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor"
It's possible that new tensor factory functions are added making this list stale.
Use at your own risk or regenerate the list.
"""
tensor_factory_functions = (
torch._cudnn_init_dropout_state,
torch.arange,
torch.bartlett_window,
torch.blackman_window,
torch._empty_affine_quantized,
torch.empty_strided,
torch.eye,
torch.full,
torch.from_file,
torch.hann_window,
torch.hamming_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.ones,
torch.scalar_tensor,
torch.rand,
torch.randint,
torch.randn,
torch.randperm,
torch.range,
torch._efficientzerotensor,
torch.zeros,
torch.tril_indices,
torch.triu_indices,
# Note: the following functions match the regular expression search above but
# they are not available in the torch module. Comment out.
# torch._sparse_coo_tensor_with_dims,
# torch.fft_fftfreq,
# torch.fft_rfftfreq,
) + (
# torch.tensor is special since it's not in native_functions.yaml
# add it separately
torch.tensor,
)

View File

@ -0,0 +1,7 @@
# mypy: allow-untyped-defs
import torch._C._lazy_ts_backend
def init():
"""Initializes the lazy Torchscript backend"""
torch._C._lazy_ts_backend._init()