I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
from .schedules import (
_ScheduleForwardOnly,
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from .stage import build_stage, PipelineStage
__all__ = [
"Pipe",
"pipe_split",
"SplitPoint",
"pipeline",
"PipelineStage",
"build_stage",
"Schedule1F1B",
"ScheduleFlexibleInterleaved1F1B",
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
]

View File

@ -0,0 +1,370 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import collections
import logging
import weakref
from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union
import torch
from torch.autograd.graph import GradientEdge, Node
from torch.nn import Parameter
from ._debug import map_debug_info
logger = logging.getLogger(__name__)
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
"""
Get the grad function or grad accumulator for a tensor.
Accumulate grad nodes are lazily created, so we need to a
dummy view in order to trigger its creation.
"""
if t.requires_grad and t.grad_fn is None:
# if no grad function (leaf tensors) we use view
viewed_t = t.view_as(t)
grad_fn = viewed_t.grad_fn
if grad_fn is not None:
return grad_fn.next_functions[0][0]
else:
raise RuntimeError(
"Attempted to get grad_fn, but got None."
"Is this being created in a no-grad context?"
)
else:
return t.grad_fn
def reverse_closure(
roots: List[Node], target_nodes: Set[Node]
) -> Tuple[Set[Node], Set[Node]]:
"""
This function returns the reverse closure of the given roots,
i.e. the set of nodes that can be reached from the roots by following the
reverse edges of the graph. The target_nodes are the nodes that we want to
include in the closure.
"""
# Recurse until we reach a target node
closure: Set[Node] = set()
visited_target_nodes = set()
q: Deque[Node] = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
metadata = cast(Dict[str, List], node.metadata)
reverse_edges = metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is None:
# this reverse graph is no longer alive
# raise RuntimeError("Reverse graph is no longer alive")
continue
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
visited_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, visited_target_nodes
# Enable weak pointer
class Holder:
def __init__(self, node: Node):
self.node = node
def construct_reverse_graph(roots: List[Node]) -> List[Holder]:
q: Deque[Node] = collections.deque()
root_seen: Set[Node] = set()
reverse_graph_refs: List[Holder] = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
metadata = cast(Dict[str, List], fn.metadata)
reverse_edges = metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]:
"""
Given a list of inputs and a list of parameters, return a list of parameter
groups, where each group contains the parameters and the intermediates that
are connected to the parameters.
The returned list of parameter groups is a list of dictionaries, where each
dictionary contains the following keys:
- "params": a set of parameters
- "intermediates": a set of intermediates
The returned list of parameter groups is a list of dictionaries,
"""
# reverse graph that starts with inputs, and goes up to the dOutput or the loss,
# but omits weights and any subgraphs connecting weights to this closure
inputs_closure, _ = reverse_closure(inputs, set())
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group: Dict[str, Set] = {
"params": {param},
"intermediates": intersected,
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(
param_group["intermediates"]
)
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params: Set[Node] = set()
seen_ids: Set[int] = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
# The assert will only be true if the input tensor requires gradients,
# otherwise the autograd graph will miss the first layer of inputs
# assert union_params == set(params)
return unique_param_groups
def stage_backward_input(
stage_outputs: List[torch.Tensor],
output_grads: Optional[List[torch.Tensor]],
input_values: List[torch.Tensor],
weights: Iterator[Parameter],
):
"""
compute the gradients for only the stage inputs with respect to the stage outputs
"""
stage_output_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs))
)
stage_input_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, input_values))
)
weight_grad_fns: List[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, weights))
)
reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns)
param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(
param_group["intermediates"]
)
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
# Stage 0 inputs do not require grads? Should we skip in that case?
if all(tensor.requires_grad for tensor in input_values):
if output_grads is None:
# In case this is the loss and there are no output_grads, then we just use 1s
output_grads = [
torch.ones_like(stage_output) for stage_output in stage_outputs
]
dinputs = torch.autograd.grad(
stage_outputs,
inputs=input_values,
grad_outputs=output_grads,
retain_graph=True,
)
# update the gradients for inputs
for i, inp in enumerate(input_values):
if inp.grad is None:
inp.grad = dinputs[i]
else:
inp.grad += dinputs[i]
else:
dinputs = None
return dinputs, param_groups
def stage_backward_weight(
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]]
):
# map weights to param_group_weights
grad_acc_to_weight = {}
weight_grads = []
for index, weight in enumerate(weights):
grad_acc = _get_grad_fn_or_grad_acc(weight)
grad_acc_to_weight[grad_acc] = weight, index
weight_grads.append(weight.grad)
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(
GradientEdge(i, 0) for i in param_group["intermediates"]
)
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
# print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(
intermediate_edges,
weights_edges,
grad_outputs=sum(param_group["grads"], tuple()),
)
for grad_acc, dw in zip(param_group["params"], dweights):
weight, index = grad_acc_to_weight[grad_acc]
if weight.grad is None:
weight.grad = dw
else:
weight.grad += dw
# return grads in the original order weights were provided in
return weight_grads
def stage_backward(
stage_output,
output_grads,
input_values,
outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
):
"""
This is a helper function to:
1. compute the gradients for the stage inputs, and
2. accumulate gradients for the stage module's parameters.
Given the input value(s) and the corresponding gradient for the output
value(s), compute and accumulate gradients for all parameter values (leaves
in the autograd trace) as well as return a list of the gradients for the
input values
"""
if outputs_with_grads_idxs is not None:
# Deprecated, not used in runtime calls, only exists in compiler
stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
try:
# stage_output may be a composite datatype like dict. Extract all individual
# tensor values here
stage_output_tensors = []
output_grad_tensors = []
def extract_tensors_with_grads(output_val, grad_val):
if isinstance(output_val, torch.Tensor):
if not output_val.requires_grad and output_val.grad_fn is None:
return
assert isinstance(
grad_val, (torch.Tensor, type(None))
), f"Expected Tensor or None gradient but got {type(grad_val)}"
stage_output_tensors.append(output_val)
output_grad_tensors.append(grad_val)
elif isinstance(output_val, (tuple, list)):
if grad_val is None:
return
assert isinstance(
grad_val, (tuple, list)
), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
assert len(output_val) == len(grad_val)
for ov, gv in zip(output_val, grad_val):
extract_tensors_with_grads(ov, gv)
elif isinstance(output_val, dict):
if grad_val is None:
return
assert isinstance(grad_val, dict)
assert set(output_val.keys()) == set(grad_val.keys())
for k in output_val.keys():
extract_tensors_with_grads(output_val[k], grad_val[k])
else:
# Output is a non-tensor type; just ignore it
pass
extract_tensors_with_grads(stage_output, output_grads)
torch.autograd.backward(
stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type]
)
# Extract gradients wrt the input values
grad_inputs = []
for val in input_values:
if isinstance(val, torch.Tensor):
grad_inputs.append(val.grad)
else:
grad_inputs.append(None)
# Alternative impl: `torch.autograd.grad`.
# Note that `torch.autograd.grad` will not accumulate gradients into the
# model's parameters.
"""
inputs_with_grad = []
for val in input_values:
if isinstance(val, torch.Tensor) and val.requires_grad:
inputs_with_grad.append(val)
grad_inputs = torch.autograd.grad(
stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type]
)
"""
except Exception as e:
exc_msg = f"""
Failed to run stage backward:
Stage output: {map_debug_info(stage_output)}
Output gradient: {map_debug_info(output_grads)}
Input: {map_debug_info(input_values)}
"""
raise RuntimeError(exc_msg) from e
return grad_inputs
# TODO: handling requires_grad=False dynamically. Can we analyze this during initial
# IR emission?
def _null_coalesce_accumulate(lhs, rhs):
"""
Coalesce two values, even if one of them is null, returning the non-null
value.
"""
if lhs is None:
return rhs
elif rhs is None:
return lhs
else:
return torch.add(lhs, rhs)

