2163 lines
87 KiB
Python
2163 lines
87 KiB
Python
# mypy: allow-untyped-defs
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import csv
|
|
import itertools
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
)
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
|
|
from torch.profiler import record_function
|
|
|
|
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
|
from .stage import _PipelineStageBase
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch.distributed import Work
|
|
|
|
__all__ = [
|
|
"get_schedule_class",
|
|
"PipelineScheduleSingle",
|
|
"PipelineScheduleMulti",
|
|
"Schedule1F1B",
|
|
"ScheduleFlexibleInterleaved1F1B",
|
|
"ScheduleGPipe",
|
|
"ScheduleInterleaved1F1B",
|
|
"ScheduleLoopedBFS",
|
|
"ScheduleInterleavedZeroBubble",
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _ComputationType(Enum):
|
|
# TODO(whc) rename to _ActType?
|
|
FORWARD = 1
|
|
BACKWARD = 2
|
|
WEIGHT = 3
|
|
UNSHARD = 4
|
|
RESHARD = 5
|
|
SEND_F = 6
|
|
RECV_F = 7
|
|
SEND_B = 8
|
|
RECV_B = 9
|
|
|
|
def __str__(self):
|
|
str_map = {
|
|
_ComputationType.FORWARD: "F",
|
|
_ComputationType.BACKWARD: "B",
|
|
_ComputationType.WEIGHT: "W",
|
|
_ComputationType.UNSHARD: "UNSHARD",
|
|
_ComputationType.RESHARD: "RESHARD",
|
|
_ComputationType.SEND_F: "SEND_F",
|
|
_ComputationType.RECV_F: "RECV_F",
|
|
_ComputationType.SEND_B: "SEND_B",
|
|
_ComputationType.RECV_B: "RECV_B",
|
|
}
|
|
return str_map[self]
|
|
|
|
@staticmethod
|
|
def from_str(action):
|
|
if action == "F":
|
|
return _ComputationType.FORWARD
|
|
elif action == "B":
|
|
return _ComputationType.BACKWARD
|
|
elif action == "W":
|
|
return _ComputationType.WEIGHT
|
|
elif action == "UNSHARD":
|
|
return _ComputationType.UNSHARD
|
|
elif action == "RESHARD":
|
|
return _ComputationType.RESHARD
|
|
elif action == "SEND_F":
|
|
return _ComputationType.SEND_F
|
|
elif action == "RECV_F":
|
|
return _ComputationType.RECV_F
|
|
elif action == "SEND_B":
|
|
return _ComputationType.SEND_B
|
|
elif action == "RECV_B":
|
|
return _ComputationType.RECV_B
|
|
else:
|
|
raise RuntimeError(f"Invalid computation type {action}")
|
|
|
|
|
|
FORWARD = _ComputationType.FORWARD
|
|
BACKWARD = _ComputationType.BACKWARD
|
|
WEIGHT = _ComputationType.WEIGHT
|
|
UNSHARD = _ComputationType.UNSHARD
|
|
RESHARD = _ComputationType.RESHARD
|
|
SEND_F = _ComputationType.SEND_F
|
|
RECV_F = _ComputationType.RECV_F
|
|
SEND_B = _ComputationType.SEND_B
|
|
RECV_B = _ComputationType.RECV_B
|
|
|
|
# Convenience shorthand for compute actions only since they are used in 'simple schedule format'
|
|
F = FORWARD
|
|
B = BACKWARD
|
|
W = WEIGHT
|
|
|
|
# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
|
|
_action_regex = re.compile(
|
|
r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)"
|
|
)
|
|
|
|
|
|
class _Action(NamedTuple):
|
|
stage_index: int
|
|
computation_type: _ComputationType
|
|
microbatch_index: Optional[int] = None
|
|
|
|
def __repr__(self):
|
|
repr = str(self.stage_index)
|
|
repr += str(self.computation_type)
|
|
if self.microbatch_index is not None:
|
|
repr += str(self.microbatch_index)
|
|
return repr
|
|
|
|
@staticmethod
|
|
def from_str(str):
|
|
"""
|
|
Reverse of __repr__
|
|
|
|
String should be formatted as [stage][action type][(microbatch)]
|
|
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
|
|
"""
|
|
if match := _action_regex.match(str):
|
|
stage_index, computation_type, microbatch_index = match.groups()
|
|
return _Action(
|
|
int(stage_index),
|
|
_ComputationType.from_str(computation_type),
|
|
int(microbatch_index) if len(microbatch_index) else None,
|
|
)
|
|
elif str == "" or str.isspace():
|
|
return None
|
|
raise RuntimeError(
|
|
f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
|
|
)
|
|
|
|
|
|
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
|
"""
|
|
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
|
and returns the formatted string
|
|
"""
|
|
# Calculate the maximum number of steps across all ranks
|
|
num_steps = max(len(actions) for actions in pipeline_order.values())
|
|
step_labels = [
|
|
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
|
]
|
|
# Sorting the dictionary by keys and retrieving values in that order
|
|
rank_actions = [
|
|
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
|
]
|
|
# Transpose the list of lists (rows to columns)
|
|
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
|
# Generate column labels for ranks
|
|
num_ranks = len(pipeline_order)
|
|
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
|
# Calculate the maximum length of each column, considering labels
|
|
max_lengths = [
|
|
max(len(str(item)) if item is not None else 0 for item in col)
|
|
for col in zip(step_labels, *transposed_actions)
|
|
]
|
|
# Format the header row with rank labels
|
|
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
|
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
|
)
|
|
# Format each row with its corresponding label
|
|
formatted_rows = [
|
|
f"{label}: "
|
|
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
|
for label, row in zip(step_labels, transposed_actions)
|
|
]
|
|
# Join the rows into a single string
|
|
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
|
|
return formatted_table
|
|
|
|
|
|
def _validate_pipeline_order(
|
|
pipeline_order: Dict[int, List[Optional[_Action]]],
|
|
num_microbatches: int,
|
|
num_stages: int,
|
|
enable_zero_bubble: bool = False,
|
|
):
|
|
"""
|
|
pipeline_order[rank] = [(computation_type, microbatch_index, stage_index), ...]
|
|
Validating that the pipeline order follows the rules:
|
|
1. Forward action for a microbatch must be before the Backward action for that microbatch
|
|
2. Recv for a microbatch must be before the send for that microbatch
|
|
3. Microbatch index is handled in sequential order for each stage
|
|
4. A later stage cannot operate on a microbatch before any of the previous stages have operated on it
|
|
5. Same microbatch cannot be handled in the same time step across ranks
|
|
"""
|
|
# microbatch_index: (current computation type, current stage)
|
|
microbatch_process_info: Dict[int, Tuple[_ComputationType, int]] = {}
|
|
max_timestep = max(len(rank_list) for rank_list in pipeline_order.values())
|
|
for timestep in range(max_timestep):
|
|
error_msg: List[str] = []
|
|
current_timestep_actions = []
|
|
for rank in range(len(pipeline_order)):
|
|
action = (
|
|
pipeline_order[rank][timestep]
|
|
if timestep < len(pipeline_order[rank])
|
|
else None
|
|
)
|
|
|
|
if action is not None:
|
|
computation_type = action.computation_type
|
|
if computation_type != _ComputationType.WEIGHT:
|
|
current_timestep_actions.append(action)
|
|
|
|
# TODO: enable this
|
|
# if len(current_timestep_actions) == 0:
|
|
# error_msg.append(
|
|
# "All actions were None, there is an unnecessary gap in the schedule"
|
|
# )
|
|
|
|
# Ensure that no microbatch is operated on twice in current_timestep_actions
|
|
unique_microbatch_indices = {
|
|
action.microbatch_index for action in current_timestep_actions
|
|
}
|
|
if len(unique_microbatch_indices) != len(current_timestep_actions):
|
|
error_msg.append(
|
|
"Duplicate microbatch index found in current_timestep_actions"
|
|
)
|
|
|
|
for action in current_timestep_actions:
|
|
stage_index = action.stage_index
|
|
computation_type = action.computation_type
|
|
mb_index = action.microbatch_index
|
|
assert (
|
|
mb_index is not None
|
|
), "All currently supported action types require valid microbatch_index"
|
|
if mb_index >= num_microbatches:
|
|
error_msg.append(f"Microbatch index {mb_index} out of range")
|
|
|
|
# first microbatch
|
|
if mb_index not in microbatch_process_info:
|
|
if computation_type != _ComputationType.FORWARD or stage_index != 0:
|
|
error_msg.append(f"Incorrect start for microbatch {mb_index}")
|
|
microbatch_process_info[mb_index] = (computation_type, stage_index)
|
|
else:
|
|
# if the microbatch is included, check that the current stage is right after prev
|
|
prev_computation, prev_stage = microbatch_process_info[mb_index]
|
|
|
|
if prev_computation == _ComputationType.FORWARD:
|
|
if prev_stage == num_stages - 1:
|
|
expected_stage = num_stages - 1
|
|
expected_computation = _ComputationType.BACKWARD
|
|
else:
|
|
expected_stage = prev_stage + 1
|
|
expected_computation = _ComputationType.FORWARD
|
|
elif prev_computation == _ComputationType.BACKWARD:
|
|
if prev_stage == 0:
|
|
error_msg.append(
|
|
f"[{mb_index=}] already finished backward computation"
|
|
)
|
|
break
|
|
else:
|
|
expected_stage = prev_stage - 1
|
|
expected_computation = _ComputationType.BACKWARD
|
|
else:
|
|
raise ValueError(
|
|
f"Computation type {prev_computation} not supported"
|
|
)
|
|
|
|
if expected_computation is not None:
|
|
if expected_computation != computation_type:
|
|
error_msg.append(
|
|
f"[{mb_index=}] {expected_computation=} VS. actual {computation_type=}"
|
|
)
|
|
|
|
if expected_stage != stage_index:
|
|
error_msg.append(
|
|
f"[{mb_index=}] {expected_stage=} VS. actual {stage_index=}"
|
|
)
|
|
|
|
microbatch_process_info[mb_index] = (
|
|
expected_computation,
|
|
expected_stage,
|
|
)
|
|
|
|
if not enable_zero_bubble:
|
|
if len(error_msg) != 0:
|
|
raise RuntimeError(
|
|
f"Error at timestep {timestep}: " + ",".join(error_msg)
|
|
)
|
|
return
|
|
|
|
for rank in range(len(pipeline_order)):
|
|
backward_steps: Set[Tuple[int, int]] = set()
|
|
weight_steps: Set[Tuple[int, int]] = set()
|
|
|
|
for action in pipeline_order[rank]:
|
|
if action is None:
|
|
continue
|
|
|
|
stage_index = action.stage_index
|
|
computation_type = action.computation_type
|
|
mb_index = action.microbatch_index
|
|
if computation_type == _ComputationType.BACKWARD:
|
|
if mb_index is not None:
|
|
backward_steps.add((mb_index, stage_index))
|
|
elif computation_type == _ComputationType.WEIGHT:
|
|
if (mb_index, stage_index) not in backward_steps:
|
|
error_msg.append(
|
|
f"{mb_index=}, {stage_index=} Weight happened before bwd"
|
|
)
|
|
if (mb_index, stage_index) in weight_steps:
|
|
error_msg.append(
|
|
f"{mb_index=}, {stage_index=} Duplicated weight step"
|
|
)
|
|
if mb_index is not None:
|
|
weight_steps.add((mb_index, stage_index))
|
|
|
|
if len(backward_steps) != len(weight_steps):
|
|
error_msg.append("Length weight steps != Length bwd steps")
|
|
|
|
if len(error_msg) != 0:
|
|
raise RuntimeError(f"Error at timestep {timestep}: " + ",".join(error_msg))
|
|
|
|
|
|
class _PipelineSchedule(ABC):
|
|
def __init__(
|
|
self,
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
# From arguments
|
|
self._n_microbatches = n_microbatches
|
|
self._loss_fn = loss_fn
|
|
# Chunking specification for positional inputs. (default: `None`)
|
|
self._args_chunk_spec = args_chunk_spec
|
|
# Chunking specification for keyword inputs. (default: `None`)
|
|
self._kwargs_chunk_spec = kwargs_chunk_spec
|
|
self._output_merge_spec = output_merge_spec
|
|
"""
|
|
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
|
|
# They are used to convert batch to microbatches in `step(x)`. See
|
|
# `TensorChunkSpec` for helper methods for creating them.
|
|
"""
|
|
|
|
# Derived
|
|
self._has_backward = self._loss_fn is not None
|
|
|
|
# Holds the losses for each microbatch.
|
|
self._internal_losses: List[torch.Tensor] = []
|
|
logger.info("Using %s", self.__class__.__name__)
|
|
|
|
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
|
|
if stage.is_last and self._has_backward:
|
|
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
|
|
self._internal_losses.append(loss)
|
|
|
|
def _maybe_get_loss(self, stage, mb_index):
|
|
valid_index = 0 <= mb_index < len(self._internal_losses)
|
|
if stage.is_last and self._has_backward and valid_index:
|
|
return self._internal_losses[mb_index]
|
|
elif len(self._internal_losses) != 0 and not valid_index:
|
|
raise RuntimeError(
|
|
f"Loss for microbatch {mb_index} is not available. "
|
|
f"Available losses for microbatches: {self._internal_losses}"
|
|
)
|
|
else:
|
|
return None
|
|
|
|
def _update_losses(self, stages, losses):
|
|
"""
|
|
Update the losses to those in the internal state
|
|
"""
|
|
# if stages not a list turn into a list
|
|
if not isinstance(stages, list):
|
|
stages = [stages]
|
|
contains_last_stage = any(stage.is_last for stage in stages)
|
|
|
|
# Return losses if there is a container passed in
|
|
if contains_last_stage and losses is not None:
|
|
if len(self._internal_losses) != self._n_microbatches:
|
|
raise RuntimeError(
|
|
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
|
|
)
|
|
|
|
# Clean external container first
|
|
losses.clear()
|
|
# Copy internal losses to external container
|
|
losses.extend(self._internal_losses)
|
|
|
|
self._internal_losses.clear()
|
|
|
|
@abstractmethod
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the schedule
|
|
implementation.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def _check_inputs(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Pre-process/check inputs
|
|
"""
|
|
|
|
def check_type_and_len(mbs, name: str):
|
|
if not isinstance(mbs, list):
|
|
raise TypeError(f"{name} must be a list but got a {type(mbs)}")
|
|
if len(mbs) != self._n_microbatches:
|
|
raise ValueError(
|
|
f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
|
|
)
|
|
|
|
if arg_mbs is not None:
|
|
check_type_and_len(arg_mbs, "arg_mbs")
|
|
else:
|
|
arg_mbs = [()] * self._n_microbatches
|
|
|
|
if kwarg_mbs is not None:
|
|
check_type_and_len(kwarg_mbs, "kwarg_mbs")
|
|
else:
|
|
kwarg_mbs = [{}] * self._n_microbatches
|
|
|
|
if target_mbs is not None:
|
|
check_type_and_len(target_mbs, "target_mbs")
|
|
|
|
if losses is not None:
|
|
if not isinstance(losses, list):
|
|
raise TypeError(f"losses must be a list but got a {type(losses)}")
|
|
|
|
return arg_mbs, kwarg_mbs
|
|
|
|
def _compute_loss(self, output, target):
|
|
return self._loss_fn(output, target) # type: ignore[misc]
|
|
|
|
def _split_inputs(
|
|
self,
|
|
args: Tuple[Any, ...],
|
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""
|
|
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
|
the chunks
|
|
"""
|
|
if args or kwargs:
|
|
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
|
args,
|
|
kwargs,
|
|
self._n_microbatches,
|
|
self._args_chunk_spec,
|
|
self._kwargs_chunk_spec,
|
|
)
|
|
return args_split, kwargs_split
|
|
else:
|
|
# Empty inputs (e.g. when called on middle stages)
|
|
# Return a list of empty tuples/dicts with matching length as chunks
|
|
return [()] * self._n_microbatches, [{}] * self._n_microbatches
|
|
|
|
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
|
|
"""
|
|
Merge output chunks back to a batch state.
|
|
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
|
|
"""
|
|
return merge_chunks(
|
|
output_chunks,
|
|
self._output_merge_spec,
|
|
)
|
|
|
|
|
|
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
|
|
"""
|
|
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
|
|
"""
|
|
if len(p2p_ops) == 0:
|
|
return None
|
|
desc_str = f"{desc}, " if desc else ""
|
|
logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
|
|
return dist.batch_isend_irecv(p2p_ops).pop()
|
|
|
|
|
|
def _sorted_batch_p2p(
|
|
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
|
|
) -> Dict[int, dist.Work]:
|
|
"""
|
|
Sorts the list of P2P ops by the peer rank, and then calls
|
|
batch_isend_irecv. Return a dictionary of works by peer rank. This function
|
|
helps us avoid hangs in case of skip connections.
|
|
"""
|
|
# Arrange p2p_ops by peer rank:
|
|
# int is the peer rank;
|
|
# List is the list of ops towards the peer
|
|
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
|
|
work_by_peer: Dict[int, dist.Work] = {}
|
|
if len(p2p_ops) == 0:
|
|
return work_by_peer
|
|
|
|
# Classify the ops by peer rank
|
|
for op in p2p_ops:
|
|
ops_by_peer[op.peer].append(op)
|
|
|
|
# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
|
|
for peer, ops in sorted(ops_by_peer.items()):
|
|
work_by_peer[peer] = _batch_p2p(ops, desc=desc)
|
|
|
|
return work_by_peer
|
|
|
|
|
|
class PipelineScheduleSingle(_PipelineSchedule):
|
|
"""
|
|
Base class for single-stage schedules.
|
|
Implements the `step` method.
|
|
Derived classes should implement `_step_microbatches`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stage: _PipelineStageBase,
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
# Init parent
|
|
super().__init__(
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
# Self attributes
|
|
self._stage = stage
|
|
self._num_stages = stage.num_stages
|
|
# Set the same has_backward flag for stage object
|
|
self._stage.has_backward = self._has_backward
|
|
|
|
# TODO: later replace this with lazy shape inference during forward
|
|
# Prepare forward send/recv infrastructure for stage
|
|
stage._prepare_forward_infra(n_microbatches)
|
|
if self._has_backward:
|
|
stage._prepare_backward_infra(n_microbatches)
|
|
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
|
|
# Clean per iteration
|
|
self._stage.clear_runtime_states()
|
|
|
|
# Split inputs into microbatches
|
|
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
|
|
|
# Split target into microbatches
|
|
if target is not None:
|
|
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
|
else:
|
|
targets_split = None
|
|
|
|
# Run microbatches
|
|
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
|
|
|
# Return merged results per original format
|
|
if self._stage.is_last:
|
|
return self._merge_outputs(self._stage.output_chunks)
|
|
else:
|
|
return None
|
|
|
|
|
|
class _ScheduleForwardOnly(PipelineScheduleSingle):
|
|
"""
|
|
The forward-only schedule.
|
|
Will go through all the microbatches and perform only the forward pass
|
|
"""
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule
|
|
"""
|
|
if target_mbs is not None or losses is not None:
|
|
raise RuntimeError(
|
|
"Forward-only schedule does not support loss computation"
|
|
)
|
|
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Delay send waits
|
|
fwd_sends_to_wait: List[dist.Work] = []
|
|
|
|
# Run microbatches
|
|
for i in range(self._n_microbatches):
|
|
with record_function(f"Forward {i}"):
|
|
ops = self._stage.get_fwd_recv_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
|
for work in works.values():
|
|
work.wait()
|
|
|
|
self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
|
|
|
ops = self._stage.get_fwd_send_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
|
fwd_sends_to_wait.extend(works.values())
|
|
|
|
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
|
|
|
|
# Wait for all forward sends to finish
|
|
# This should not have performance impact because by the time the first
|
|
# backward arrives all the forward sends should have been finished.
|
|
for work in fwd_sends_to_wait:
|
|
work.wait()
|
|
|
|
|
|
class ScheduleGPipe(PipelineScheduleSingle):
|
|
"""
|
|
The GPipe schedule.
|
|
Will go through all the microbatches in a fill-drain manner.
|
|
"""
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the GPipe schedule.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Delay send waits
|
|
fwd_sends_to_wait: List[dist.Work] = []
|
|
|
|
# Run microbatches
|
|
for i in range(self._n_microbatches):
|
|
with record_function(f"Forward {i}"):
|
|
ops = self._stage.get_fwd_recv_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
|
for work in works.values():
|
|
work.wait()
|
|
|
|
output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
|
|
|
ops = self._stage.get_fwd_send_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
|
fwd_sends_to_wait.extend(works.values())
|
|
|
|
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
|
|
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, i)
|
|
|
|
# Wait for all forward sends to finish
|
|
# This should not have performance impact because by the time the first
|
|
# backward arrives all the forward sends should have been finished.
|
|
for work in fwd_sends_to_wait:
|
|
work.wait()
|
|
|
|
# No loss function, no need to run backward
|
|
if not self._has_backward:
|
|
return
|
|
|
|
# Run backward
|
|
# Delay send waits
|
|
bwd_sends_to_wait: List[dist.Work] = []
|
|
for i in range(self._n_microbatches):
|
|
with record_function(f"Backward {i}"):
|
|
ops = self._stage.get_bwd_recv_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="bwd_recv")
|
|
for work in works.values():
|
|
work.wait()
|
|
|
|
loss = self._maybe_get_loss(self._stage, i)
|
|
self._stage.backward_one_chunk(i, loss=loss)
|
|
|
|
ops = self._stage.get_bwd_send_ops(i)
|
|
works = _sorted_batch_p2p(ops, desc="bwd_send")
|
|
bwd_sends_to_wait.extend(works.values())
|
|
|
|
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
|
|
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stage, losses)
|
|
|
|
# Wait for all backward sends to finish
|
|
for work in bwd_sends_to_wait:
|
|
work.wait()
|
|
|
|
|
|
class Schedule1F1B(PipelineScheduleSingle):
|
|
"""
|
|
The 1F1B schedule.
|
|
Will perform one forward and one backward on the microbatches in steady state.
|
|
"""
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Run one iteration of the pipeline schedule with list of microbatches.
|
|
Will go through all the microbatches according to the 1F1B schedule.
|
|
|
|
Args:
|
|
microbatches: list of microbatch args.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Last stage has 1 warmup, second-to-last 2 warmups, ...
|
|
# first stage `num_stages` warmups
|
|
warmup_chunks = min(
|
|
self._n_microbatches,
|
|
self._num_stages - self._stage.stage_index,
|
|
)
|
|
|
|
# Chunk counters
|
|
fwd_mb_index = 0
|
|
bwd_mb_index = 0
|
|
weight_stage_mb_index = 0
|
|
|
|
# Warmup phase
|
|
send_work = None
|
|
fwd_sends = []
|
|
for _ in range(warmup_chunks):
|
|
# Receive activations
|
|
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
|
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
|
|
recv_work.wait()
|
|
|
|
# Compute
|
|
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
|
|
|
# Clear previous chunk's forward sends (hopefully they have well
|
|
# finished, otherwise, we are heavily communication bound, in which
|
|
# case it doesn't create a lot of benefit to compute next chunk
|
|
# eagerly either)
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Send activations
|
|
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
|
if fwd_mb_index != warmup_chunks - 1:
|
|
# Safe to fire
|
|
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
|
|
# otherwise:
|
|
# The last foward send is left for fuse with first 1B in 1B1F below
|
|
|
|
# Compute loss
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
|
fwd_mb_index += 1
|
|
|
|
# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
|
|
|
|
# 1B1F phase
|
|
while True: # Don't worry, we have a break inside
|
|
# We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
|
|
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
|
|
|
# Now, we need to fire the fwd_sends and bwd_recvs together
|
|
if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
|
|
fuse_work.wait()
|
|
|
|
# Backward one chunk
|
|
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
|
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
|
|
|
# Get the bwd send ops, but don't fire, to be fused with the 1F below
|
|
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
|
bwd_mb_index += 1
|
|
|
|
if fwd_mb_index == self._n_microbatches:
|
|
# We are done with 1B1F, so break with some left-over bwd_sends
|
|
break
|
|
|
|
# We prepare 1F of the `1B1F`
|
|
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
|
|
|
|
# Fuse it with bwd_sends above
|
|
if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
|
|
fuse_work.wait()
|
|
|
|
# Now do the fwd
|
|
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
|
|
|
|
# Compute loss
|
|
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
|
|
|
|
# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
|
|
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
|
|
fwd_mb_index += 1
|
|
|
|
# Remember we still have some bwd_sends left over after the break? Now it is time to fire it
|
|
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
|
|
|
# Cooldown
|
|
while bwd_mb_index < self._n_microbatches:
|
|
# prepare bwd recv ops
|
|
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
|
|
if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
|
|
recv_work.wait()
|
|
|
|
# Backward one chunk
|
|
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
|
|
self._stage.backward_one_chunk(bwd_mb_index, loss=loss)
|
|
|
|
# Clear previous chunk's backward sends (hopefully they have well finished)
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Get the bwd send ops, fire it
|
|
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
|
|
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
|
bwd_mb_index += 1
|
|
|
|
# Wait for the last backward send to finish
|
|
if send_work:
|
|
send_work.wait()
|
|
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stage, losses)
|
|
|
|
|
|
def _add_unshard_reshard(
|
|
compute_actions: List[Optional[_Action]],
|
|
max_active_stages: int = 3,
|
|
) -> List[_Action]:
|
|
"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
|
|
|
|
UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
|
|
RESHARD does the opposite, releasing memory (but doing no commmunication)
|
|
|
|
We abandon the "timestep lock" during lowering
|
|
|
|
max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
|
|
3 stages is probably the thing we want?
|
|
(to account for having one f and one b active, and something else prefetching?)
|
|
"""
|
|
|
|
def next_stage_indices(
|
|
count: int, next_actions: List[Optional[_Action]]
|
|
) -> List[int]:
|
|
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
|
|
seen: Set[int] = set()
|
|
ret: List[int] = []
|
|
|
|
for a in next_actions:
|
|
if a is not None and a.stage_index not in seen:
|
|
seen.add(a.stage_index)
|
|
ret.append(a.stage_index)
|
|
if len(ret) == count:
|
|
break
|
|
return ret
|
|
|
|
active_stages: Set[int] = set()
|
|
fsdp_aware_actions: List[_Action] = []
|
|
|
|
def _unshard(stage_index: int):
|
|
active_stages.add(stage_index)
|
|
fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
|
|
|
|
def _reshard(stage_index: int):
|
|
active_stages.remove(stage_index)
|
|
fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
|
|
|
|
for i, action in enumerate(compute_actions):
|
|
if action is None:
|
|
continue
|
|
|
|
# We prefetch the next N stages we'll see, dropping existing stages to make room
|
|
next_n = next_stage_indices(max_active_stages, compute_actions[i:])
|
|
# Fetch needs to be ordered correctly, so don't use a set
|
|
fetch = list(filter(lambda s: s not in active_stages, next_n))
|
|
# Unclear what the best policy is for eviction, but we can maintain order so we do
|
|
evict = list(filter(lambda s: s not in next_n, active_stages))
|
|
|
|
# logger.debug(
|
|
# "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
|
|
# i,
|
|
# active_stages,
|
|
# fetch,
|
|
# evict,
|
|
# )
|
|
|
|
for stage in evict:
|
|
_reshard(stage)
|
|
for stage in fetch:
|
|
_unshard(stage)
|
|
fsdp_aware_actions.append(action)
|
|
|
|
return fsdp_aware_actions
|
|
|
|
|
|
def _add_send_recv(
|
|
compute_actions: Dict[int, List[_Action]],
|
|
stage_to_rank: Callable[[int], int],
|
|
num_stages: int,
|
|
) -> Dict[int, List[_Action]]:
|
|
comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
|
|
|
|
def _has_comms(action: _Action) -> bool:
|
|
if action.computation_type == F:
|
|
return action.stage_index != num_stages - 1
|
|
elif action.computation_type == B:
|
|
return action.stage_index != 0
|
|
return False
|
|
|
|
def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
|
|
assert _has_comms(action), f"{action} is not a valid comm action"
|
|
stage_idx = action.stage_index
|
|
ctype = action.computation_type
|
|
mb_idx = action.microbatch_index
|
|
send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
|
|
recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
|
|
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
|
|
return send, recv
|
|
|
|
def _ready_to_schedule(
|
|
action: Optional[_Action], prev_actions: List[_Action]
|
|
) -> bool:
|
|
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
|
|
This helps ensure a sane (non-hanging) ordering of sends and recvs.
|
|
But it also means we might not be able to schedule our next compute action yet.
|
|
"""
|
|
if action is None:
|
|
return True
|
|
elif action.computation_type == F and not action.stage_index == 0:
|
|
expected_recv = _Action(
|
|
action.stage_index,
|
|
RECV_F if action.computation_type == F else RECV_B,
|
|
action.microbatch_index,
|
|
)
|
|
return expected_recv in prev_actions
|
|
elif action.computation_type == B and not action.stage_index == num_stages - 1:
|
|
expected_recv = _Action(
|
|
action.stage_index,
|
|
RECV_F if action.computation_type == F else RECV_B,
|
|
action.microbatch_index,
|
|
)
|
|
return expected_recv in prev_actions
|
|
else:
|
|
return True
|
|
|
|
while compute_actions:
|
|
progress = False
|
|
# go in order of ranks even if dict keys aren't ordered
|
|
for rank in range(len(compute_actions)):
|
|
assert len(compute_actions[rank]) > 0
|
|
action = compute_actions[rank][0]
|
|
|
|
if not _ready_to_schedule(action, comm_actions[rank]):
|
|
continue
|
|
|
|
if action is not None:
|
|
comm_actions[rank].append(action)
|
|
if _has_comms(action):
|
|
send, recv = _get_comms(action)
|
|
# TODO we can avoid send/recv if the 2 stages are on the same rank.
|
|
# should we avoid that in the runtime or here?
|
|
comm_actions[rank].append(send)
|
|
comm_actions[stage_to_rank(recv.stage_index)].append(recv)
|
|
|
|
compute_actions[rank].pop(0)
|
|
if len(compute_actions[rank]) == 0:
|
|
del compute_actions[rank]
|
|
progress = True
|
|
assert progress, "Malformed compute schedule, can't schedule sends/recvs"
|
|
return comm_actions
|
|
|
|
|
|
class PipelineScheduleMulti(_PipelineSchedule):
|
|
"""
|
|
Base class for multi-stage schedules.
|
|
Implements the `step` method.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
|
|
use_full_backward: bool = True,
|
|
):
|
|
if len(stages) <= 1:
|
|
raise ValueError(
|
|
f"Multi-stage schedule expects at least two stages but got {len(stages)}"
|
|
)
|
|
# Init parent
|
|
super().__init__(
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
# Self attributes
|
|
self._stages = stages
|
|
self._num_stages = stages[0].num_stages
|
|
self.pp_group_size = stages[0].group_size
|
|
self.rank = stages[0].group_rank
|
|
# Set the pipeline stage states
|
|
if stage_index_to_group_rank is not None:
|
|
for stage in self._stages:
|
|
stage.stage_index_to_group_rank = stage_index_to_group_rank
|
|
self.stage_index_to_group_rank = stages[0].stage_index_to_group_rank
|
|
|
|
# Set the same has_backward flag for stage object
|
|
for stage in self._stages:
|
|
stage.has_backward = self._has_backward
|
|
|
|
self._should_compute_loss = (
|
|
lambda stage: stage.is_last and self._loss_fn is not None
|
|
)
|
|
|
|
# This will be set during init of derived schedules
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
self.use_full_backward = use_full_backward
|
|
|
|
# TODO: later replace this with lazy shape inference during forward
|
|
# Prepare forward send/recv infrastructure for stage
|
|
for stage in self._stages:
|
|
stage._prepare_forward_infra(n_microbatches)
|
|
if self._has_backward:
|
|
stage._prepare_backward_infra(n_microbatches)
|
|
|
|
def _dump_csv(self, filename):
|
|
"""Dump a CSV representation of the schedule into a file with the provided filename."""
|
|
with open(filename, "w", newline="") as csvfile:
|
|
writer = csv.writer(csvfile)
|
|
for rank in self.pipeline_order:
|
|
writer.writerow(self.pipeline_order[rank])
|
|
|
|
def _validate_schedule(self):
|
|
# TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
|
|
def _validate_rank_actions(
|
|
actions: Dict[int, List[_Action | None]],
|
|
num_stages: int,
|
|
num_microbatches: int,
|
|
):
|
|
# We will count all the actions per stage and ensure they happen in a valid order
|
|
# (e.g. F before B before W for a given microbatch)
|
|
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
|
stage_id: {
|
|
F: set(),
|
|
B: set(),
|
|
W: set(),
|
|
}
|
|
for stage_id in range(num_stages)
|
|
}
|
|
for rank in actions:
|
|
for action in actions[rank]:
|
|
if action is None:
|
|
continue
|
|
assert isinstance(
|
|
action, _Action
|
|
), f"Got an invalid action: {action}, expected instance of _Action"
|
|
s_id = action.stage_index
|
|
ctype = action.computation_type
|
|
mb_id = action.microbatch_index
|
|
if ctype == F:
|
|
stage_actions[s_id][F].add(mb_id)
|
|
elif ctype == B:
|
|
assert (
|
|
mb_id in stage_actions[s_id][F]
|
|
), f"Running Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
|
stage_actions[s_id][B].add(mb_id)
|
|
elif ctype == W:
|
|
assert (
|
|
not self.use_full_backward
|
|
), "Schedule contains 'W' actions, but is configured to use full backward"
|
|
assert (
|
|
mb_id in stage_actions[s_id][B]
|
|
), f"Running Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
|
|
stage_actions[s_id][W].add(mb_id)
|
|
|
|
for s_id in stage_actions:
|
|
for ctype in (F, B, W):
|
|
stage_mb = len(stage_actions[s_id][ctype])
|
|
assert (
|
|
stage_mb == num_microbatches
|
|
), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
|
|
|
|
assert (
|
|
len(self.pipeline_order) == self.pp_group_size
|
|
), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
|
|
for rank in range(self.pp_group_size):
|
|
assert (
|
|
rank in self.pipeline_order
|
|
), f"Schedule is missing actions for rank {rank}"
|
|
_validate_rank_actions(
|
|
self.pipeline_order,
|
|
self._num_stages,
|
|
self._n_microbatches,
|
|
)
|
|
|
|
def _load_csv(self, filename, format="compute_only"):
|
|
"""Load a CSV representation of the schedule from a file with the provided filename.
|
|
This API will most likely get renamed/refactored so is marked as internal for now.
|
|
|
|
format must be "compute_only" for PipelineScheduleMulti
|
|
"""
|
|
assert format == "compute_only"
|
|
with open(filename, newline="") as csvfile:
|
|
reader = csv.reader(csvfile)
|
|
for rank, row in enumerate(reader):
|
|
self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
|
|
self._validate_schedule()
|
|
|
|
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
|
"""
|
|
Run one iteration of the pipeline schedule with *whole-batch* input.
|
|
Will chunk the input into microbatches automatically, and go through the
|
|
microbatches according to the schedule implementation.
|
|
|
|
args: positional arguments to the model (as in non-pipeline case).
|
|
kwargs: keyword arguments to the model (as in non-pipeline case).
|
|
target: target for the loss function.
|
|
losses: a list to store the losses for each microbatch.
|
|
"""
|
|
|
|
# Clean per iteration
|
|
for stage in self._stages:
|
|
stage.clear_runtime_states()
|
|
|
|
# Split inputs into microbatches
|
|
args_split, kwargs_split = self._split_inputs(args, kwargs)
|
|
|
|
# Split target into microbatches
|
|
if target is not None:
|
|
targets_split = list(torch.tensor_split(target, self._n_microbatches))
|
|
else:
|
|
targets_split = None
|
|
|
|
# Run microbatches
|
|
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
|
|
|
# Return merged results per original format
|
|
for stage in self._stages:
|
|
if stage.is_last:
|
|
return self._merge_outputs(stage.output_chunks)
|
|
# Does not contain the last stage
|
|
return None
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
|
|
|
TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
|
|
not support models with skip connections.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Based on the plan in Step 1 created in __init__:
|
|
# 2. Perform communication based on the pipeline_order
|
|
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
|
stage.stage_index: stage for stage in self._stages
|
|
}
|
|
|
|
# determine prev_rank and next_rank based on which ranks are next to
|
|
# the stages in the pipeline_order
|
|
all_prev_ranks: Set[int] = set()
|
|
all_next_ranks: Set[int] = set()
|
|
for stage_index in stage_index_to_stage.keys():
|
|
# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
|
|
if stage_index > 0:
|
|
all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
|
|
if stage_index < self._num_stages - 1:
|
|
all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
|
|
|
|
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
|
try:
|
|
ops: List[dist.P2POp] = []
|
|
if action is not None:
|
|
computation_type = action.computation_type
|
|
mb_index = action.microbatch_index
|
|
stage_index = action.stage_index
|
|
assert (
|
|
mb_index is not None
|
|
), "All currently supported action types require valid microbatch_index"
|
|
if computation_type == _ComputationType.FORWARD:
|
|
# perform forward computation
|
|
stage = stage_index_to_stage[stage_index]
|
|
output = stage.forward_one_chunk(
|
|
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
|
)
|
|
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
|
ops.extend(stage.get_fwd_send_ops(mb_index))
|
|
elif computation_type == _ComputationType.BACKWARD:
|
|
# perform backward computation
|
|
stage = stage_index_to_stage[stage_index]
|
|
loss = self._maybe_get_loss(stage, mb_index)
|
|
stage.backward_one_chunk(
|
|
mb_index, loss=loss, full_backward=self.use_full_backward
|
|
)
|
|
ops.extend(stage.get_bwd_send_ops(mb_index))
|
|
elif computation_type == _ComputationType.WEIGHT:
|
|
# perform weight update
|
|
if self.use_full_backward:
|
|
raise ValueError(
|
|
f"We detected a weight update in the pipeline schedule, but \
|
|
{self.use_full_backward=}"
|
|
)
|
|
stage = stage_index_to_stage[stage_index]
|
|
stage.backward_weight_one_chunk(mb_index)
|
|
else:
|
|
raise ValueError(f"Unknown computation type {computation_type}")
|
|
|
|
# Look at the neighboring ranks for this current timestep and determine whether
|
|
# this current rank needs to do any recv communication
|
|
for prev_rank in all_prev_ranks:
|
|
prev_rank_ops = self.pipeline_order[prev_rank]
|
|
prev_rank_action = None
|
|
if time_step < len(prev_rank_ops):
|
|
prev_rank_action = prev_rank_ops[time_step]
|
|
if prev_rank_action is not None:
|
|
computation_type = prev_rank_action.computation_type
|
|
mb_index = prev_rank_action.microbatch_index
|
|
stage_index = prev_rank_action.stage_index
|
|
assert (
|
|
mb_index is not None
|
|
), "All currently supported action types require valid microbatch_index"
|
|
# Only handle sends for the forward from a previous rank
|
|
if computation_type == _ComputationType.FORWARD:
|
|
# If not the last stage, then receive fwd activations
|
|
if stage_index + 1 in stage_index_to_stage:
|
|
# TODO: We are assuming that stage will always receive from stage-1
|
|
# however that is not necessarily true of get_fwd_recv_ops
|
|
stage = stage_index_to_stage[stage_index + 1]
|
|
ops.extend(stage.get_fwd_recv_ops(mb_index))
|
|
elif (
|
|
computation_type == _ComputationType.BACKWARD
|
|
or computation_type == _ComputationType.WEIGHT
|
|
):
|
|
# Previous rank doing backward or weight update has no influence for the current rank forward recv
|
|
pass
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown computation type {computation_type}"
|
|
)
|
|
for next_rank in all_next_ranks:
|
|
next_rank_ops = self.pipeline_order[next_rank]
|
|
next_rank_action = None
|
|
if time_step < len(next_rank_ops):
|
|
next_rank_action = next_rank_ops[time_step]
|
|
if next_rank_action is not None:
|
|
computation_type = next_rank_action.computation_type
|
|
mb_index = next_rank_action.microbatch_index
|
|
stage_index = next_rank_action.stage_index
|
|
assert (
|
|
mb_index is not None
|
|
), "All currently supported action types require valid microbatch_index"
|
|
# Only handle receives for the backwards from a next rank
|
|
if (
|
|
computation_type == _ComputationType.FORWARD
|
|
or computation_type == _ComputationType.WEIGHT
|
|
):
|
|
# Next rank doing forward or weight update has no influence for the current rank backward recv
|
|
pass
|
|
elif computation_type == _ComputationType.BACKWARD:
|
|
# If not the first stage, then receive bwd gradients
|
|
if stage_index - 1 in stage_index_to_stage:
|
|
# TODO: We are assuming that stage will always receive from stage+1
|
|
# however that is not necessarily true of get_bwd_recv_ops
|
|
stage = stage_index_to_stage[stage_index - 1]
|
|
ops.extend(stage.get_bwd_recv_ops(mb_index))
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown computation type {computation_type}"
|
|
)
|
|
|
|
# do the communication
|
|
if ops:
|
|
_batch_p2p(ops).wait()
|
|
except Exception as e:
|
|
logger.error(
|
|
"[Rank %s] pipeline schedule %s caught the following exception \
|
|
at time_step %s when running action %s",
|
|
self.rank,
|
|
self.__class__.__name__,
|
|
time_step,
|
|
action,
|
|
)
|
|
logger.error("%s", _format_pipeline_order(self.pipeline_order))
|
|
raise e
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stages, losses)
|
|
|
|
|
|
class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|
"""
|
|
Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
|
|
|
|
Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
|
|
subclassed and the subclass can be responsible for creating a schedule IR.
|
|
"""
|
|
|
|
def _load_actions(
|
|
self,
|
|
actions: Dict[int, List[Optional[_Action]]],
|
|
format: str = "compute_only",
|
|
):
|
|
"""
|
|
Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
|
|
communication actions. Stores the schedule in self, and must be called before running step_mo()
|
|
"""
|
|
assert (
|
|
self.stage_index_to_group_rank is not None
|
|
), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
|
|
self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
|
|
if format == "compute_comms":
|
|
for rank in actions:
|
|
self.pipeline_order_with_comms[rank] = []
|
|
for action in actions[rank]:
|
|
assert action is not None
|
|
self.pipeline_order_with_comms[rank].append(action)
|
|
# TODO what level of validation should we offer for compute+comms schedule?
|
|
elif format == "compute_only":
|
|
# Perform schedule lowering
|
|
for rank in actions:
|
|
self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
|
|
actions[rank]
|
|
)
|
|
|
|
self.pipeline_order_with_comms = _add_send_recv(
|
|
self.pipeline_order_with_comms,
|
|
stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
|
|
num_stages=self._num_stages,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"{format=} is not implemented")
|
|
|
|
def _load_csv(self, filename: str, format: str = "compute_only"):
|
|
"""Loads a csv in simple format and then lowers it to include comunication actions
|
|
|
|
format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
|
|
will automatically be run to generate a compute_comms schedule.
|
|
"""
|
|
if format == "compute_only":
|
|
# this will populate self.pipeline_order
|
|
super()._load_csv(filename)
|
|
# this will populate self.pipeline_order_with_comms
|
|
self._load_actions(self.pipeline_order)
|
|
elif format == "compute_comms":
|
|
actions = {}
|
|
with open(filename, newline="") as csvfile:
|
|
reader = csv.reader(csvfile)
|
|
for rank, row in enumerate(reader):
|
|
actions[rank] = [_Action.from_str(s) for s in row]
|
|
self._load_actions(actions, format=format)
|
|
else:
|
|
raise NotImplementedError(f"{format=} is not implemented")
|
|
|
|
def _dump_csv(self, filename: str):
|
|
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
|
# TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible
|
|
# that it does not exist if it was created from a compute_comms schedule.
|
|
assert (
|
|
self.pipeline_order_with_comms is not None
|
|
), "Must initialize compute_comms schedule before dump_csv"
|
|
with open(filename, "w", newline="") as csvfile:
|
|
writer = csv.writer(csvfile)
|
|
for rank in self.pipeline_order_with_comms:
|
|
writer.writerow(self.pipeline_order_with_comms[rank])
|
|
|
|
def _step_microbatches(
|
|
self,
|
|
arg_mbs: Optional[List] = None,
|
|
kwarg_mbs: Optional[List] = None,
|
|
target_mbs: Optional[List] = None,
|
|
losses: Optional[List] = None,
|
|
):
|
|
"""
|
|
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
|
|
|
TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
|
|
not support models with skip connections.
|
|
"""
|
|
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
|
|
|
# Based on the plan in Step 1 created in __init__:
|
|
# 2. Perform communication based on the pipeline_order
|
|
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
|
stage.stage_index: stage for stage in self._stages
|
|
}
|
|
|
|
assert (
|
|
self.pipeline_order_with_comms is not None
|
|
), "Must call _load_actions() before calling _step_microbatches()"
|
|
|
|
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
|
bwd_recv_ops: Dict[Tuple[int, int], Work] = {}
|
|
fwd_recv_ops: Dict[Tuple[int, int], Work] = {}
|
|
|
|
# send ops should be waited on before step() exists, mainly for hygeine
|
|
send_ops: List[Work] = []
|
|
|
|
# we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
|
|
unshard_ops: Dict[int, UnshardHandle] = {}
|
|
unsharded_stages = set()
|
|
|
|
def _assert_unsharded(stage_idx: int):
|
|
"""If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
|
|
if stage_idx in unshard_ops:
|
|
unshard_ops[stage_idx].wait()
|
|
del unshard_ops[stage_idx]
|
|
unsharded_stages.add(stage_idx)
|
|
assert (
|
|
stage_idx in unsharded_stages
|
|
), f"Attempted to compute on sharded {stage_idx=}"
|
|
|
|
for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
|
|
try:
|
|
comp_type = action.computation_type
|
|
mb_index: int = (
|
|
action.microbatch_index
|
|
if action.microbatch_index is not None
|
|
else -1
|
|
)
|
|
assert mb_index >= 0 or comp_type in (
|
|
UNSHARD,
|
|
RESHARD,
|
|
), f"{action=} missing mb_index"
|
|
stage_idx = action.stage_index
|
|
stage = stage_index_to_stage[stage_idx]
|
|
stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
|
|
|
|
logger.debug(
|
|
"_PipelineScheduleRuntime running time_step %d, action %s",
|
|
time_step,
|
|
action,
|
|
)
|
|
|
|
# TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
|
|
# since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
|
|
# safe to use instead.
|
|
# However, I was wondering if I should avoid calling batched operators at all in the case that there is
|
|
# only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
|
|
if comp_type == SEND_F:
|
|
send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
|
|
elif comp_type == SEND_B:
|
|
send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
|
|
elif comp_type == RECV_F:
|
|
assert (
|
|
stage_idx,
|
|
mb_index,
|
|
) not in fwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing forward"
|
|
fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
|
stage.get_fwd_recv_ops(mb_index)
|
|
)
|
|
elif comp_type == RECV_B:
|
|
assert (
|
|
stage_idx,
|
|
mb_index,
|
|
) not in bwd_recv_ops, "Recv twice for {stage_idx=} {mb_index=} without executing backward"
|
|
bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
|
|
stage.get_bwd_recv_ops(mb_index)
|
|
)
|
|
elif comp_type == UNSHARD:
|
|
if stage_uses_fsdp:
|
|
assert (
|
|
stage_idx not in unsharded_stages
|
|
and stage_idx not in unshard_ops
|
|
), f"Unsharding the same {stage_idx=} twice"
|
|
unshard_ops[stage_idx] = stage.submod.unshard(async_op=True)
|
|
elif comp_type == RESHARD:
|
|
if stage_uses_fsdp:
|
|
assert (
|
|
stage_idx in unsharded_stages
|
|
), f"Resharding {stage_idx=} without unsharding"
|
|
assert (
|
|
stage_idx not in unshard_ops
|
|
), f"Resharding {stage_idx=} before finishing unshard"
|
|
stage.submod.reshard()
|
|
elif comp_type == FORWARD:
|
|
if stage_uses_fsdp:
|
|
_assert_unsharded(stage_idx)
|
|
|
|
if not stage.is_first:
|
|
assert (
|
|
stage_idx,
|
|
mb_index,
|
|
) in fwd_recv_ops, f"Computing {action=} before receiving input"
|
|
fwd_recv_ops.pop((stage_idx, mb_index)).wait()
|
|
output = stage.forward_one_chunk(
|
|
mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
|
|
)
|
|
self._maybe_compute_loss(stage, output, target_mbs, mb_index)
|
|
elif comp_type == BACKWARD:
|
|
if stage_uses_fsdp:
|
|
_assert_unsharded(stage_idx)
|
|
|
|
if not stage.is_last:
|
|
assert (
|
|
stage_idx,
|
|
mb_index,
|
|
) in bwd_recv_ops, (
|
|
f"Attempted to run compute {action=} before receiving input"
|
|
)
|
|
bwd_recv_ops.pop((stage_idx, mb_index)).wait()
|
|
loss = self._maybe_get_loss(stage, mb_index)
|
|
stage.backward_one_chunk(
|
|
mb_index, loss=loss, full_backward=self.use_full_backward
|
|
)
|
|
elif comp_type == WEIGHT:
|
|
if stage_uses_fsdp:
|
|
_assert_unsharded(stage_idx)
|
|
|
|
if self.use_full_backward:
|
|
raise ValueError(
|
|
f"We detected a weight update in the pipeline schedule, but \
|
|
{self.use_full_backward=}"
|
|
)
|
|
stage.backward_weight_one_chunk(mb_index)
|
|
else:
|
|
raise ValueError(f"{action=} is unknown or unsupported")
|
|
except Exception as e:
|
|
logger.error(
|
|
"_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
|
|
time_step,
|
|
action,
|
|
)
|
|
# TODO(whc) what is the best practice for printing a multiline log?
|
|
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
|
print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type]
|
|
raise e
|
|
|
|
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
|
while len(send_ops):
|
|
send_ops.pop().wait()
|
|
|
|
assert len(unshard_ops) == 0, "Unused unshard operations"
|
|
|
|
# Return losses if there is a container passed in
|
|
self._update_losses(self._stages, losses)
|
|
|
|
|
|
class ScheduleLoopedBFS(PipelineScheduleMulti):
|
|
"""
|
|
Breadth-First Pipeline Parallelism.
|
|
See https://arxiv.org/abs/2211.05953 for details.
|
|
Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
|
|
What is different is that when microbatches are ready for multiple local
|
|
stages, Loops BFS will prioritizes the earlier stage, running all available
|
|
microbatches at once.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
# ========================================================================
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
def _calculate_single_rank_operations(self, rank):
|
|
n_local_stages = len(self._stages)
|
|
stage_indices = range(
|
|
rank, self.pp_group_size * n_local_stages, self.pp_group_size
|
|
)
|
|
|
|
# Store the list of operations used for that rank
|
|
rank_ops: List[Optional[_Action]] = []
|
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
|
for _ in range(rank):
|
|
rank_ops.append(None)
|
|
|
|
for stage_index in stage_indices:
|
|
for mb_index in range(self._n_microbatches):
|
|
rank_ops.append(
|
|
_Action(stage_index, _ComputationType.FORWARD, mb_index)
|
|
)
|
|
|
|
# wait for the first backward to trickle up
|
|
# which is 2 for every hop away
|
|
post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
|
|
rank_ops.extend([None] * post_warmup_ops)
|
|
|
|
for stage_index in reversed(stage_indices):
|
|
for mb_index in reversed(range(self._n_microbatches)):
|
|
rank_ops.append(
|
|
_Action(stage_index, _ComputationType.BACKWARD, mb_index)
|
|
)
|
|
return rank_ops
|
|
|
|
|
|
def _get_1f1b_rank_ops(
|
|
n_local_stages,
|
|
pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
num_1f1b_microbatches=0,
|
|
enable_zero_bubble=False,
|
|
):
|
|
# All stages start with handling microbatch 0
|
|
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
|
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
|
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
|
|
|
# Store the list of operations used for that rank
|
|
rank_ops: List[Optional[_Action]] = []
|
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
|
for _ in range(rank):
|
|
rank_ops.append(None)
|
|
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
|
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
|
# Formula:
|
|
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
|
|
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
|
|
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
|
|
# warmup_ops = calculated above
|
|
post_warmup_ops = (
|
|
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
|
|
) - (warmup_ops + rank)
|
|
|
|
if enable_zero_bubble:
|
|
post_warmup_ops = pp_group_size - rank - 1
|
|
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
|
|
backward_op_ids = []
|
|
weight_op_count = 0
|
|
|
|
for op in range(total_ops):
|
|
# Warmup phase
|
|
if op < warmup_ops:
|
|
fwd_stage_index = forward_stage_index(op)
|
|
# This will assign the current microbatch index and update it as well
|
|
fwd_stage_mb_index[fwd_stage_index] = (
|
|
mb_index := fwd_stage_mb_index[fwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
|
|
)
|
|
if op == warmup_ops - 1:
|
|
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
|
|
rank_ops.extend([None] * post_warmup_ops)
|
|
# 1F1B Phase (forward and backward)
|
|
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
|
|
fwd_stage_index = forward_stage_index(op)
|
|
fwd_stage_mb_index[fwd_stage_index] = (
|
|
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
|
|
)
|
|
bwd_stage_index = backward_stage_index(op)
|
|
bwd_stage_mb_index[bwd_stage_index] = (
|
|
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
|
)
|
|
backward_op_ids.append(op)
|
|
|
|
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
|
weight_stage_index = backward_stage_index(
|
|
backward_op_ids[weight_op_count]
|
|
)
|
|
weight_stage_mb_index[weight_stage_index] = (
|
|
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(
|
|
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
|
)
|
|
)
|
|
weight_op_count += 1
|
|
# Cooldown phase
|
|
else:
|
|
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
|
|
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
|
|
if not enable_zero_bubble:
|
|
rank_ops.append(None)
|
|
|
|
bwd_stage_index = backward_stage_index(op)
|
|
bwd_stage_mb_index[bwd_stage_index] = (
|
|
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
|
)
|
|
backward_op_ids.append(op)
|
|
|
|
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
|
weight_stage_index = backward_stage_index(
|
|
backward_op_ids[weight_op_count]
|
|
)
|
|
weight_stage_mb_index[weight_stage_index] = (
|
|
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(
|
|
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
|
)
|
|
)
|
|
weight_op_count += 1
|
|
|
|
while enable_zero_bubble and weight_op_count < len(backward_op_ids):
|
|
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
|
|
weight_stage_mb_index[weight_stage_index] = (
|
|
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
|
) + 1
|
|
rank_ops.append(
|
|
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
|
|
)
|
|
weight_op_count += 1
|
|
|
|
return rank_ops
|
|
|
|
|
|
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
|
"""
|
|
The Interleaved 1F1B schedule.
|
|
See https://arxiv.org/pdf/2104.04473 for details.
|
|
Will perform one forward and one backward on the microbatches in steady
|
|
state and supports multiple stages per rank. When microbatches are ready for
|
|
multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
|
|
(also called "depth first").
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
self.pp_group_size = stages[0].group_size
|
|
# TODO: is this limitation a must?
|
|
if n_microbatches % self.pp_group_size != 0:
|
|
raise ValueError(
|
|
f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \
|
|
to be a multiple of the number of pipeline ranks ({self.pp_group_size})."
|
|
)
|
|
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
)
|
|
|
|
self.n_local_stages = len(stages)
|
|
self.rank = stages[0].group_rank
|
|
self.group = stages[0].group
|
|
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
|
def get_rank_warmup_ops(rank):
|
|
# Warms up operations for last stage
|
|
warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size
|
|
# Increment warmup operations by 2 for each hop away from the last stage
|
|
warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank)
|
|
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
|
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
|
|
|
warmup_ops = get_rank_warmup_ops(rank)
|
|
microbatch_ops = self.n_local_stages * self._n_microbatches
|
|
# fwd_bwd_ops should encompass the remaining forwards
|
|
fwd_bwd_ops = microbatch_ops - warmup_ops
|
|
# cooldown_ops should encompass the remaining backwards
|
|
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
|
# total ops encompass both forward and backward ops
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
|
|
|
logger.debug(
|
|
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
|
rank,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
total_ops,
|
|
)
|
|
|
|
# Calculates the stage index based on step and pp_group_size
|
|
def forward_stage_index(step):
|
|
# Get the local index from 0 to n_local_stages-1
|
|
local_index = (step // self.pp_group_size) % self.n_local_stages
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
def backward_stage_index(step):
|
|
local_index = (
|
|
self.n_local_stages
|
|
- 1
|
|
- ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages
|
|
)
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
return _get_1f1b_rank_ops(
|
|
self.n_local_stages,
|
|
self.pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
)
|
|
|
|
|
|
class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
|
"""
|
|
The Flexible Interleaved 1F1B schedule.
|
|
|
|
This schedule is mostly similar to the interleaved 1F1B schedule.
|
|
It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
|
|
Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
|
|
it works as long as n_microbatches % num_rounds is 0. As a few examples, support
|
|
|
|
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
|
|
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
|
|
|
|
When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
enable_zero_bubble: bool = False,
|
|
):
|
|
self.pp_group_size = stages[0].group_size
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
use_full_backward=not enable_zero_bubble,
|
|
)
|
|
self.n_local_stages = len(stages)
|
|
self.rank = stages[0].group_rank
|
|
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
|
|
self.microbatches_per_round = n_microbatches // self.number_of_rounds
|
|
self.enable_zero_bubble = enable_zero_bubble
|
|
if n_microbatches % self.number_of_rounds != 0:
|
|
raise ValueError(
|
|
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
|
|
f"multiple of the number of rounds ({self.number_of_rounds}), "
|
|
f"but got {n_microbatches}."
|
|
)
|
|
# 1. Create the pipeline_order (all ranks do this calculation)
|
|
# This will be used to keep track of the current state of the entire pipeline
|
|
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
|
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
|
for rank in range(self.pp_group_size):
|
|
rank_ops = self._calculate_single_rank_operations(rank)
|
|
self.pipeline_order[rank] = rank_ops
|
|
|
|
# This function add bubbles to the generated schedule based on dependencies of actions
|
|
# Note that the ZB1P schedule will not require bubbles to be manually added and it is
|
|
# only useful when n_microbatches <= microbatches_per_round
|
|
self.pipeline_order = self._add_bubbles_to_actions(
|
|
self.n_local_stages * self.pp_group_size,
|
|
)
|
|
|
|
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
|
def get_rank_warmup_ops(rank):
|
|
# Warms up operations for last stage
|
|
warmups_ops_last_stage = (
|
|
self.n_local_stages - 1
|
|
) * self.microbatches_per_round
|
|
# Increment warmup operations by 2 for each hop away from the last stage
|
|
multiply_factor = 1 if self.enable_zero_bubble else 2
|
|
warmup_ops = warmups_ops_last_stage + multiply_factor * (
|
|
(self.pp_group_size - 1) - rank
|
|
)
|
|
|
|
# We cannot have more warmup operations than there are number of microbatches, so cap it there
|
|
return min(warmup_ops, self._n_microbatches * self.n_local_stages)
|
|
|
|
warmup_ops = get_rank_warmup_ops(rank)
|
|
microbatch_ops = self.n_local_stages * self._n_microbatches
|
|
# fwd_bwd_ops should encompass the remaining forwards
|
|
fwd_bwd_ops = microbatch_ops - warmup_ops
|
|
# cooldown_ops should encompass the remaining backwards
|
|
cooldown_ops = microbatch_ops - fwd_bwd_ops
|
|
# total ops encompass both forward and backward ops
|
|
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
# warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
|
|
logger.debug(
|
|
"rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
|
|
rank,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
total_ops,
|
|
)
|
|
|
|
# Calculates the stage index based on step and pp_group_size
|
|
|
|
def forward_stage_index(step):
|
|
# Get the local index from 0 to n_local_stages-1
|
|
local_index = (step // self.microbatches_per_round) % self.n_local_stages
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
def backward_stage_index(step):
|
|
local_index = (
|
|
self.n_local_stages
|
|
- 1
|
|
- ((step - warmup_ops) // self.microbatches_per_round)
|
|
% self.n_local_stages
|
|
)
|
|
return (local_index * self.pp_group_size) + rank
|
|
|
|
if self.enable_zero_bubble:
|
|
num_1f1b_microbatches = rank
|
|
|
|
return _get_1f1b_rank_ops(
|
|
self.n_local_stages,
|
|
self.pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
num_1f1b_microbatches,
|
|
enable_zero_bubble=True,
|
|
)
|
|
|
|
return _get_1f1b_rank_ops(
|
|
self.n_local_stages,
|
|
self.pp_group_size,
|
|
warmup_ops,
|
|
fwd_bwd_ops,
|
|
cooldown_ops,
|
|
rank,
|
|
forward_stage_index,
|
|
backward_stage_index,
|
|
)
|
|
|
|
def _add_bubbles_to_actions(self, num_stages_global):
|
|
actions = self.pipeline_order
|
|
if not self.enable_zero_bubble:
|
|
return actions
|
|
|
|
def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
|
|
if op == _ComputationType.FORWARD:
|
|
if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
|
|
return True
|
|
elif op == _ComputationType.BACKWARD:
|
|
if stage == num_stages_global - 1:
|
|
return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
|
|
return (stage + 1, op, microbatch) not in seen_ops
|
|
return False
|
|
|
|
seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
|
|
result: Dict[int, List[Optional[_Action]]] = {}
|
|
next_pointer: Dict[int, int] = {}
|
|
bubbles_added: Dict[int, int] = {}
|
|
total_bubbles_added = 0
|
|
|
|
for rank in range(self.pp_group_size):
|
|
result[rank] = []
|
|
next_pointer[rank] = 0
|
|
bubbles_added[rank] = 0
|
|
|
|
while True:
|
|
should_stop = True
|
|
|
|
temp_seen_ops: Set[Tuple[int, _ComputationType, int]] = set()
|
|
|
|
for rank in range(self.pp_group_size):
|
|
timestamp = next_pointer[rank]
|
|
if timestamp >= len(actions[rank]):
|
|
continue
|
|
|
|
should_stop = False
|
|
|
|
if actions[rank][timestamp] is not None:
|
|
temp_action = actions[rank][timestamp]
|
|
assert temp_action is not None
|
|
stage_index, op, microbatch = temp_action
|
|
if not need_bubble(
|
|
stage_index, op, microbatch, num_stages_global, seen_ops
|
|
):
|
|
result[rank].append(actions[rank][timestamp])
|
|
if microbatch is not None:
|
|
temp_seen_ops.add((stage_index, op, microbatch))
|
|
next_pointer[rank] += 1
|
|
else:
|
|
result[rank].append(None)
|
|
bubbles_added[rank] += 1
|
|
else:
|
|
next_pointer[rank] += 1
|
|
result[rank].append(None)
|
|
|
|
seen_ops.update(temp_seen_ops)
|
|
if should_stop:
|
|
break
|
|
|
|
if total_bubbles_added > 0:
|
|
logger.warning(
|
|
"Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
|
|
total_bubbles_added,
|
|
bubbles_added,
|
|
)
|
|
return result
|
|
|
|
|
|
class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
|
|
"""
|
|
The Interleaved Zero Bubble schedule.
|
|
See https://arxiv.org/pdf/2401.10241 for details.
|
|
Will perform one forward and one backward on inputs for the microbatches in steady
|
|
state and supports multiple stages per rank. Uses the backward for weights to fill in
|
|
the pipeline bubble.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
stages: List[_PipelineStageBase],
|
|
n_microbatches: int,
|
|
loss_fn: Optional[Callable] = None,
|
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
):
|
|
super().__init__(
|
|
stages=stages,
|
|
n_microbatches=n_microbatches,
|
|
loss_fn=loss_fn,
|
|
args_chunk_spec=args_chunk_spec,
|
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
|
output_merge_spec=output_merge_spec,
|
|
enable_zero_bubble=True,
|
|
)
|
|
|
|
|
|
def get_schedule_class(schedule_name: str):
|
|
"""
|
|
Maps a schedule name to its corresponding class object.
|
|
|
|
Args:
|
|
schedule_name (str): The name of the schedule.
|
|
"""
|
|
schedule_map = {
|
|
"1F1B": Schedule1F1B,
|
|
"Interleaved1F1B": ScheduleInterleaved1F1B,
|
|
"GPipe": ScheduleGPipe,
|
|
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
|
|
"LoopedBFS": ScheduleLoopedBFS,
|
|
"InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
|
|
"PipelineScheduleSingle": PipelineScheduleSingle,
|
|
"PipelineScheduleMulti": PipelineScheduleMulti,
|
|
}
|
|
if schedule_name not in schedule_map:
|
|
raise ValueError(f"Unknown schedule name: {schedule_name}")
|
|
return schedule_map[schedule_name]
|