View File

@ -0,0 +1,21 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
def friendly_debug_info(v):
"""
Helper function to print out debug info in a friendly way.
"""
if isinstance(v, torch.Tensor):
return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
else:
return str(v)
def map_debug_info(a):
"""
Helper function to apply `friendly_debug_info` to items in `a`.
`a` may be a list, tuple, or dict.
"""
return torch.fx.node.map_aggregate(a, friendly_debug_info)

View File

@ -0,0 +1,27 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict
import torch
from torch.export.unflatten import _ModuleFrame
def _outline_submodules(orig_graph: torch.fx.Graph):
# Create an empty GraphModule to hold the outlined modules
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, torch.nn.Module] = {}
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),
seen_nodes,
seen_modules,
None,
[""],
"",
{},
module=new_module,
).run_outer()
new_module.graph.lint()
new_module.recompile()
return new_module

View File

@ -0,0 +1,99 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import List, Tuple, Union
import torch
from torch import fx
logger = logging.getLogger(__name__)
def flatten_args_detach(args):
"""
Flatten the args into a list form and detach the tensors from computational graph.
"""
flat_detached_args = []
def extract_tensor_args(a):
nonlocal flat_detached_args
if isinstance(a, torch.Tensor):
val = a.detach().requires_grad_(a.requires_grad)
flat_detached_args.append(val)
return val
else:
flat_detached_args.append(a)
return a
new_args = fx.node.map_aggregate(
args,
extract_tensor_args,
)
return new_args, flat_detached_args
def flatten_args(args):
"""
Flatten the args into a list form.
"""
flat_args = []
def extract_tensor_args(a):
nonlocal flat_args
flat_args.append(a)
return a
fx.node.map_aggregate(
args,
extract_tensor_args,
)
return flat_args
class PipeliningShapeError(RuntimeError):
"""Shape mismatch between configured and runtime values."""
def validate_tensor_metadata(desc, expected, given):
if not expected.shape == given.shape:
raise PipeliningShapeError(
f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
)
if not expected.dtype == given.dtype:
raise PipeliningShapeError(
f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
)
if not expected.stride() == given.stride():
raise PipeliningShapeError(
f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
)
def validate_tensors_metadata(
desc,
expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(
f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
)
for i in range(len(expected_tensors)):
validate_tensor_metadata(
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
)
@dataclass
class PipeInfo:
"""
Captures information for a pipeline (`Pipe` object).
"""
graph: fx.Graph
num_stages: int
has_loss_and_backward: bool

View File

@ -0,0 +1,469 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.fx.node import map_aggregate
from torch.utils._pytree import tree_flatten, tree_unflatten
__all__ = [
"TensorChunkSpec",
"split_args_kwargs_into_chunks",
"merge_chunks",
]
logger = logging.getLogger(__name__)
"""
_debug_mask_minibatches specifies to send masked versions of the mini-batch
through instead of micro-batch slices--this can be used for more stable
numerical testing (see [A Note About Correctness Testing])
"""
_debug_mask_minibatches = False
class _CustomReducer:
"""
Custom reducer class that can be used to specify a custom operation that
reduces losses of multiple microbatches into one value.
Example:
>>> # xdoctest: +SKIP
>>> sum_reducer = _CustomReducer(
>>> torch.tensor(0.0),
>>> lambda a, b: a + b
>>> )
"""
def __init__(self, init_value, reduce_fn):
self.init_value = init_value
self.reduce_fn = reduce_fn
class _LossReducer(_CustomReducer):
pass
sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
# Default chunking dimension is 0. This is used for the case where the user did
# not specify a chunking dimension.
DEFAULT_CHUNK_DIM = 0
class TensorChunkSpec:
"""
Class used to specify chunking of inputs
"""
def __init__(self, split_dim):
self.split_dim = split_dim
split_dim: int
def __repr__(self):
return (
f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
)
def __str__(self):
return f"TensorChunkSpec({self.split_dim})"
@staticmethod
def from_tuple(
chunk_dims: Tuple[int, ...],
):
"""
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return args_chunk_spec
@staticmethod
def from_dict(
chunk_dims: Dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
dictionary of chunk dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
)
return kwargs_chunk_spec
# Class used to specify replication of inputs
class _Replicate:
pass
def _shard_dict_of_args(
args_dict,
args_chunk_spec,
num_chunks,
):
"""
Given a dictionary of args, and a dictionary of chunking specs, shard the
args according to the chunking specs.
Args:
args_dict: Dictionary of args
args_chunk_spec: Dictionary of chunking specs
num_chunks: Number of chunks to shard the args into
Returns:
args_split: List of sharded args
"""
# Stage 1+2: flatten and shard/replicate
# args_sharded_replicated : [num args, num flat values, num chunks]
args_sharded_replicated = {}
arg_specs = []
real_num_chunks = num_chunks
first_tensor = True
assert len(args_dict) == len(
args_chunk_spec
), f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
for arg_key, arg in args_dict.items():
flat, spec = tree_flatten(arg)
arg_specs.append(spec)
chunk_spec = args_chunk_spec[arg_key]
assert chunk_spec is not None # Should have been set by caller
chunk_spec_flat, _ = tree_flatten(chunk_spec)
if len(flat) != len(chunk_spec_flat):
raise ValueError(
f"Argument value {arg} did not have the same number of "
f"values as as chunk spec {chunk_spec}"
)
sharded_arg_flat = []
for v, chunk_v in zip(flat, chunk_spec_flat):
if chunk_v is _Replicate or not isinstance(v, torch.Tensor):
sharded_arg_flat.append([v] * real_num_chunks)
elif isinstance(chunk_v, TensorChunkSpec):
# TODO: check type of v. If it's a tensor, use chunk (or debug mask).
# If it's a collection type, split it as you would expect. Otherwise,
# Throw an error
assert isinstance(v, torch.Tensor), f"{v} is not a tensor"
v_split_dim_size = v.size(chunk_v.split_dim)
if v_split_dim_size < real_num_chunks:
if first_tensor:
# We can only adjust number of chunks when we hit this
# issue at the first tensor encountered
logger.warning(
f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004
f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}."
)
real_num_chunks = v_split_dim_size
else:
raise RuntimeError(
f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, "
f"smaller than the number of chunks {num_chunks}. "
"PiPPy cannot reduce the number of chunks because "
"other arguments have bigger chunk-dimension sizes. "
"Please adjust your num_chunks setting."
)
chunk_tensors = torch.tensor_split(
v, real_num_chunks, chunk_v.split_dim
)
if _debug_mask_minibatches:
expanded_chunks = []
split_dim_idx = 0
for chunk_tensor in chunk_tensors:
new_val = torch.zeros_like(v)
upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim)
slice_indices = [slice(None, None, None)] * new_val.ndim
slice_indices[chunk_v.split_dim] = slice(
split_dim_idx, upper_idx
)
new_val[slice_indices] = chunk_tensor
expanded_chunks.append(new_val)
split_dim_idx += chunk_tensor.size(chunk_v.split_dim)
sharded_arg_flat.append(expanded_chunks)
else:
sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type]
first_tensor = False
else:
raise TypeError(f"Unrecognized chunk spec: {chunk_v}")
args_sharded_replicated[arg_key] = sharded_arg_flat
# chunks_flat : [num chunks, num args, num flat values]
chunks_flat = []
for chunk_idx in range(real_num_chunks):
chunk_args = {}
for key, arg in args_sharded_replicated.items():
arg_single_chunk = []
for v_flat in arg:
arg_single_chunk.append(v_flat[chunk_idx])
chunk_args[key] = arg_single_chunk
chunks_flat.append(chunk_args)
# args_split : [num chunks, num args]
args_split = []
for chunk in chunks_flat:
per_chunk_args = {}
assert len(arg_specs) == len(chunk)
for (key, arg), arg_spec in zip(chunk.items(), arg_specs):
per_chunk_args[key] = tree_unflatten(arg, arg_spec)
args_split.append(per_chunk_args)
return args_split
def split_args_kwargs_into_chunks(
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
) -> Tuple[List[Tuple], List[Dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
Args:
args: Tuple of args
kwargs: Dict of kwargs
chunks: Number of chunks to split the args and kwargs into
args_chunk_spec: chunking specs for args, in same shape as args
kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
Returns:
args_split: List of sharded args
kwargs_split: List of sharded kwargs
"""
# Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
# the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
# and `kwargs_chunk_spec` specifications. The steps are as follows:
#
# 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
# To use a running example: suppose our inputs look like
#
# args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
# (kwargs not shown but it's a similar process)
#
# Then for this step we would end up with
#
# args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
#
# 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
#
# args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
#
# 3. Rotate the nesting order such that chunks are the outer dimension
#
# args_chunks = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 4. Unflatten each chunk according to the spec
#
# args_chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# TODO: _debug_mask_minibatches
# Handle the case where kwargs is None
if kwargs is None:
kwargs = {}
# If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
# their format and use default chunking along dim 0
if args_chunk_spec is None:
args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args)
if kwargs_chunk_spec is None:
kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM))
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
chunks,
)
real_num_chunks = len(args_split_dict)
kwargs_split = _shard_dict_of_args(
kwargs,
kwargs_chunk_spec,
real_num_chunks,
)
if len(kwargs_split) < real_num_chunks:
# In case kwargs are sharded into less chunks
# e.g. when `args` has no tensor, just values
real_num_chunks = len(kwargs_split)
# Re-shard args
args_split_dict = _shard_dict_of_args(
dict(enumerate(args)),
dict(enumerate(args_chunk_spec)),
real_num_chunks,
)
if len(args_split_dict) != len(kwargs_split):
raise RuntimeError(
"args and kwargs are split into different number of chunks: "
f"{len(args_split_dict)}, {len(kwargs_split)}"
)
args_split = []
for chunk_args in args_split_dict:
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
return args_split, kwargs_split
def merge_chunks(
chunks: List[Any],
chunk_spec,
):
"""
Given a list of chunks, merge them into a single value according to
the chunk spec.
Args:
chunks: list of chunks
chunk_spec: Chunking spec for the chunks
Returns:
value: Merged value
"""
# This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
# steps are similar to the steps in that function but in reverse. Given the
# input values:
#
# chunks = [
# ([A, [B, C_1]], D),
# ([A, [B, C_2]], D),
# ]
# args_spec = ([None, [None, TensorChunkSpec]], None)
#
# 1. Flatten the chunks according to the chunk_spec
#
# chunks_flat = [
# ([A, B, C_1], D),
# ([A, B, C_2], D),
# ]
#
# 2. Rotate the nesting order such that chunks are the inner dimension
#
# value_inner = ([A, B, [C_1, C_2]], D)
#
# 3. Concatenate sharded arguments
#
# value_combined = ([A, B, C], D)
#
# 4. Unflatten the combined args given the spec
#
# value = ([A, [B, C]], D)
# Preliminary: flatten the chunk spec
if chunk_spec is not None:
spec_flattened, flatten_spec = tree_flatten(chunk_spec)
else:
# If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
# We obtain the output structure by flattening chunk 0 and generate the chunk_spec
chunk0_flat, flatten_spec = tree_flatten(chunks[0])
spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
# Stage 1: flatten chunks
# chunks_flattened : [num chunks, num args]
chunks_flattened = []
for chunk in chunks:
chunk_flattened, _ = tree_flatten(chunk)
if len(chunk_flattened) != len(spec_flattened):
raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
chunks_flattened.append(chunk_flattened)
# Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
# concatenate sharded operands
# args_flattened : [num args]
args_flattened = []
for arg_idx, arg in enumerate(spec_flattened):
if isinstance(arg, TensorChunkSpec):
partial_values = [
chunks_flattened[chunk_idx][arg_idx]
for chunk_idx in range(len(chunks_flattened))
]
if _debug_mask_minibatches:
# Infer size of individual chunks by running `tensor_split` again
overall_shape = partial_values[0].shape
for val in partial_values[1:]:
assert val.shape == overall_shape
meta_chunks = torch.tensor_split(
torch.empty(*overall_shape, device="meta"),
sections=len(partial_values),
dim=arg.split_dim,
)
values_to_cat = []
chunk_start_idx = 0
assert len(partial_values) == len(meta_chunks)
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
slice_indices = [slice(None, None, None)] * partial_value.ndim
slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
sliced = partial_value[slice_indices]
values_to_cat.append(sliced)
chunk_start_idx = chunk_end_idx
else:
values_to_cat = partial_values
args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
elif isinstance(arg, _CustomReducer):
reduced_val = arg.init_value
for chunk_idx in range(len(chunks_flattened)):
reduced_val = arg.reduce_fn(
reduced_val, chunks_flattened[chunk_idx][arg_idx]
)
args_flattened.append(reduced_val)
else:
value = chunks_flattened[0][arg_idx]
for chunk_idx in range(1, len(chunks_flattened)):
assert chunks_flattened[chunk_idx][arg_idx] == value
args_flattened.append(value)
# Stage 4: Unflatten combined args
return tree_unflatten(args_flattened, flatten_spec)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff