I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
import torch.distributed.tensor._ops # force import all built-in dtensor ops
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401
from torch.distributed.tensor._api import (
distribute_module,
distribute_tensor,
DTensor,
empty,
full,
ones,
rand,
randn,
zeros,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
from torch.optim.optimizer import (
_foreach_supported_types as _optim_foreach_supported_types,
)
from torch.utils._foreach_utils import (
_foreach_supported_types as _util_foreach_supported_types,
)
# All public APIs from dtensor package
__all__ = [
"DTensor",
"distribute_tensor",
"distribute_module",
"Shard",
"Replicate",
"Partial",
"Placement",
"ones",
"empty",
"full",
"rand",
"randn",
"zeros",
]
# Append DTensor to the list of supported types for foreach implementation for optimizer
# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
if DTensor not in _optim_foreach_supported_types:
_optim_foreach_supported_types.append(DTensor)
if DTensor not in _util_foreach_supported_types:
_util_foreach_supported_types.append(DTensor)
# Set namespace for exposed private names
DTensor.__module__ = "torch.distributed.tensor"
distribute_tensor.__module__ = "torch.distributed.tensor"
distribute_module.__module__ = "torch.distributed.tensor"
ones.__module__ = "torch.distributed.tensor"
empty.__module__ = "torch.distributed.tensor"
full.__module__ = "torch.distributed.tensor"
rand.__module__ = "torch.distributed.tensor"
randn.__module__ = "torch.distributed.tensor"
zeros.__module__ = "torch.distributed.tensor"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,373 @@
# mypy: allow-untyped-defs
import logging
import math
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
broadcast,
get_global_rank,
get_group_rank,
get_rank,
GroupMember,
ProcessGroup,
scatter,
Work,
)
logger = logging.getLogger(__name__)
if not torch._running_with_deploy():
@torch.library.register_fake("_dtensor::shard_dim_alltoall")
def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name):
group_size = _get_group_size_by_name(group_name)
stacked_list = [torch.empty_like(input) for _ in range(group_size)]
group = _resolve_process_group(group_name)
group_rank = get_group_rank(group, get_rank())
return torch.cat(stacked_list, dim=gather_dim).chunk(group_size, dim=shard_dim)[
group_rank
]
else:
import warnings
warnings.warn(
"PyTorch Distributed functional collectives do not work with torch::deploy."
)
def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
if mesh.device_type == "cpu":
# Gloo does not support alltoall, so falling back to allgather + chunk
# TODO: This logs way too much
logger.warning(
"CPU process group does not support alltoall yet, falling back with allgather + chunk!"
)
out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim))
if isinstance(out, funcol.AsyncCollectiveTensor):
# stick to the same behavior for the alltoall case, remove this once we enable alltoall async
out = out.wait()
out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[
mesh.get_local_rank(mesh_dim)
]
return out.contiguous() if not out.is_contiguous() else out
group_name = funcol._resolve_group_name((mesh, mesh_dim))
# TODO: enable async op for shard_dim_alltoall
return torch.ops._dtensor.shard_dim_alltoall(
input, gather_dim, shard_dim, group_name
)
def mesh_scatter(
output: torch.Tensor,
scatter_list: List[torch.Tensor],
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
scatter a list of tensors to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
2 to rank 2/3.
Args:
output (torch.Tensor): the tensor to receive the scattered list.
scatter_list (List[torch.Tensor]): the tensor list to be scattered.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if output.is_meta:
return None
dim_group = mesh.get_group(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
if src_for_dim == get_rank():
fut = scatter(
output,
scatter_list=scatter_list,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
else:
fut = scatter(
output,
scatter_list=None,
src=src_for_dim,
group=dim_group,
async_op=async_op,
)
return fut
def mesh_broadcast(
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
"""
broadcast the tensor to a device mesh dimension. We by default
use the first rank of the mesh dimension as the source of truth, i.e
for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
to rank 2/3.
Args:
tensor (torch.Tensor): tensor to broadcast.
mesh_dim (int, optional): indicate which mesh dimension we want
to scatter on, we by default choose the first rank on the
mesh dimension as source of truth.
Returns:
A :class:`Work` object
"""
# TODO: Ideally we should use the meta tensor way
# (to register a meta kernel for the collective op)
# so that it would avoid the communication. Need to
# remove the check below once that is done.
if tensor.is_meta:
return None
dim_group = mesh.get_group(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
# src need to be global rank
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, 0)
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
pad = [0, 0] * (tensor.ndim - pad_dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor
return tensor.narrow(
pad_dim,
start=0,
length=tensor.size(pad_dim) - pad_size,
)
def fill_empty_tensor_to_shards(
shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
) -> List[torch.Tensor]:
if num_empty_tensors == 0:
return shards
tensor_size = list(shards[0].size())
tensor_size = [
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
]
tensor = shards[0].new_zeros(tensor_size)
for _ in range(num_empty_tensors):
shards.append(tensor)
return shards
def check_tensor_meta(
local_tensor, check_shape_stride=False
) -> Optional["dtensor_spec.TensorMeta"]:
local_metadata = {
"dtype": local_tensor.dtype,
"requires_grad": local_tensor.requires_grad,
}
if check_shape_stride:
local_metadata.update(
{"shape": local_tensor.shape, "stride": local_tensor.stride()}
)
gathered_metadata = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_metadata, local_metadata)
# Check if metadata is consistent across ranks
if not all(meta == local_metadata for meta in gathered_metadata):
raise ValueError(
"Inconsistent tensor metadata (including shape and stride) across ranks."
)
return None
def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int:
assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
@dataclass
class MeshTopoInfo:
"""
Mesh information for collective cost estimation
"""
mesh: DeviceMesh
mesh_dim_devices: List[int]
mesh_dim_bandwidth: List[float]
mesh_dim_latency: List[float]
@staticmethod
@lru_cache(None)
def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
# Generate mesh topology info for intra-host/inter-host communication pattern
# Note that we made bunch of assumptions for simplicity:
# 1. we assume the mesh is homogeneous, and it's gpu/nccl model
# 2. we assume gpu arch is Ampere or Hopper
# 3. we assume collectives are all ring base algo for now
num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type)
# the base bw number (intra-node), GB/s
base_bw = 87.7
mesh_dim_bandwidth = [base_bw] * mesh.ndim
# the latency in terms of us (intra-node, nv-link)
mesh_dim_latency = [0.6] * mesh.ndim
mesh_dim_devices = [1] * mesh.ndim
total_num_devices = 1
for mesh_dim in reversed(range(mesh.ndim)):
num_devices = mesh.size(mesh_dim)
mesh_dim_devices[mesh_dim] = num_devices
total_num_devices *= num_devices
if total_num_devices > num_devices_per_host:
# magic number for inter-host communication bandwidth/latency factor
# This number assumes latest GPU arch, i.e. Ampere or Hopper
# TODO: see if we need to tweak this or offer a way for user
# to specify the bandwidths/latency
mesh_dim_bandwidth[mesh_dim] *= 0.22
# set to ethernet latency for inter-host
mesh_dim_latency[mesh_dim] = 2.7
return MeshTopoInfo(
mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency
)
def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
return latency + bw * 1e6 # rescale to us
def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
# allreduce have almost 2x comm bytes compare to allgather/reduce_scatter
num_hops = 2 * num_devices_on_mesh_dim - 1
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
return latency + bw * 1e6
def reduce_scatter_cost(
bytes_gb: float,
mesh_topo: MeshTopoInfo,
mesh_dim: int,
) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim]
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth
return latency + bw * 1e6
def redistribute_cost(
current_spec: "dtensor_spec.DTensorSpec",
target_spec: "dtensor_spec.DTensorSpec",
) -> float:
"""
This function returns the cost of redistribute from current to target DTensorSpec.
NOTE:
1. Only consider communication cost here, since computation costs for redistribute
are quite trival (i.e. we only need to narrow or simple division)
2. Only consider redistribute cost on same mesh, cross mesh communication cost is
not quite needed for operator strategy estimation/selection.
"""
if current_spec.mesh != target_spec.mesh:
# make infinite cost if meshes are not same
# TODO: see if we want to support this once there's cross mesh communication
return float("inf")
if current_spec.is_replicated():
# short-cut:
# comm cost is 0 if current spec is already full replication
return 0.0
mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
cost = 0.0
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
)
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
for i, (current, target) in enumerate(
zip(current_spec.placements, target_spec.placements)
):
if current == target:
continue
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
if current.is_shard() and target.is_replicate():
# allgather gives larger comm bytes
comm_bytes_gb *= num_devices_on_mesh_dim
# add up allgather comm cost
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_shard() and target.is_shard():
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0
elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_partial() and target.is_shard():
# add up reduce_scatter comm cost
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
# after reduce_scatter the comm bytes for further collectives halved.
comm_bytes_gb /= num_devices_on_mesh_dim
elif current.is_shard() and target.is_partial():
# ban shard -> partial as it does not make sense to perform
# this redistribute
return float("inf")
return cost

View File

@ -0,0 +1,510 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import functools
import logging
import operator
import warnings
from typing import cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
import torch
import torch.distributed as dist
import torch.distributed.tensor._api as dtensor
import torch.distributed.tensor._random as random
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
_is_inplace_op,
_is_out_variant_op,
OpInfo,
OpSchema,
OutputSpecType,
)
from torch.distributed.tensor._random import is_rng_supported_mesh
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor._sharding_prop import ShardingPropagator
from torch.distributed.tensor._tp_conv import (
convolution_backward_handler,
convolution_handler,
)
from torch.distributed.tensor._utils import try_find_mesh_from_args
from torch.distributed.tensor.placement_types import Partial, Placement, Replicate
if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree # type: ignore[no-redef]
aten = torch.ops.aten
logger = logging.getLogger(__name__)
def decompose_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
"""
Decomposes a op to core ATen op, this handler is mostly here
for inference mode usage where the ops are not core aten ops.
"""
r = op_call.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
else:
raise RuntimeError("Decomposition failed")
def is_same_size_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> bool:
lhs = cast(torch.Tensor, args[0])
rhs = cast(torch.Tensor, args[1])
return lhs.shape == rhs.shape
def found_inf_reduce_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> None:
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
local_tensor_args = pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec
)
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
grad_dtensor = cast(list[dtensor.DTensor], args[0])[0]
grad_placements = grad_dtensor.placements
mesh = grad_dtensor.device_mesh
found_inf_placements: list[Placement] = []
for placement in grad_placements:
if isinstance(placement, Replicate):
found_inf_placements.append(placement)
else:
found_inf_placements.append(Partial("max"))
target_tensor = cast(torch.Tensor, args[1])
spec = DTensorSpec(
mesh=mesh,
placements=tuple(found_inf_placements),
tensor_meta=TensorMeta(
shape=target_tensor.size(),
stride=target_tensor.stride(),
dtype=target_tensor.dtype,
),
)
found_inf_dtensor = dtensor.DTensor(
local_tensor=target_tensor, spec=spec, requires_grad=False
)
found_inf = found_inf_dtensor.full_tensor()
target_tensor.copy_(found_inf)
class OpDispatcher:
"""
Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding
propagation, redistribute local args, local compute, and post-processing (re-wrapping). It
also handles any op specific logic if necessary.
NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher
is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster
pytree if needed, and leveraging various caching mechanisms implemented in the sharding
propagation and redistribute modules. The CPU overhead is critical to eager mode performance,
one need to carefully measure the CPU overhead when making significant changes to the
OpDispatcher and ShardingPropagator.
"""
def __init__(self) -> None:
self.sharding_propagator = ShardingPropagator()
self._random_ops = {
aten.native_dropout.default,
aten.normal_.default,
aten.rand_like.default,
aten.randn_like.default,
aten.randint_like.default,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
aten.uniform_.default,
aten.bernoulli.default,
aten.bernoulli_.float,
}
self._custom_op_handlers = {
aten.linear.default: decompose_handler,
aten.is_same_size.default: is_same_size_handler,
aten.convolution.default: convolution_handler,
aten.convolution_backward.default: convolution_backward_handler,
aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler,
}
# This flag is used internally to control whether we treat the torch.Tensor(non-DTensor)
# as implicitly replicated or we throw error to user.
# NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave
# it as False by default.
self._allow_implicit_replication = False
def dispatch(
self,
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
"""
Main dispatching logic
"""
# operators that does not need to go through sharding propagation
if op_call in self._custom_op_handlers:
return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator]
# extract local tensor and sharding infos to a OpInfo
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
logger.debug("Dispatching op_call: %s", op_info.schema)
self.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
logger.debug("output_sharding for %s: %s", op_call, output_sharding)
assert output_sharding is not None, "output sharding should not be None"
mesh = op_info.mesh
if mesh.get_coordinate() is not None:
# computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
# If sharding propagation decision needs redistribute, perform redistribute
# on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None
self.redistribute_local_args(
op_info, output_sharding.redistribute_schema
)
local_tensor_args = (
pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
)
# run local op computation with potentially modified args/kwargs
local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
if op_call in self._random_ops:
if not random._rng_tracker and is_rng_supported_mesh(mesh):
# Default to `OffsetBasedRNGTracker` if the parallelism API
# did not already construct one
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
torch.Tensor, local_tensor_args[0]
)
rng_context = (
random._rng_tracker._distribute_region(first_arg._spec)
if random._rng_tracker and not first_local_arg.is_meta
else contextlib.nullcontext()
)
# For DTensor random operator, run it within a RNGTracker context to
# ensure the random number generator is properly distributed.
with rng_context:
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
# normal case, run local sharded op computation
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
# For a non-participating device (happens on rank that does not belong to
# the device mesh), we do:
# 1. if the return type is scalar, set the local result to None.
# 2. if the return type is Tensor or List[Tensor], return empty
# tensor(s) with correct dtype.
spec = output_sharding.output_spec
ret_list = op_info.schema.op._schema.returns
if spec is None:
# For a scalar return type, the non-participating device has None
# as its local result
local_results = None
else:
def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if spec.tensor_meta is not None:
shape = spec.tensor_meta.shape
dtype = spec.tensor_meta.dtype
if len(shape) == 0:
# scalar tensor
return torch.zeros((), dtype=dtype)
else:
# non-scalar tensor
return torch.tensor([], dtype=dtype)
else:
raise RuntimeError(f"{spec} has no tensor metadata.")
if isinstance(spec, DTensorSpec):
# return a Tensor value
local_results = default_tensor(spec)
elif isinstance(spec, Sequence):
# return a List[Tensor] value
local_results = [
default_tensor(s) if s is not None else None for s in spec
]
assert isinstance(local_results, List)
if None in local_results:
ret_type = str(ret_list[0].type)
raise NotImplementedError(
f"return type {ret_type} in DTensor op is not supported"
)
if output_sharding.output_spec is None:
if op_call == aten.equal.default:
# For equal operator, The local results from all devices should be all-gathered
# and a reduce op (AND) will be performed on the list of results to ensure SPMD
# execution. We can extend this for more ops if necessary.
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
obj_list = list(filter(lambda x: x is not None, obj_list))
# perform reduce on the collection with AND op
local_results = functools.reduce(operator.and_, obj_list, True)
if _is_inplace_op(op_call):
# inplace op should return self instead of re-wrapping
if output_sharding.output_spec is not None:
return args[0]
else:
return None
elif _is_out_variant_op(op_call):
# out variant could possibly have multiple out args (i.e. lu_unpack.out)
output_specs = (
(output_sharding.output_spec,)
if not isinstance(output_sharding.output_spec, tuple)
else output_sharding.output_spec
)
out_dts = []
spec_idx = 0
for argument in op_call._schema.arguments:
if argument.is_out:
out_dt = cast(dtensor.DTensor, kwargs[argument.name])
out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
out_dts.append(out_dt)
spec_idx += 1
assert len(out_dts) >= 1, "out variant should have at least one out arg"
return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
else:
return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined]
@staticmethod
def redistribute_local_args(
op_info: OpInfo,
suggested_input_schema: OpSchema,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
if op_info.args_tree_spec is not None:
flatten_args_schema_to_reshard = tuple(
pytree.tree_leaves(suggested_input_schema.args_schema)
)
else:
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
new_local_args: List[object] = []
for i, arg_spec in enumerate(op_info.flat_args_schema):
reshard_arg_spec = flatten_args_schema_to_reshard[i]
if isinstance(arg_spec, DTensorSpec):
local_tensor = cast(torch.Tensor, op_info.local_args[i])
if arg_spec != reshard_arg_spec:
resharded_local_tensor = redistribute_local_tensor(
local_tensor, arg_spec, reshard_arg_spec
)
new_local_args.append(resharded_local_tensor)
else:
new_local_args.append(local_tensor)
else:
new_local_args.append(reshard_arg_spec)
op_info.local_args = tuple(new_local_args)
def unwrap_to_op_info(
self,
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> OpInfo:
# get runtime schema info to determine whether to use pytree to flatten inputs
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
op_call, None
)
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
# flatten args/kwargs when op says necessary
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
args_list, args_spec = args, None
args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
local_args.append(arg._local_tensor)
if mesh is not None and mesh != arg.device_mesh:
# TODO: try replicate dtensor spec in missing dimension would work
# for most cases for foreach case except when the first DTensor in
# the list is one that also need to be replicated. We need to revisit
# how we want to handle this corner case. For now, this case would hit
# the cross mesh error even if implicit replication is turned on.
spec = self._try_replicate_dtensor_spec_in_missing_dim(
op_call, arg, mesh
)
args_schema.append(spec)
else:
mesh = arg.device_mesh
args_schema.append(arg._spec)
elif isinstance(arg, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
args_schema.append(
self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
)
local_args.append(arg)
else:
args_schema.append(arg)
local_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
local_kwargs[k] = v._local_tensor
if mesh is not None and mesh != v.device_mesh:
spec = self._try_replicate_dtensor_spec_in_missing_dim(
op_call, v, mesh
)
kwargs_schema[k] = spec
else:
mesh = v.device_mesh
kwargs_schema[k] = v._spec
elif isinstance(v, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
op_call, v, mesh
)
local_kwargs[k] = v
else:
kwargs_schema[k] = v
local_kwargs[k] = v
assert mesh is not None, f"found no DeviceMesh from dtensor args for {op_call}!"
op_info = OpInfo(
mesh,
OpSchema(
op_call,
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema),
kwargs_schema,
schema_info=runtime_schema_info,
),
args_schema,
tuple(local_args),
local_kwargs,
args_spec,
)
return op_info
@staticmethod
def wrap(res: object, spec: OutputSpecType) -> object:
if isinstance(res, torch.Tensor):
if spec is not None:
assert isinstance(
spec, DTensorSpec
), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
assert res.ndim == 0, "output tensor should be scalar!"
return res
elif isinstance(res, (list, tuple)):
assert spec is not None and isinstance(
spec, (list, tuple)
), f"output spec does not match with output! Expected list/tuple, got {spec}."
res_list = []
for e, s in zip(res, spec):
res_list.append(OpDispatcher.wrap(e, s))
return tuple(res_list) if isinstance(res, tuple) else res_list
else:
# if the res contains only non tensor values (i.e. int/float/none), we simply return it
# without rewrapping to DTensor.
return res
def _try_replicate_spec_for_scalar_tensor(
self,
op_call: torch._ops.OpOverload,
tensor_arg: torch.Tensor,
mesh: "DeviceMesh",
) -> DTensorSpec:
# util function to produce a replicate spec for a scalar tensor arg/kwarg
if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
warnings.warn(
"Found a non-scalar tensor with numel=1 and ndim!=0, "
"we are implicitly creating a replicated DTensor for it. "
"However, please consider changing it to a scalar tensor "
"or explicitly create a DTensor under distributed enviroment."
)
if tensor_arg.numel() == 1 or self._allow_implicit_replication:
# scalar tensor can be safely treated as replicated
replication_spec = DTensorSpec(
mesh,
(Replicate(),) * mesh.ndim,
tensor_meta=TensorMeta(
shape=tensor_arg.shape,
stride=tensor_arg.stride(),
dtype=tensor_arg.dtype,
),
)
else:
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
)
return replication_spec
def _try_replicate_dtensor_spec_in_missing_dim(
self,
op_call: torch._ops.OpOverload,
dtensor_arg: "dtensor.DTensor",
mesh: "DeviceMesh",
) -> DTensorSpec:
# util function to produce a new spec for a DTensor arg/kwarg
# that puts Replicate() placement in the missing dimension for foreach ops
from torch.distributed.device_mesh import _mesh_resources
cur_mesh = dtensor_arg.device_mesh
root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
if (
self._allow_implicit_replication
and "foreach" in op_call.__name__
and root_mesh == mesh
):
placements = [Replicate() for _ in range(root_mesh.ndim)]
cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload]
replicate_spec = DTensorSpec(
root_mesh,
tuple(placements),
tensor_meta=TensorMeta(
shape=dtensor_arg.shape,
stride=dtensor_arg.stride(),
dtype=dtensor_arg.dtype,
),
)
else:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet! "
f"Got meshes: {mesh} {cur_mesh}"
)
return replicate_spec

View File

@ -0,0 +1,276 @@
from dataclasses import dataclass
from typing import Any, cast, List, NamedTuple, Optional, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
class TensorMeta(NamedTuple):
# simple named tuple to represent tensor metadata
# intentionally to stay simple only for sharding
# propagation purposes.
shape: torch.Size
stride: Tuple[int, ...]
dtype: torch.dtype
# used internally to propagate the placements
@dataclass
class DTensorSpec:
mesh: DeviceMesh
placements: Tuple[Placement, ...]
# tensor meta will only be set during sharding propagation
tensor_meta: Optional[TensorMeta] = None
def __post_init__(self) -> None:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
self._hash: Optional[int] = None
def __setattr__(self, attr: str, value: Any) -> None:
super().__setattr__(attr, value)
# Make sure to recompute the hash in case any of the hashed attributes
# change (though we do not expect `mesh` or `placements` to change)
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
self._hash = None
def _hash_impl(self) -> int:
# hashing and equality check for DTensorSpec are used to cache the sharding
# propagation results. We only need to consider the mesh, placements, shape
# dtype and stride.
# Caveat: we need to keep this in mind and sync hash and eq if we add more
# fields to them.
if self.tensor_meta is not None:
return hash(
(
self.mesh,
self.placements,
self.tensor_meta.shape,
self.tensor_meta.stride,
self.tensor_meta.dtype,
)
)
return hash((self.mesh, self.placements))
def __hash__(self) -> int:
# We lazily cache the spec to avoid recomputing the hash upon each
# use, where we make sure to update the hash when the `tensor_meta`
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
# does not try to hash non-singleton `SymInt`s for the stride.
if self._hash is None:
self._hash = self._hash_impl()
return self._hash
def __eq__(self, __o: object) -> bool:
if not (
isinstance(__o, DTensorSpec)
and self.mesh == __o.mesh
and self.placements == __o.placements
):
return False
if self.tensor_meta is None or __o.tensor_meta is None:
return self.tensor_meta == __o.tensor_meta
return (
self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr]
and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr]
and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr]
)
def __str__(self) -> str:
"""
human readable representation of the DTensorSpec
"""
if len(self.placements) == 1:
placement_str = str(self.placements[0])
else:
placement_str = str(self.placements)
if self.tensor_meta is not None:
tensor_shape = str(tuple(self.tensor_meta.shape))
else:
tensor_shape = "unknown shape"
return f"Spec({placement_str} on {tensor_shape})"
@property
def shape(self) -> torch.Size:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return self.tensor_meta.shape
@property
def stride(self) -> Tuple[int, ...]:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return self.tensor_meta.stride
@property
def ndim(self) -> int:
if self.tensor_meta is None:
raise ValueError("tensor_meta is not set")
return len(self.tensor_meta.shape)
@property
def num_shards(self) -> int:
num_shards = 1
for i, placement in enumerate(self.placements):
if placement.is_shard():
num_shards *= self.mesh.size(i)
return num_shards
@property
def device_mesh(self) -> DeviceMesh:
# simple aliasing for the mesh field, make some
# checks that mixes DTensor/DTensorSpec easier
return self.mesh
@property
def dim_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. It simply return a list of ints
where dim_map[i] denotes the sharding mapping to the mesh
dimension, and len(dim_map) == dist_tensor.ndim
dim_map[i] = -1: means tensor dim i replicate on mesh
dim_map[i] = j: means tensor dim i shard on mesh dim j
For example, we have a dist tensor that have the shape of
[18, 20, 30], and device_mesh([0, 1, 2, 3]), placements:
[Shard(1)], the dim_map of this placement would be:
[-1, 0, -1]. This representation is pretty helpful during
sharding propagation where we could know exactly each
tensor dimension is sharded or not.
Note that if placements contains `_Partial`, we have to
explicitly deal with it, so that when we create a DTensorSpec
with dim_map, we could properly record the pending sums.
"""
# dims mapping of dist tensor sharding
# return size of tensor ndim, -1 represent replicate
# and int >=0 represent shard on that device mesh dim
r = [-1] * self.ndim
for i, placement in enumerate(self.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
if r[shard_dim] > -1:
raise ValueError(
f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]},"
" DTensor operator implementation does not support things like hybrid"
" sharding strategies yet (i.e. [Shard(0), Shard(0)])"
)
r[shard_dim] = i
return r
@property
def num_shards_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. Unlike `dim_map`, `num_shards_map`
denotes how many shards each tensor dim has. Like `dim_map`:
len(num_shards_map) == dist_tensor.ndim
num_shards_map[i] = 1: means tensor dim i is not sharded
num_shards_map[i] = j: means tensor dim i has j shards in total
For example, we have a dist tensor of shape [18, 20, 30],
a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
would be: [4, 2, 1].
"""
r = [1] * self.ndim
for i, placement in enumerate(self.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
r[shard_dim] *= self.mesh.size(i)
return r
@property
def sums(self) -> List[int]:
"""
sums is a property we derive from `placements` of the
distributed tensor. It simply return a list of ints where
sums[i] denotes the pending sum (partial) on mesh dim i
"""
return [
idx
for idx, placement in enumerate(self.placements)
if placement.is_partial()
]
@classmethod
def from_dim_map(
cls,
mesh: DeviceMesh,
dim_map: List[int],
sums: List[int],
tensor_meta: Optional[TensorMeta] = None,
) -> "DTensorSpec":
"""
Construct a DTensorSpec from dim_map list and pending sum.
Args:
mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec
dim_map (List[int]): a list of integer that represents sharding on each
tensor dimension, see `dim_map` property doc for details
sums (List[int]): a list of integer that represents the dist tensor have
pending sum on which device mesh dimension.
tensor meta (TensorMeta): DTensor metadata
Return:
a class:`DTensorSpec` object
"""
# by default replicate on device mesh dims
placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
# find all mesh dims that need pending reductions
for s in sums:
placements[s] = Partial()
for i, m in enumerate(dim_map):
if m >= 0:
placement = placements[m]
if placement.is_shard():
placement = cast(Shard, placement)
raise RuntimeError(
f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}"
)
elif placement.is_partial():
raise RuntimeError(
f"DeviceMesh dimension {m} cannot be both shard and partial!"
)
placements[m] = Shard(i)
return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
def is_replicated(self) -> bool:
"""
return True if the current DTensorSpec replicates on all mesh dims (devices)
"""
return all(placement.is_replicate() for placement in self.placements)
def is_sharded(self) -> bool:
"""
return True if the current DTensorSpec is sharded on any mesh dims (devices)
"""
return any(placement.is_shard() for placement in self.placements)
def shallow_copy_with_tensor_meta(
self, tensor_meta: Optional[TensorMeta]
) -> "DTensorSpec":
"""
Shallow copy the DTensorSpec with a new tensor_meta.
"""
assert tensor_meta is not None, "shallow copy with no tensor_meta!"
return DTensorSpec(
self.mesh,
self.placements,
tensor_meta=tensor_meta,
)

View File

@ -0,0 +1,457 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch._ops import OpOverload
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import Placement
try:
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
except ImportError:
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
tree_leaves,
tree_map_only,
TreeSpec,
)
# Common type aliases
ArgsType = Tuple[object, ...]
KwargsType = Dict[str, object]
PlacementList = List[Optional[Placement]]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
# be the same set of possibilities.
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
"""
This is used to propagate tensor metadata, must be under fake mode
"""
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
return torch.empty_strided(
arg.tensor_meta.shape,
arg.tensor_meta.stride,
dtype=arg.tensor_meta.dtype,
)
def _is_inplace_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an inplace variant, it might not
# be entirely correct, but it's good enough for now.
return op._schema.name[-1] == "_"
def _is_out_variant_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an out variant, it might not
# be entirely correct, but it's good enough for now.
return "out" in op._schema.overload_name
def _pretty_print_spec(spec: object) -> str:
if spec is None:
return "None"
elif isinstance(spec, DTensorSpec):
return "".join([str(p) for p in spec.placements])
elif isinstance(spec, Sequence):
return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
else:
raise RuntimeError(f"Unknown spec type to print: spec={spec}")
@dataclass
class PlacementStrategy:
"""
A placement strategy describes acceptable sharding placements of the output
and the tensor arguments of an operation.
note: when the op return value is a single DTensor object, output_specs is
DTensorSpec; when the return value is a tuple of Optional[DTensor],
output_specs is a tuple of Optional[DTensorSpec].
"""
output_specs: Union[DTensorSpec, Tuple[Optional[DTensorSpec], ...]]
input_specs: Optional[Sequence[DTensorSpec]] = None
# redistribute costs for this op placement strategy
# we need a nested list to record the cost for each
# operand of this operator, and for each operand of
# this operator it might have multiple placement strategies
redistribute_cost: Optional[List[List[float]]] = None
@cached_property
def output_spec(self) -> DTensorSpec:
"""
This function requires that the strategy have exactly one DTensorSpec as the
output spec. If the output_specs is a tuple, we throw an exception.
"""
if isinstance(self.output_specs, DTensorSpec):
return self.output_specs
else:
raise ValueError(
f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
)
def input_spec(self, index: int = 0) -> DTensorSpec:
assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
assert len(self.input_specs) > index, (
f"Invalid index {index} for input_specs of length "
f"{len(self.input_specs)}: {self.input_specs}"
)
return self.input_specs[index]
def __str__(self) -> str:
if self.input_specs is not None:
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
else:
input_specs_str = ""
output_spec_str = _pretty_print_spec(self.output_specs)
return f"{input_specs_str}{output_spec_str}"
class StrategyType:
"""
Base class type for op strategy, We have two StrategyType:
OpStrategy and TupleStrategy
"""
class OpStrategy(StrategyType):
"""
OpStrategy that consists of a list of placement strategies associated with the op
"""
def __init__(self, strategies: List[PlacementStrategy]) -> None:
super().__init__()
self.strategies: List[PlacementStrategy] = strategies
def __str__(self) -> str:
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
mesh_shape = self.mesh_shape
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
def max_num_shards(self) -> int:
"""
Returns the max number of shards across all placement strategies
"""
return max(strategy.output_spec.num_shards for strategy in self.strategies)
@property
def mesh_shape(self):
output_spec = self.strategies[0].output_specs
if isinstance(output_spec, DTensorSpec):
return output_spec.mesh.shape
else:
assert isinstance(
output_spec, tuple
), "found no DTensorSpec in the OpStrategy!"
assert output_spec[0] is not None
return output_spec[0].mesh.shape
@property
def ndim(self):
return self.strategies[0].output_spec.ndim
@property
def shape(self):
return self.strategies[0].output_spec.shape
class TupleStrategy(StrategyType):
"""
TupleStrategy represents the output strategy of this op is a tuple
of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
with possibly different placement strategies, we should return a TupleStrategy that
contains a tuple of OpStrategy, where each child represents the sharding strategy
of "each element" of the tuple/list of tensors the op returns.
NOTE: if the output of the op is a List[Tensor] and they share the same placement
strategy, then we should return a single OpStrategy instead of a TupleStrategy
"""
def __init__(self, childs: Sequence[StrategyType]) -> None:
super().__init__()
self.childs: Sequence[StrategyType] = childs
def __str__(self) -> str:
child_strategies_str = ", ".join(
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
)
return f"TupleStrategy({child_strategies_str})"
@dataclass
class RuntimeSchemaInfo:
"""
RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
execution. This is mainly used for two ways: 1. to generate hash for args to determine
whether to re-run sharding prop or not 2. to determine if we need pytree
"""
# This static_argnum records static arg "starting index" for ops that have non-tensor
# args/kwargs which would affect sharding propagation results. All args starting from
# this index would be hashed to our sharding cache.
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
static_argnum: int = 100
# This static_kwargkey records static kwarg names which would affect sharding prop
static_kwargkey: Optional[List[str]] = None
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelerate eager performance.
needs_pytree: bool = False
@dataclass
class OpSchema:
"""
OpSchema is a data class that describes an operator input schemas, it includes
DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
preserved). It is mainly used by the DTensor's dispatching logic to perform various
actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
NOTE: this should be used as a read only data class
TODO: make this a frozen dataclass
Args:
op: the operator overload we are intercepting
args_schema: contains args except that the DTensor args have been replaced
with its DTensorSpec or OpStrategy
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
with its DTensorSpec or OpStrategy
"""
op: OpOverload
args_schema: ArgsType
kwargs_schema: KwargsType
schema_info: Optional[RuntimeSchemaInfo] = None
@property
def args_spec(self) -> Tuple[DTensorSpec, ...]:
"""
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
mainly used by sharding propagation to propagate the output spec
"""
args = (
tree_leaves(self.args_schema)
if self.schema_info is not None and self.schema_info.needs_pytree
else self.args_schema
)
return tuple(item for item in args if isinstance(item, DTensorSpec))
@property
def args_strategy(self) -> Tuple[OpStrategy, ...]:
# filter out non-relevant values from args schema to get a clean OpStrategy list
# separate with args_spec for the ease of type annotation
# TODO: see if we should merge this with args_spec
args = (
tree_leaves(self.args_schema)
if self.schema_info is not None and self.schema_info.needs_pytree
else self.args_schema
)
return tuple(item for item in args if isinstance(item, OpStrategy))
def __repr__(self) -> str:
args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
return (
f"OpSchema(op={self.op},"
f" args_schema=({args_schema}),"
f" kwargs_schema={self.kwargs_schema})"
)
def __str__(self) -> str:
args_schema: List[str] = []
mesh_shape = None
for arg in self.args_schema:
if isinstance(arg, DTensorSpec):
args_schema.append(str(arg))
mesh_shape = arg.mesh.shape
elif isinstance(arg, OpStrategy):
assert len(arg.strategies) == 1
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
mesh_shape = arg.mesh_shape
elif isinstance(arg, TupleStrategy):
first_op_strtgy = arg.childs[0]
assert isinstance(first_op_strtgy, OpStrategy)
mesh_shape = first_op_strtgy.mesh_shape
args_schema.append(str(arg))
else:
args_schema.append(str(arg))
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
def __post_init__(self) -> None:
has_symints = False
for a in self.args_schema:
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
has_symints = True
break
self.has_symints = has_symints
def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
arg = self.args_schema[arg_idx]
is_tensor = isinstance(arg, DTensorSpec)
if is_tensor:
return True
if not isinstance(arg, list):
return False
return all(isinstance(e, DTensorSpec) or e is None for e in arg)
def return_type_tuple_tensor_like(self) -> bool:
# all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
# in the tuple, but the first element must be a Tensor, so this check is enough
return_types = self.op._schema.returns
return len(return_types) > 1 and isinstance(
return_types[0].type, torch.TensorType
)
def return_type_tensor(self) -> bool:
return_types = self.op._schema.returns
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
# return types, so this check is enough for tensor like types
return isinstance(return_types[0].type, torch.TensorType)
def __hash__(self) -> int:
# Only hash args and kwargs that op indicates to hash
if not self.schema_info:
static_argnum = len(self.args_schema)
static_kwargkey = None
else:
static_argnum = self.schema_info.static_argnum
static_kwargkey = self.schema_info.static_kwargkey
args_to_hash = tuple(
tuple(e) if isinstance(e, list) else e
for i, e in enumerate(self.args_schema)
if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
)
if static_kwargkey is not None:
kwargs_to_hash = tuple(
self.kwargs_schema.get(k, None) for k in static_kwargkey
)
return hash((self.op, args_to_hash, kwargs_to_hash))
else:
return hash((self.op, args_to_hash))
def __eq__(self, other: object) -> bool:
# early return checks
if not isinstance(other, OpSchema):
return False
if self.op != other.op:
return False
if len(self.args_schema) != len(other.args_schema):
return False
# compare each element and early return if any of them is different
if not self.schema_info:
static_argnum = len(self.args_schema)
static_kwargkey = None
else:
static_argnum = self.schema_info.static_argnum
static_kwargkey = self.schema_info.static_kwargkey
for i, (self_arg, other_arg) in enumerate(
zip(self.args_schema, other.args_schema)
):
if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
return False
elif i >= static_argnum and self_arg != other_arg:
return False
# check kwarg equality when there's a static kwarg key
if static_kwargkey:
for key in static_kwargkey:
if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
key, None
):
return False
return True
def gen_fake_args(self) -> ArgsType:
"""
gen_fake_args: generate fake args for the operator, this is mainly used
by sharding propagation rules to generate fake args for the operator
to run the local tensor operator and get the output spec.
"""
return tree_map_only(
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
)
def gen_fake_kwargs(self) -> KwargsType:
"""
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
by sharding propagation rules to generate fake kwargs for the operator
to run the local tensor operator and get the output spec.
"""
return tree_map_only(
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
)
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
idx_of_args_spec = 0
if (
origin_schema.schema_info is not None
and origin_schema.schema_info.needs_pytree
):
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
else:
args_schema = origin_schema.args_schema
for arg in args_schema:
if isinstance(arg, DTensorSpec):
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
idx_of_args_spec += 1
else:
new_arg_schema.append(arg)
self.args_schema = tuple(new_arg_schema)
self.kwargs_schema = origin_schema.kwargs_schema
@dataclass
class OutputSharding:
"""
OutputSharding is a data class that is used by the sharding propagation,
it could set the output_spec upon successful propagation. If needs_redistribute
is set to True, a redistribute_schema would be returned together to indicate
the input arguments needs to be redistributed before the op execution.
NOTE: the redistribute_schema generated by sharding propagation should be
exactly the same as the operator OpSchema, except the DTensorSpecs
"""
output_spec: OutputSpecType
redistribute_schema: Optional[OpSchema] = None
needs_redistribute: bool = False
@dataclass
class OpInfo:
"""
All Runtime Op execution info are packed here
"""
mesh: DeviceMesh
schema: OpSchema
flat_args_schema: List[object]
local_args: Sequence[object]
local_kwargs: Dict[str, object]
args_tree_spec: Optional[TreeSpec] = None
# the output sharding info
output_sharding: Optional[OutputSharding] = None

View File

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from ._conv_ops import * # noqa: F403
from ._embedding_ops import * # noqa: F403
from ._experimental_ops import * # noqa: F403
from ._math_ops import * # noqa: F403
from ._matrix_ops import * # noqa: F403
from ._pointwise_ops import * # noqa: F403
from ._random_ops import * # noqa: F403
from ._tensor_ops import * # noqa: F403
from ._view_ops import * # noqa: F403

View File

@ -0,0 +1,288 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, Dict, List, Optional, Tuple
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
_is_inplace_op,
_is_out_variant_op,
OpSchema,
OutputSharding,
)
from torch.distributed.tensor._ops.utils import prod
from torch.distributed.tensor._utils import compute_local_shape
def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
return string[:idx] + new_char + string[idx + 1 :]
def _gen_reshard_suggestions(
op_schema: OpSchema,
input_dims: List[str],
input_specs: Tuple[DTensorSpec, ...],
dim_to_sharding: Dict[str, int],
pending_sum: List[int],
) -> OutputSharding:
suggested_arg_specs: List[DTensorSpec] = []
for input_dim, input_spec in zip(input_dims, input_specs):
dim_map = [dim_to_sharding[dim] for dim in input_dim]
suggested_arg_specs.append(
DTensorSpec.from_dim_map(
mesh=input_spec.mesh,
dim_map=dim_map,
sums=pending_sum,
tensor_meta=input_spec.tensor_meta,
)
)
suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {})
suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
None,
redistribute_schema=suggested_schema,
)
def einop_rule(
equation: str,
op_schema: OpSchema,
*,
linearity: bool = False,
enforce_sharding: Optional[Dict[str, int]] = None,
) -> OutputSharding:
"""
Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
This is mostly borrowed from @zdevito's sharding simulator. Examples:
mk,kn->mn - einsum
ij,ij->ij - addition
ij,j->ij - broadcasted addition
ij->i - reduction
Other ops could use this propagation algorithm when applied, note
that einsum propagation only deal with list of specs (DTensor specs)
as it only works on list of tensors!
linearity in einop_rule means that the calling op `f` follows this rule:
f(a + b) = f(a) + f(b)
In this case we can propagate the partial sum, note that linearity in einop
only applies to partial sum, not other operations like min/max (which are
associative but not linear).
"""
# parse einop equation and extract arg specs
inputs, outputs = equation.split("->")
input_dims, output_dims = inputs.split(","), outputs.split(",")
input_specs = op_schema.args_spec
# NOTE: only support single output unless needed in future
output_dim = output_dims[0]
dim_to_sharding: Dict[str, int] = {}
dim_to_size: Dict[str, int] = {}
# record pending sum, key is mesh dimension, value is pending sum
# counter across input specs
pending_sums_counter: Dict[int, int] = {}
seen_shardings: Dict[int, str] = {}
needs_reshard = False
def merge_sharding(dim: str, a: int, b: int) -> int:
# merge the sharding of inputs if it's able to merge, i.e. we can merge
# replicate and shard to shard, but this will trigger an reshard operation
if a != b:
if a == -1 or b == -1:
# reshard the replicate to match the sharded one
nonlocal needs_reshard
needs_reshard = True
return a if a != -1 else b
else:
# TODO: further merge the sharding properly (i.e. reshard one input to replicate)
raise RuntimeError(
f"{equation}: dim {dim} sharded two different ways: {a} and {b}"
)
else:
return a
for input_dim, input_spec in zip(input_dims, input_specs):
# deal with partial sums
input_sums = input_spec.sums
for sum_dim in input_sums:
if sum_dim not in pending_sums_counter:
seen_shardings[sum_dim] = "+"
# update pending sum counter for pending sum mesh
# dimension with the occurrence from each input
pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1
for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)):
if enforce_sharding and dim in enforce_sharding:
if enforce_sharding[dim] != mesh_dim:
needs_reshard = True
dim_to_sharding[dim] = enforce_sharding[dim]
dim_to_size[dim] = input_spec.shape[idx]
elif dim not in dim_to_sharding:
dim_to_sharding[dim] = mesh_dim
dim_to_size[dim] = input_spec.shape[idx]
else:
dim_to_sharding[dim] = merge_sharding(
dim, dim_to_sharding[dim], mesh_dim
)
assert dim_to_size[dim] == input_spec.shape[idx]
# after merging sharding, we check if there're multiple
# sharding on the same mesh dim.
merged_sharding_for_dim = dim_to_sharding[dim]
if merged_sharding_for_dim != -1:
if (
merged_sharding_for_dim in seen_shardings
and dim != seen_shardings[merged_sharding_for_dim]
):
needs_reshard = True
seen_shardings[merged_sharding_for_dim] += dim
else:
seen_shardings[merged_sharding_for_dim] = dim
if pending_sums_counter and not linearity:
# return reshard suggestion with no pending sum, because we already properly
# merge the sharding, this reshard suggestion is legit to use
return _gen_reshard_suggestions(
op_schema, input_dims, input_specs, dim_to_sharding, []
)
else:
# It's a op that support linearity, but not all input arguments are partial
# we fail the sharding propagation with suggestion to make all inputs be
# partial on the corresponding mesh dim (all inputs should be partial for
# the mesh dims in order to execute locally and delay the sum reduction)
for value in pending_sums_counter.values():
if value != len(input_specs):
needs_reshard = True
for mesh_dim, dims in seen_shardings.items():
if len(dims) > 1:
# we found different input dims are being sharded on the same mesh dim
# in order to perform local op computation, we need to reshard inputs
# base on some simple heuristics, now we simply pick the one with least comm
# volume. (i.e. the input with least size)
# TODO: consider a more advanced heuristic to pick the best sharding
costs = []
for d in dims:
cost = 0
for input_dim, input_spec in zip(input_dims, input_specs):
if (
d in input_dim
and input_spec.dim_map[input_dim.index(d)] == mesh_dim
):
assert input_spec.tensor_meta is not None
global_shape = input_spec.tensor_meta.shape
local_shape = compute_local_shape(
global_shape, input_spec.mesh, input_spec.placements
)
cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
costs.append(cost)
d_to_keep_sharding = dims[costs.index(max(costs))]
for d in dims:
# update dim_to_sharding to keep the sharding of the dim with
# highest comm and make the rest of the dims to replicate
if d != d_to_keep_sharding:
dim_to_sharding[d] = -1
pending_sums = list(pending_sums_counter.keys())
if needs_reshard:
return _gen_reshard_suggestions(
op_schema, input_dims, input_specs, dim_to_sharding, pending_sums
)
# generate output pending sum if a dim is sharded, and it appears in input
# but not output
for dim, shard_on_mesh in dim_to_sharding.items():
if dim not in output_dims[0] and shard_on_mesh != -1:
pending_sums.append(shard_on_mesh)
# if no need to reshard, we directly generate the output sharding
output_dim_map = []
output_shape = []
for dim in output_dim:
if dim == "1":
# find output dim that is a singleton dimension, mark sharding and shape
output_dim_map.append(-1)
output_shape.append(1)
else:
output_dim_map.append(dim_to_sharding[dim])
output_shape.append(dim_to_size[dim])
# XXX: since we still need to have intermediate shape calculation, we need
# to pass in the shape here. We should remove this once sharding decomp works
# for ops like addmm
assert input_specs[0].tensor_meta is not None
tensor_meta = TensorMeta(
torch.Size(output_shape),
input_specs[0].tensor_meta.stride,
input_specs[0].tensor_meta.dtype,
)
return OutputSharding(
DTensorSpec.from_dim_map(
input_specs[0].mesh,
output_dim_map,
pending_sums,
tensor_meta=tensor_meta,
)
)
def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding:
"""
Propagate the sharding for pointwise operations.
Examples:
ij,ij->ij - addition/mul
ij,j->ij - broadcasted addition
"""
alphabet = "abcdefghijklmnopqrstuvwxyz"
# find the max_dim first in case we need to broadcasting
input_specs = op_schema.args_spec
max_dim = max(input.ndim for input in input_specs)
dimchars = []
singleton_counter: List[int] = [0] * max_dim
for input in input_specs:
start_dim = max_dim - input.ndim
p = alphabet[start_dim:max_dim]
# handle the "broadcasting to a common shape case"
# see https://pytorch.org/docs/stable/notes/broadcasting.html
# If any of the dimensions is singleton dimension (i.e. 1).
# we mark the dim char as a special "1" to distinguish with
# the non-singleton dimension, so that sharding propagation
# should just ignore the singleton dimension.
if len(input_specs) > 1:
for i in range(max_dim):
if i < start_dim:
# treat the leading miss dim chars as singleton
singleton_counter[i] += 1
elif input.shape[i - start_dim] == 1:
# mark singleton dim char as a special "1" in einop rule
singleton_counter[i] += 1
p = _replace_char_in_str(p, "1", (i - start_dim))
dimchars.append(p)
out_dimchars = alphabet[:max_dim]
# check if we replace the all inputs dim char with singleton dimension,
# if we replace all inputs, we also need to replace the output dimension.
for output_dim_idx in range(len(out_dimchars)):
out_dimchar = out_dimchars[output_dim_idx]
if singleton_counter[output_dim_idx] == len(input_specs):
out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx)
fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}"
enforce_sharding: Dict[str, int] = {}
if _is_inplace_op(op_schema.op):
# inplace op should keep the input sharding it writes to
for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map):
enforce_sharding[out_dimchar] = mesh_dim
elif _is_out_variant_op(op_schema.op):
out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"])
for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map):
enforce_sharding[out_dimchar] = mesh_dim
return einop_rule(
fmt,
op_schema,
linearity=linearity,
enforce_sharding=enforce_sharding,
)

View File

@ -0,0 +1,110 @@
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema, OutputSharding
from torch.distributed.tensor._ops.utils import register_prop_rule
aten = torch.ops.aten
@register_prop_rule(aten.convolution.default)
def convolution_rules(op_schema: OpSchema) -> OutputSharding:
(
input_spec,
weight_spec,
bias_spec,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
) = op_schema.args_schema
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
assert isinstance(bias_spec, DTensorSpec)
assert input_spec.tensor_meta is not None
assert weight_spec.tensor_meta is not None
in_shape = input_spec.tensor_meta.shape
weight_shape = weight_spec.tensor_meta.shape
assert isinstance(stride, List)
assert isinstance(padding, List)
assert isinstance(dilation, List)
assert isinstance(weight_shape, torch.Size)
N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3]
C_out = weight_shape[0]
H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[
0
] + 1
W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[
1
] + 1
output_shape = [N, C_out, H_out, W_out]
output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1)
output_dim_map = input_spec.dim_map
pending_sums = input_spec.sums
tensor_meta = TensorMeta(
torch.Size(output_shape),
output_stride,
input_spec.tensor_meta.dtype,
)
return OutputSharding(
DTensorSpec.from_dim_map(
input_spec.mesh,
output_dim_map,
pending_sums,
tensor_meta=tensor_meta,
)
)
@register_prop_rule(aten.convolution_backward.default)
def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
input_spec = op_schema.args_schema[0]
(
grad_output_spec,
input_spec,
weight_spec,
bias_shape_opt,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
) = op_schema.args_schema
assert isinstance(grad_output_spec, DTensorSpec)
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
assert isinstance(bias_shape_opt, List)
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
bias_tensor_meta = TensorMeta(
torch.Size(bias_shape_opt),
(1,),
input_spec.tensor_meta.dtype,
)
grad_input_spec = input_spec
grad_weight_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1, -1, -1, -1],
[0],
tensor_meta=weight_tensor_meta,
)
grad_bias_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1],
[0],
tensor_meta=bias_tensor_meta,
)
return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec])

View File

@ -0,0 +1,181 @@
import itertools
from dataclasses import dataclass
from typing import List, Set, Tuple
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
@dataclass
class EinsumDims:
contracting_dims: List[str]
batch_dims: List[str]
lhs_out_only_dims: List[str]
rhs_out_only_dims: List[str]
@classmethod
def parse_equation(cls, equation: str) -> Tuple[List[str], str]:
# parse einop equation and extract arg specs
"""
Parse the einsum equation str to input dim chars and output dim char
"""
inputs, outputs = equation.split("->")
input_dims, output_dims = inputs.split(","), outputs.split(",")
# NOTE: only support at most two inputs, and single output
# extend to support more inputs if needed in future
assert len(input_dims) <= 2, "Only support at most two inputs"
assert len(output_dims) == 1, "Only support single output"
output_dim = output_dims[0]
return input_dims, output_dim
@classmethod
def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
"""
Parse the dims and extract the contracting, batch, and free dimensions
for the left and right hand sides.
"""
dim_char_set: Set[str] = set()
for input_dim in input_dims:
dim_char_set.update(input_dim)
# get a determinisitc order of all dim chars
all_dim_chars = sorted(dim_char_set)
# parse input and output dimensions
lhs_out_only_dims, rhs_out_only_dims = [], []
batch_dims, contracting_dims = [], []
for dim_char in all_dim_chars:
if dim_char not in output_dim:
contracting_dims.append(dim_char)
else:
is_batch_dim = True
for input_dim in input_dims:
is_batch_dim = is_batch_dim and dim_char in input_dim
if is_batch_dim:
batch_dims.append(dim_char)
else:
assert (
len(input_dims) == 2
), "free dimension only supported for two inputs!"
lhs, rhs = input_dims
if dim_char in lhs:
lhs_out_only_dims.append(dim_char)
elif dim_char in rhs:
rhs_out_only_dims.append(dim_char)
else:
raise RuntimeError("Invalid dimension character")
return cls(
contracting_dims=contracting_dims,
batch_dims=batch_dims,
lhs_out_only_dims=lhs_out_only_dims,
rhs_out_only_dims=rhs_out_only_dims,
)
def gen_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
all_mesh_dim_strategies = []
# generate strategies for each mesh dim
for mesh_dim in range(mesh.ndim):
mesh_dim_strategies = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)
if mesh.size(mesh_dim) <= 1:
# only replicate strategy for mesh dim with size 1
# TODO: see if this is valid for the submesh case
continue
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
mesh_dim_strategies.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
mesh_dim_strategies.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: List[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
]
mesh_dim_strategies.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: List[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
]
mesh_dim_strategies.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: List[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)
all_mesh_dim_strategies.append(mesh_dim_strategies)
# generate strategies for entire mesh
strategy_combs = itertools.product(*all_mesh_dim_strategies)
# TODO: filter out invalid strategies, at this point we generate
# all possible strategies without considering the whether the tensor
# dim could be sharded or not, we would need to filter out invalid
# strategies base on the actual tensor shape
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = []
for specs in zip(*strategy_comb):
spec_list.append(DTensorSpec(mesh, tuple(specs)))
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
all_strategies.append(strat)
return OpStrategy(all_strategies)

View File

@ -0,0 +1,274 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from dataclasses import dataclass, field
from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementList,
StrategyType,
)
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
# refcount allows shared usage of the MaskBuffer, as long as all users have the same data
refcount: int = 0
def materialize_mask(self, mask):
if self.refcount == 0:
self.data = mask
else:
assert self.data is not None
if not torch.equal(self.data, mask):
raise RuntimeError(
"MaskBuffer has been materialized with conflicting data"
)
self.refcount += 1
def release_mask(self):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
self.refcount -= 1
if self.refcount == 0:
self.data = None
def apply_mask(self, tensor):
if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
# For gather, the mask has the same dimension as the output tensor, whereas
# the output of the embedding op has an additional dimension compare to the input,
# hence the output masking logic below having two different cases.
if tensor.ndim == self.data.ndim:
tensor[self.data] = 0.0
else:
tensor[self.data, :] = 0.0
@dataclass(frozen=True)
class _MaskPartial(Partial):
"""
A partial mask placement devised for rowwise sharded embedding op, where we need
to mask and adjust the indices to the local embedding shard, embedding masking
is a special type of the Partial placement
NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
# required fields for computing the local offset and deriving the mask
offset_shape: Optional[torch.Size] = None
offset_dim: int = 0
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
assert (
self.offset_shape is not None
), "offset_shape needs to be set for _MaskPartial"
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
self.offset_shape[self.offset_dim],
num_chunks,
mesh.get_local_rank(mesh_dim),
return_offset=True,
)
# Build the input mask and save it for the current partial placement
# this is so that the output of embedding op can reuse the same partial
# placement saved mask to perform mask + reduction
mask = (tensor < local_offset_on_dim) | (
tensor >= local_offset_on_dim + local_shard_size
)
# mask the input tensor
masked_tensor = tensor.clone() - local_offset_on_dim
masked_tensor[mask] = 0
# materialize the mask buffer to be used for reduction
self.mask_buffer.materialize_mask(mask)
return masked_tensor
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# by the time we ned reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# perform sum reduction
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# by the time we ned reduction, we should have already saved the mask
assert self.mask_buffer.data is not None
# apply the mask to the tensor that pending reduction
self.mask_buffer.apply_mask(tensor)
# clear the mask buffer
self.mask_buffer.release_mask()
# call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def __eq__(self, other: object) -> bool:
if not isinstance(other, _MaskPartial):
return False
# if either data is not None, we invalidate the sharding cache, as this indicates
# the current MaskPartial placement is still in use and should not be used for cache hit.
if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
return False
return (
self.reduce_op == other.reduce_op
and self.offset_shape == other.offset_shape
and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
(
self.reduce_op,
self.offset_shape,
self.offset_dim,
)
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
human readable representation of the MaskPartial placement
"""
return "MaskP"
@register_op_strategy(aten.embedding.default)
def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""
This strategy handles embedding op. We have two possible embedding shardings:
rowwise and colwise
"""
weight_strategy = cast(OpStrategy, op_schema.args_schema[0])
indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
weight_shape = weight_strategy.shape
indices_shape = indices_strategy.shape
output_emd_dim = len(indices_shape)
single_mesh_dim_strategies = []
# placement list stores placements of [output, weight, input_indices]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)
# colwise sharding, output shard on last dim, weight shard on dim 1, input replicate
colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()]
single_mesh_dim_strategies.append(colwise_sharding)
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
# from the input indices and use it for output reduction
rowwise_sharding: PlacementList = [
embedding_partial_placement,
Shard(0),
embedding_partial_placement,
]
single_mesh_dim_strategies.append(rowwise_sharding)
# batch dim sharding, weight replicated, input can shard on any dim, output follows input
for input_dim in range(len(indices_shape)):
batch_sharding: PlacementList = [
Shard(input_dim),
Replicate(),
Shard(input_dim),
]
single_mesh_dim_strategies.append(batch_sharding)
return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies)
@register_op_strategy(aten.embedding_dense_backward.default)
def embedding_dense_backward_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> StrategyType:
"""
This strategy handles embedding op. We have two possible embedding shardings:
rowwise and colwise
"""
grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0])
indices_strategy = cast(OpStrategy, op_schema.args_schema[1])
grad_out_shape = grad_out_strategy.shape
indices_shape = indices_strategy.shape
grad_out_ndim = len(grad_out_shape)
single_mesh_dim_strategies = []
# placement list stores placements of [output, weight, input_indices]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)
# colwise sharding backward, grad_out shard on last dim, input replicate,
# weight grad shard colwise
colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()]
single_mesh_dim_strategies.append(colwise_sharding)
# batch dim sharding, weight replicated, grad_out/input have same sharding
# that can shard on any dim, weight grad partial
for input_dim in range(len(indices_shape)):
batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)]
single_mesh_dim_strategies.append(batch_sharding)
# grad_out partial, input replicate, weight grad keep partial
partial_sharding: PlacementList = [Partial(), Partial(), Replicate()]
single_mesh_dim_strategies.append(partial_sharding)
return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies)

View File

@ -0,0 +1,28 @@
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementStrategy,
StrategyType,
)
from torch.distributed.tensor._ops.utils import register_op_strategy
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Replicate
aten = torch.ops.aten
@register_op_strategy(aten.slice_backward.default)
def slice_backward_rules(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""
slice_backward is a new_zeros + slice_scatter, we only allow replication
on the input/output for now since new_zeros would produce replication
"""
replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
return OpStrategy([PlacementStrategy(replicate_spec)])

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,500 @@
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementList,
PlacementStrategy,
RuntimeSchemaInfo,
)
from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
generate_redistribute_costs,
infer_broadcast_dims_map,
is_tensor_shardable,
map_placements_after_broadcast,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
aten = torch.ops.aten
@register_op_strategy(aten.t.default)
def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
self_strategy = op_schema.args_schema[0]
assert isinstance(self_strategy, OpStrategy)
transpose_strategies = []
for input_strategy in self_strategy.strategies:
input_spec = input_strategy.output_spec
# follow the input spec but transpose the Shard placements
output_placements = [
Shard(1 - p.dim) if isinstance(p, Shard) else p
for p in input_spec.placements
]
transpose_strategy = PlacementStrategy(
output_specs=DTensorSpec(
mesh=input_strategy.output_spec.mesh,
placements=tuple(output_placements),
),
input_specs=(input_strategy.output_spec,),
)
transpose_strategies.append(transpose_strategy)
return OpStrategy(strategies=transpose_strategies)
def _mm_like_strategy(
mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
self_strategy, mat2_strategy = op_schema.args_schema
assert isinstance(self_strategy, OpStrategy)
assert isinstance(mat2_strategy, OpStrategy)
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs
strategies = mm_strategy.strategies
filtered_strategies = []
for strtg in strategies:
assert strtg.input_specs is not None
self_spec = strtg.input_specs[0]
mat2_spec = strtg.input_specs[1]
if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable(
mat2_strategy.shape, mat2_spec
):
redistribute_cost = [
generate_redistribute_costs(self_strategy, self_spec),
generate_redistribute_costs(mat2_strategy, mat2_spec),
]
strtg.redistribute_cost = redistribute_cost
filtered_strategies.append(strtg)
mm_strategy.strategies = filtered_strategies
return mm_strategy
def _addmm_like_strategy(
mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema
assert isinstance(self_strategy, OpStrategy)
assert isinstance(mat1_strategy, OpStrategy)
assert isinstance(mat2_strategy, OpStrategy)
self_shape = self_strategy.shape
mm_out_shape = torch.Size(
[
mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size
for i, dim_size in enumerate(mat1_strategy.shape)
]
)
# generate all possible strategies for mm
mm_strategy = gen_einsum_strategies(mm_equation, mesh)
# filter out invalid strategies and associate costs
strategies = mm_strategy.strategies
filtered_strategies = []
for strtg in strategies:
# construct new strategy by consider the self arg
assert strtg.input_specs is not None
mat1_spec = strtg.input_specs[0]
mat2_spec = strtg.input_specs[1]
out_spec = strtg.output_spec
# self arg's spec should follow the output of mm, but need
# to consider broadcast for the self arg
broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape)
self_placements = map_placements_after_broadcast(
out_spec.placements, mm_out_shape, broadcast_dims_map
)
self_spec = DTensorSpec(mesh=mesh, placements=self_placements)
if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable(
mat2_strategy.shape, mat2_spec
):
# update input specs with new self spec
strtg.input_specs = (self_spec, mat1_spec, mat2_spec)
# associate costs
redistribute_cost = [
generate_redistribute_costs(self_strategy, self_spec),
generate_redistribute_costs(mat1_strategy, mat1_spec),
generate_redistribute_costs(mat2_strategy, mat2_spec),
]
strtg.redistribute_cost = redistribute_cost
filtered_strategies.append(strtg)
mm_strategy.strategies = filtered_strategies
return mm_strategy
@register_op_strategy(aten.mm.default)
def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(aten.addmm.default)
def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _addmm_like_strategy("mk,kn->mn", mesh, op_schema)
@register_op_strategy(aten.bmm.default)
def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
@register_op_strategy(aten.baddbmm.default)
def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema)
@register_op_strategy(
aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5)
)
def scaled_dot_product_flash_attention_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
# TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation
# as it involves: matmul, pointwise, reduction ops together.
return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5]
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 3 valid tensor outputs and 3 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [
Replicate(),
Replicate(),
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
Replicate(),
Replicate(),
Replicate(),
Replicate(),
]
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the num of head dim
qkv_sharding = Shard(1) # num head dim
output_sharding = Shard(1) # num head dim
logsumexp_sharding = Shard(1) # num head dim
if return_debug_mask:
debug_attn_mask_sharding: Placement = Shard(1) # num head dim
else:
# empty debug mask, replicated
debug_attn_mask_sharding = Replicate()
num_heads_dim_sharding: PlacementList = [
output_sharding,
logsumexp_sharding,
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
debug_attn_mask_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
]
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # cum_seq_q
None, # cum_seq_k
None, # max_q
None, # max_k
None, # philox_seed
None, # philox_offset
Shard(2), # debugattn
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=9
)
@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default)
def scaled_dot_product_flash_attention_backward_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
tensor_input_indices = [
i
for i, arg_spec in enumerate(op_schema.args_schema)
if isinstance(arg_spec, OpStrategy)
]
num_tensor_inputs = len(tensor_input_indices)
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the num of head dim
grad_output_sharding = Shard(1) # num head dim
qkv_sharding = Shard(1) # num head dim
output_sharding = Shard(1) # num head dim
logsumexp_sharding = Shard(1) # num head dim
grad_qkv_sharding = Shard(1) # num head dim
num_heads_dim_sharding: PlacementList = [
grad_qkv_sharding,
grad_qkv_sharding,
grad_qkv_sharding,
grad_output_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
output_sharding,
logsumexp_sharding,
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6))
single_mesh_dim_strategies.append(seq_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=3
)
@register_op_strategy(aten.constant_pad_nd.default)
def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
# TODO(d4l3k); implement a more correct strategy for constant_pad_nd
return OpStrategy(
[
PlacementStrategy(
output_specs=DTensorSpec(mesh, (Replicate(),)),
input_specs=(
DTensorSpec(mesh, (Replicate(),)),
DTensorSpec(mesh, (Replicate(),)),
),
redistribute_cost=[[1]],
)
]
)
@register_op_strategy(
aten._scaled_dot_product_efficient_attention.default,
schema_info=RuntimeSchemaInfo(4),
)
def scaled_dot_product_efficient_attention_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
# NOTE: currently we only support some simple strategies to support tensor parallelism
q_input_strategy = op_schema.args_schema[0]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
has_attn_bias = op_schema.args_schema[3] is not None
compute_log_sumexp = op_schema.args_schema[4]
single_mesh_dim_strategies: List[PlacementList] = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [
Replicate(),
Replicate(),
None,
None,
Replicate(),
Replicate(),
Replicate(),
]
if has_attn_bias:
all_replicate.append(Replicate()) # attn bias
# Context Parallelism: shards on the sequence dim
single_mesh_dim_strategies.append(
[
Shard(2), # output
Shard(2), # logsumexp
None, # philox_seed
None, # philox_offset
Shard(2), # q
Shard(2), # k
Shard(2), # v
]
)
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the heads dimension
qkv_sharding = Shard(1)
output_sharding = Shard(1)
if compute_log_sumexp:
logsumexp_sharding: Placement = Shard(1)
else:
# empty logsumexp, replicated
logsumexp_sharding = Replicate()
num_heads_dim_sharding = [
output_sharding,
logsumexp_sharding,
None,
None,
qkv_sharding,
qkv_sharding,
qkv_sharding,
]
if has_attn_bias:
num_heads_dim_sharding.append(Shard(1))
single_mesh_dim_strategies.append(num_heads_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
single_mesh_dim_strategies,
input_index=4,
)
@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default)
def scaled_dot_product_efficient_attention_backward_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> OpStrategy:
q_input_strategy = op_schema.args_schema[1]
assert isinstance(q_input_strategy, OpStrategy)
# assuming q/k/v have the same shape
qkv_shape = q_input_strategy.shape
has_attn_bias = op_schema.args_schema[4] is not None
tensor_input_indices = [
i
for i, arg_spec in enumerate(op_schema.args_schema)
if isinstance(arg_spec, OpStrategy)
]
single_mesh_dim_strategies = []
# placement list stores placements of [outputs, inputs]
# in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs
# NOTE: Output sharding of grad_bias on heads dim if attn_bias is present;
# otherwise grad_bias will be empty and its DTensorSpec will be removed.
# first we can always accept full replication for both inputs and outputs
all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias)
if not has_attn_bias:
all_replicate[3] = None # grad bias is None if attn_bias is not present
single_mesh_dim_strategies.append(all_replicate)
# second we can accept the sharding pattern of tensor parallelism, which
# shard on the heads dimension
grad_output_sharding = Shard(1)
qkv_sharding = Shard(1)
output_sharding = Shard(1)
logsumexp_sharding = Shard(1)
grad_qkv_sharding = Shard(1)
grad_bias_sharding = Shard(1) if has_attn_bias else None
num_heads_dim_sharding: PlacementList = [
grad_qkv_sharding,
grad_qkv_sharding,
grad_qkv_sharding,
grad_bias_sharding,
grad_output_sharding,
qkv_sharding,
qkv_sharding,
qkv_sharding,
# the place for optional input attn_bias,
output_sharding,
logsumexp_sharding,
]
# input sharding of attn_bias on heads dim if present
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
# accept replicate on the rest scalar tensor inputs
# namely philox_seed and philox_offset
num_heads_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(num_heads_dim_sharding)
# Context Parallelism: shards on the sequence dim
seq_dim_sharding: PlacementList = [
Shard(2), # grad_q
Shard(2), # grad_k
Shard(2), # grad_v
Shard(1) if has_attn_bias else None, # grad_bias
Shard(2), # grad_output
Shard(2), # q
Shard(2), # k
Shard(2), # v
Shard(2), # output
Shard(2), # logsumexp
]
# accept replicate on the rest tensor inputs, potentially
# cum_seq_q, cum_seq_k, philox_seed, philox_offset
# at indices 6, 7, 12, 13, respectively
if has_attn_bias:
num_heads_dim_sharding.insert(8, Shard(1))
seq_dim_sharding.extend([Replicate(), Replicate()])
single_mesh_dim_strategies.append(seq_dim_sharding)
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
single_mesh_dim_strategies,
input_index=4,
)

View File

@ -0,0 +1,688 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import List, Sequence, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
_is_inplace_op,
_is_out_variant_op,
OpSchema,
OpStrategy,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
TupleStrategy,
)
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
infer_broadcast_dims_map,
map_placements_after_broadcast,
normalize_dim,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten
# leave the remaining pointwise_ops list here for convenience,
# Below ops are some pointwise ops that are yet to be supported,
# they might not be a complete list.
# pointwise_ops = [
# "fake_quantize_per_channel_affine",
# "fake_quantize_per_tensor_affine",
# "floor_divide", # floor_divide is deprecated
# "frexp", # multiple output pointwise op, need to add support
# "gradient", # need investigation on this op
# "imag", # complex data type only
# "quantized_batch_norm",
# "quantized_max_pool1d",
# "quantized_max_pool2d",
# "real", # complex data type only
# ]
linear_pointwise_ops = [
aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op.
aten.to.dtype,
aten.add.Tensor,
aten.add_.Tensor,
]
pointwise_ops = [
# please keep the entries below alphabetically sorted
aten.__ilshift__.Scalar,
aten.__ilshift__.Tensor,
aten.__irshift__.Scalar,
aten.__irshift__.Tensor,
aten.__lshift__.Scalar,
aten.__lshift__.Tensor,
aten.__rshift__.Scalar,
aten.__rshift__.Tensor,
aten._conj.default,
aten.abs.default,
aten.abs.out,
aten.abs_.default,
aten.acos.default,
aten.acos.out,
aten.acos_.default,
aten.acosh.default,
aten.acosh.out,
aten.acosh_.default,
aten.add.Scalar,
aten.add.out,
aten.add_.Scalar,
aten.addcdiv.default,
aten.addcdiv.out,
aten.addcdiv_.default,
aten.addcmul.default,
aten.addcmul.out,
aten.addcmul_.default,
aten.angle.default,
aten.angle.out,
aten.asin.default,
aten.asin.out,
aten.asin_.default,
aten.asinh.default,
aten.asinh.out,
aten.asinh_.default,
aten.atan.default,
aten.atan.out,
aten.atan2.default,
aten.atan2.out,
aten.atan2_.default,
aten.atan_.default,
aten.atanh.default,
aten.atanh.out,
aten.atanh_.default,
aten.bitwise_and.Scalar,
aten.bitwise_and.Scalar_Tensor,
aten.bitwise_and.Scalar_out,
aten.bitwise_and.Tensor,
aten.bitwise_and.Tensor_out,
aten.bitwise_and_.Scalar,
aten.bitwise_and_.Tensor,
aten.bitwise_left_shift.Scalar_Tensor,
aten.bitwise_left_shift.Tensor,
aten.bitwise_left_shift.Tensor_Scalar,
aten.bitwise_left_shift.Tensor_Scalar_out,
aten.bitwise_left_shift.Tensor_out,
aten.bitwise_left_shift_.Tensor,
aten.bitwise_left_shift_.Tensor_Scalar,
aten.bitwise_not.default,
aten.bitwise_not.out,
aten.bitwise_not_.default,
aten.bitwise_or.Scalar,
aten.bitwise_or.Scalar_Tensor,
aten.bitwise_or.Scalar_out,
aten.bitwise_or.Tensor,
aten.bitwise_or.Tensor_out,
aten.bitwise_or_.Scalar,
aten.bitwise_or_.Tensor,
aten.bitwise_right_shift.Scalar_Tensor,
aten.bitwise_right_shift.Tensor,
aten.bitwise_right_shift.Tensor_Scalar,
aten.bitwise_right_shift.Tensor_Scalar_out,
aten.bitwise_right_shift.Tensor_out,
aten.bitwise_right_shift_.Tensor,
aten.bitwise_right_shift_.Tensor_Scalar,
aten.bitwise_xor.Scalar,
aten.bitwise_xor.Scalar_Tensor,
aten.bitwise_xor.Scalar_out,
aten.bitwise_xor.Tensor,
aten.bitwise_xor.Tensor_out,
aten.bitwise_xor_.Scalar,
aten.bitwise_xor_.Tensor,
aten.ceil.default,
aten.ceil.out,
aten.ceil_.default,
aten.clamp.default,
aten.clamp.out,
aten.clamp_.default,
aten.clip.default,
aten.clip.out,
aten.clip_.default,
aten.conj_physical.default,
aten.conj_physical.out,
aten.conj_physical_.default,
aten.copysign.Scalar,
aten.copysign.Scalar_out,
aten.copysign.Tensor,
aten.copysign.out,
aten.copysign_.Scalar,
aten.copysign_.Tensor,
aten.cos.default,
aten.cos.out,
aten.cos_.default,
aten.cosh.default,
aten.cosh.out,
aten.cosh_.default,
aten.deg2rad.default,
aten.deg2rad.out,
aten.deg2rad_.default,
aten.digamma.default,
aten.digamma.out,
aten.digamma_.default,
aten.div.Tensor,
aten.div.Tensor_mode,
aten.div.out,
aten.div.out_mode,
aten.div_.Tensor,
aten.div_.Tensor_mode,
aten.eq.Tensor,
aten.eq.Tensor_out,
aten.eq.Scalar,
aten.eq.Scalar_out,
aten.erf.default,
aten.erf.out,
aten.erf_.default,
aten.erfc.default,
aten.erfc.out,
aten.erfc_.default,
aten.erfinv.default,
aten.erfinv.out,
aten.erfinv_.default,
aten.exp.default,
aten.exp.out,
aten.exp2.default,
aten.exp2.out,
aten.exp2_.default,
aten.exp_.default,
aten.expm1.default,
aten.expm1.out,
aten.expm1_.default,
aten.float_power.Scalar,
aten.float_power.Scalar_out,
aten.float_power.Tensor_Scalar,
aten.float_power.Tensor_Scalar_out,
aten.float_power.Tensor_Tensor,
aten.float_power.Tensor_Tensor_out,
aten.float_power_.Scalar,
aten.float_power_.Tensor,
aten.floor.default,
aten.floor.out,
aten.floor_.default,
aten.fmod.Scalar,
aten.fmod.Scalar_out,
aten.fmod.Tensor,
aten.fmod.Tensor_out,
aten.fmod_.Scalar,
aten.fmod_.Tensor,
aten.frac.default,
aten.frac.out,
aten.frac_.default,
aten.ge.Scalar,
aten.ge.Tensor,
aten.gelu.default,
aten.gt.Tensor,
aten.gt.Tensor_out,
aten.gt.Scalar,
aten.gt.Scalar_out,
aten.gt.Scalar,
aten.gt.Tensor,
aten.hypot.default,
aten.hypot.out,
aten.hypot_.default,
aten.i0.default,
aten.i0.out,
aten.i0_.default,
aten.igamma.default,
aten.igamma.out,
aten.igamma_.default,
aten.igammac.default,
aten.igammac.out,
aten.igammac_.default,
aten.isinf.default,
aten.isnan.default,
aten.isneginf.default,
aten.isneginf.out,
aten.isposinf.default,
aten.isposinf.out,
aten.ldexp.default,
aten.ldexp.out,
aten.ldexp_.default,
aten.lt.Tensor,
aten.lt.Tensor_out,
aten.lt.Scalar,
aten.lt.Scalar_out,
aten.le.Scalar,
aten.le.Tensor,
aten.lerp.Scalar,
aten.lerp.Scalar_out,
aten.lerp.Tensor,
aten.lerp.Tensor_out,
aten.lerp_.Scalar,
aten.lerp_.Tensor,
aten.lgamma.default,
aten.lgamma.out,
aten.lgamma_.default,
aten.log.default,
aten.log.out,
aten.log10.default,
aten.log10.out,
aten.log10_.default,
aten.log1p.default,
aten.log1p.out,
aten.log1p_.default,
aten.log2.default,
aten.log2.out,
aten.log2_.default,
aten.log_.default,
aten.logaddexp.default,
aten.logaddexp.out,
aten.logaddexp2.default,
aten.logaddexp2.out,
aten.logical_and.default,
aten.logical_and.out,
aten.logical_and_.default,
aten.logical_not.default,
aten.logical_not.out,
aten.logical_not_.default,
aten.logical_or.default,
aten.logical_or.out,
aten.logical_or_.default,
aten.logical_xor.default,
aten.logical_xor.out,
aten.logical_xor_.default,
aten.logit.default,
aten.logit.out,
aten.logit_.default,
aten.masked_fill.Scalar,
aten.maximum.out,
aten.mul.Scalar,
aten.mul.Tensor,
aten.mul.out,
aten.mul_.Scalar,
aten.mul_.Tensor,
aten.mvlgamma.default,
aten.mvlgamma.out,
aten.mvlgamma_.default,
aten.native_dropout_backward.default,
aten.native_dropout_backward.out,
aten.nan_to_num.default,
aten.nan_to_num.out,
aten.nan_to_num_.default,
aten.ne.Scalar,
aten.neg.default,
aten.neg.out,
aten.neg_.default,
aten.nextafter.default,
aten.nextafter.out,
aten.nextafter_.default,
aten.polygamma.default,
aten.polygamma.out,
aten.polygamma_.default,
aten.positive.default,
aten.pow.Scalar,
aten.pow.Scalar_out,
aten.pow.Tensor_Scalar,
aten.pow.Tensor_Scalar_out,
aten.pow.Tensor_Tensor,
aten.pow.Tensor_Tensor_out,
aten.pow_.Scalar,
aten.pow_.Tensor,
aten.reciprocal.default,
aten.reciprocal.out,
aten.reciprocal_.default,
aten.rad2deg.default,
aten.rad2deg.out,
aten.rad2deg_.default,
aten.relu.default,
aten.relu_.default,
aten.remainder.Scalar,
aten.remainder.Scalar_Tensor,
aten.remainder.Scalar_out,
aten.remainder.Tensor,
aten.remainder.Tensor_out,
aten.remainder_.Scalar,
aten.remainder_.Tensor,
aten.round.decimals,
aten.round.decimals_out,
aten.round.default,
aten.round.out,
aten.round_.decimals,
aten.round_.default,
aten.rsqrt.default,
aten.rsqrt.out,
aten.rsqrt_.default,
aten.rsub.Scalar,
aten.sgn.default,
aten.sgn.out,
aten.sgn_.default,
aten.sigmoid.default,
aten.sigmoid.out,
aten.sigmoid_.default,
aten.sign.default,
aten.sign.out,
aten.sign_.default,
aten.signbit.default,
aten.signbit.out,
aten.silu.default,
aten.silu.out,
aten.sin.default,
aten.sin.out,
aten.sin_.default,
aten.sinc.default,
aten.sinc.out,
aten.sinc_.default,
aten.sinh.default,
aten.sinh.out,
aten.sinh_.default,
aten.sqrt.default,
aten.sqrt.out,
aten.sqrt_.default,
aten.square.default,
aten.square.out,
aten.square_.default,
aten.sub.Scalar,
aten.sub.Tensor,
aten.sub.out,
aten.sub_.Scalar,
aten.sub_.Tensor,
aten.tan.default,
aten.tan.out,
aten.tan_.default,
aten.tanh.default,
aten.tanh.out,
aten.tanh_.default,
aten.true_divide.Tensor,
aten.trunc.default,
aten.trunc.out,
aten.trunc_.default,
aten.where.self,
aten.where.self_out,
aten.xlogy.OutScalar_Self,
aten.xlogy.OutScalar_Other,
aten.xlogy.OutTensor,
aten.xlogy.Scalar_Other,
aten.xlogy.Scalar_Self,
aten.xlogy.Tensor,
aten.xlogy_.Scalar_Other,
aten.xlogy_.Tensor,
# backward point-wise ops
# please keep the entries below alphabetically sorted
aten.gelu_backward.default,
aten.sigmoid_backward.default,
aten.silu_backward.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
]
def pointwise_strategy(
mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
) -> OpStrategy:
max_shards_strategy_index = -1
max_shards = -1
if _is_inplace_op(op_schema.op):
# inplace op should follow the first arg strategy
followed_strategy = op_schema.args_schema[0]
elif _is_out_variant_op(op_schema.op):
# out variant op should follow the out kwarg strategy
followed_strategy = op_schema.kwargs_schema["out"]
else:
# normal pointwise op, we choose to follow the arg with
# the max shards in case operands needs reshard
for idx, arg_strategy in enumerate(op_schema.args_schema):
if not isinstance(arg_strategy, OpStrategy):
continue
arg_max_shards = arg_strategy.max_num_shards()
if arg_max_shards > max_shards:
max_shards_strategy_index = idx
max_shards = arg_max_shards
followed_strategy = op_schema.args_schema[max_shards_strategy_index]
assert isinstance(
followed_strategy, OpStrategy
), f"no strategy to follow for {op_schema}!"
return common_pointwise_strategy(
mesh, op_schema.args_schema, followed_strategy, linearity
)
def common_pointwise_strategy(
mesh: DeviceMesh,
args_schema: Sequence[object],
followed_strategy: OpStrategy,
linearity: bool,
) -> OpStrategy:
# handle broadcasting
common_shape = torch.broadcast_shapes(
*[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
)
pointwise_strategy = OpStrategy([])
for placement_strategy in followed_strategy.strategies:
spec_to_follow = placement_strategy.output_spec
out_placements: List[Placement] = []
for placement in spec_to_follow.placements:
if isinstance(placement, Shard):
shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
common_ndim = len(common_shape)
new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim
out_placements.append(Shard(new_shard_dim))
elif isinstance(placement, Partial) and not linearity:
# clear the partial placemnet if op does not support linearity
# by default we just replicate the partial, need to see if this
# is optimal for all cases
out_placements.append(Replicate())
else:
out_placements.append(placement)
input_specs: List[DTensorSpec] = []
redistribute_costs: List[List[float]] = []
for input_arg in args_schema:
if isinstance(input_arg, OpStrategy):
# every arg follow the out_placements, but need to handle broadcasting
input_arg_spec = input_arg.strategies[0].output_spec
input_arg_dims_map = infer_broadcast_dims_map(
common_shape, input_arg_spec.shape
)
input_target_placements = map_placements_after_broadcast(
tuple(out_placements),
common_shape,
input_arg_dims_map,
)
input_arg_target_spec = DTensorSpec(
mesh=mesh,
placements=input_target_placements,
tensor_meta=input_arg_spec.tensor_meta,
)
input_specs.append(input_arg_target_spec)
redistribute_costs.append(
generate_redistribute_costs(input_arg, input_arg_target_spec)
)
pointwise_strategy.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(
mesh=mesh,
placements=tuple(out_placements),
),
input_specs=input_specs,
redistribute_cost=redistribute_costs,
)
)
return pointwise_strategy
def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""
Linear pointwise operators can propagate pending reductions.
For example, c = add(a, b); if a is pending sum, then c will be
pending sum as well without any communication overhead.
"""
return pointwise_strategy(mesh, op_schema, linearity=True)
for op in linear_pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
linear_pointwise_strategy
)
for op in pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
pointwise_strategy
)
# TODO: add all for_each ops
for_each_ops = [
aten._foreach_abs.default,
aten._foreach_abs_.default,
aten._foreach_addcdiv_.Scalar,
aten._foreach_addcdiv_.ScalarList,
aten._foreach_addcdiv_.Tensor,
aten._foreach_addcmul.Scalar,
aten._foreach_addcmul_.Scalar,
aten._foreach_addcmul_.ScalarList,
aten._foreach_addcmul_.Tensor,
aten._foreach_clamp_max_.Scalar,
aten._foreach_clamp_min_.Scalar,
aten._foreach_div_.List,
aten._foreach_div_.Scalar,
aten._foreach_div_.ScalarList,
aten._foreach_div_.Tensor,
aten._foreach_div.List,
aten._foreach_div.Scalar,
aten._foreach_div.ScalarList,
aten._foreach_div.Tensor,
aten._foreach_lerp_.Scalar,
aten._foreach_maximum_.List,
aten._foreach_mul.Scalar,
aten._foreach_mul.ScalarList,
aten._foreach_mul.Tensor,
aten._foreach_mul.List,
aten._foreach_mul_.Scalar,
aten._foreach_mul_.ScalarList,
aten._foreach_mul_.Tensor,
aten._foreach_mul_.List,
aten._foreach_neg.default,
aten._foreach_neg_.default,
aten._foreach_reciprocal_.default,
aten._foreach_sub.Scalar,
aten._foreach_sub_.Scalar,
aten._foreach_sub.List,
aten._foreach_sub_.List,
aten._foreach_sub.ScalarList,
aten._foreach_sub_.ScalarList,
aten._foreach_sqrt.default,
aten._foreach_sqrt_.default,
aten._foreach_zero_.default,
aten._foreach_exp.default,
aten._foreach_exp_.default,
aten._foreach_cos.default,
aten._foreach_cos_.default,
aten._foreach_log.default,
aten._foreach_log_.default,
aten._amp_foreach_non_finite_check_and_unscale_.default,
]
for_each_linearity_ops = [
aten._foreach_add.Scalar,
aten._foreach_add_.Scalar,
aten._foreach_add_.ScalarList,
aten._foreach_add.List,
aten._foreach_add_.List,
]
def list_pointwise_strategy(
mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
) -> StrategyType:
"""
Apply the pointwise strategy to the zipped arguments. For example, if we
run a foreach add of two lists l1 and l2, then we apply the pointwise
strategy on each pair (l1[i], l2[i]). If the first argument is a list but
the second (or later) one is a tensor, then we broadcast the tensor by
replicating it into a list with the length of the first argument.
Args:
mesh (DeviceMesh): device mesh for pointwise ops
op_schema (OpSchema): schema of the operator to generate strategy for
linearity (bool): specify whether op(a) + op(b) = op(a + b)
Returns:
OpStrategy: generated strategy
"""
def args_tuple_strategies(args_schema: Tuple[object, ...]) -> List[TupleStrategy]:
first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.childs)
tuple_strategies: List[TupleStrategy] = []
for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length
assert len(arg.childs) == strategy_len
tuple_strategies.append(arg)
elif isinstance(arg, OpStrategy):
if arg_idx > 0: # implicitly broadcast
tuple_strategies.append(
TupleStrategy([arg for _ in range(strategy_len)])
)
else:
raise RuntimeError(
f"list op only supports tuple strategy! {op_schema}"
)
return tuple_strategies
args_strategies = args_tuple_strategies(op_schema.args_schema)
follow_strategy: TupleStrategy = args_strategies[0]
list_strategy: List[OpStrategy] = []
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
assert isinstance(child_strtgy, OpStrategy)
args_schema: List[StrategyType] = [
arg_strategy.childs[child_idx] for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(
mesh, args_schema, child_strtgy, linearity
)
list_strategy.append(pointwise_strategy)
return TupleStrategy(list_strategy)
def list_linear_pointwise_strategy(
mesh: DeviceMesh, op_schema: OpSchema
) -> StrategyType:
"""
for each list op stratgy that supports linearity
"""
return list_pointwise_strategy(mesh, op_schema, linearity=True)
for op in for_each_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
list_pointwise_strategy
)
for op in for_each_linearity_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
list_linear_pointwise_strategy
)
fused_ops = [
aten._fused_adam_.default,
aten._fused_adam.default,
aten._fused_adam.tensor_lr,
aten._fused_adam_.tensor_lr,
aten._fused_adamw_.default,
aten._fused_adamw.default,
aten._fused_adamw.tensor_lr,
aten._fused_adamw_.tensor_lr,
]
for op in fused_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))(
list_pointwise_strategy
)

View File

@ -0,0 +1,38 @@
# mypy: allow-untyped-decorators
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementStrategy,
StrategyType,
)
from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy
aten = torch.ops.aten
@register_op_strategy(
[
aten.normal_.default,
aten.uniform_.default,
aten.native_dropout.default,
aten.bernoulli_.float,
aten.bernoulli.default,
]
)
def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
self_strategy = op_schema.args_schema[0]
assert isinstance(self_strategy, OpStrategy)
random_strategy = OpStrategy([])
for arg_strategy in self_strategy.strategies:
arg_spec = arg_strategy.output_spec
if is_tensor_partial(arg_spec):
# TODO: figure out how inplace random op should behave when it's partial
raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!")
random_strategy.strategies.append(PlacementStrategy(output_specs=arg_spec))
return random_strategy

View File

@ -0,0 +1,792 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Tuple
import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
_is_inplace_op,
OpSchema,
OpStrategy,
OutputSharding,
PlacementList,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
TupleStrategy,
)
from torch.distributed.tensor._ops._common_rules import pointwise_rule
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops.utils import (
expand_to_full_mesh_op_strategy,
is_tensor_dim_sharded,
is_tensor_evenly_shardable,
is_tensor_partial,
normalize_dim,
register_op_strategy,
register_prop_rule,
)
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
aten = torch.ops.aten
def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# Default strategy by default just propagate the first input strategy
select_strategy = op_schema.args_schema[0]
assert isinstance(select_strategy, OpStrategy)
default_strategy = []
for strategy in select_strategy.strategies:
# we create new DTensorSpecs even for default strategy to assure that
# the tensor metas are distinct between the arguments and outputs
default_strategy.append(
PlacementStrategy(
output_specs=DTensorSpec(
mesh=strategy.output_spec.mesh,
placements=strategy.output_spec.placements,
)
)
)
return OpStrategy(default_strategy)
register_op_strategy(
[
aten.clone.default,
aten.contiguous.default,
aten.copy_.default,
aten.detach.default,
aten.fill_.Scalar,
aten.zero_.default,
]
)(default_strategy)
register_op_strategy(
aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"])
)(default_strategy)
@register_op_strategy(
[
aten.equal.default,
aten.is_same_size.default,
]
)
def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# equal_strategy deals with ops that comparing two tensor, we need to make sure
# sharding layout the same with two operands, we choose to follow the arg with max
# num of shards, still keep is_same_size here for completeness as they share the
# same strategy in theory.
self_strategy, other_strategy = op_schema.args_schema
assert isinstance(self_strategy, OpStrategy)
assert isinstance(other_strategy, OpStrategy)
select_strategy = (
self_strategy
if self_strategy.max_num_shards() >= other_strategy.max_num_shards()
else other_strategy
)
equal_strategy = OpStrategy([])
for arg_strategy in select_strategy.strategies:
arg_spec = arg_strategy.output_spec
if is_tensor_partial(arg_spec):
# if the arg_spec have partial, reshard to replicate
# otherwise local shard tensor comparison would be invalid
output_spec = DTensorSpec(
mesh=arg_spec.mesh,
placements=tuple(
Replicate() if isinstance(p, Partial) else p
for p in arg_spec.placements
),
)
equal_strategy.strategies.append(
PlacementStrategy(output_specs=output_spec)
)
else:
equal_strategy.strategies.append(PlacementStrategy(arg_spec))
return equal_strategy
@register_op_strategy(
[
aten.empty_like.default,
aten.ones_like.default,
aten.rand_like.default,
aten.randn_like.default,
aten.zeros_like.default,
],
schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
@register_op_strategy(
[aten.full_like.default],
schema_info=RuntimeSchemaInfo(2, ["dtype"]),
)
@register_op_strategy(
[
aten.randint_like.default,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,
],
schema_info=RuntimeSchemaInfo(3, ["dtype"]),
)
def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# create_like_strategy deals with ops that creating tensors with same
# shape as input, but with specific content that does not depend on
# the input, we can propagate sharding, but we have to make sure we
# move from partial to replicated.
select_strategy = op_schema.args_schema[0]
create_like_strategy = OpStrategy([])
assert isinstance(select_strategy, OpStrategy)
for arg_strategy in select_strategy.strategies:
arg_spec = arg_strategy.output_spec
if is_tensor_partial(arg_spec):
# if the arg_spec have partial, accept partial
# in the input_specs but output replicate for
# those corresponding mesh dims
output_spec = DTensorSpec(
mesh=arg_spec.mesh,
placements=tuple(
Replicate() if isinstance(p, Partial) else p
for p in arg_spec.placements
),
)
create_like_strategy.strategies.append(
PlacementStrategy(output_specs=output_spec, input_specs=(arg_spec,))
)
else:
create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
return create_like_strategy
@register_op_strategy(
[
aten.new_empty.default,
aten.new_full.default,
aten.new_ones.default,
aten.new_zeros.default,
aten.new_empty_strided.default,
],
schema_info=RuntimeSchemaInfo(1, ["dtype"]),
)
def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# Currently there are two strategies:
# 1. let the output be replicated
# 2. let the output follow the input if input and output have the same shape
input_strategy = op_schema.args_schema[0]
assert isinstance(input_strategy, OpStrategy)
input_shape = input_strategy.shape
output_shape = op_schema.args_schema[1]
assert isinstance(output_shape, list)
new_factory_strategy = OpStrategy([])
for arg_strategy in input_strategy.strategies:
input_spec = arg_strategy.output_spec
replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
new_factory_strategy.strategies.append(
PlacementStrategy(
output_specs=replica_spec,
input_specs=(input_spec,),
redistribute_cost=[[0.0] * mesh.ndim],
)
)
if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded():
# NOTE: for new_empty_strided, currently the non-replicate sharding
# is supported only when the shape is evenly shardable
if (
op_schema.op == aten.new_empty_strided.default
and not is_tensor_evenly_shardable(input_shape, input_spec)
):
continue
new_factory_strategy.strategies.append(
PlacementStrategy(
output_specs=input_spec,
input_specs=(input_spec,),
# encouraging new tensor placement to be the same as input
redistribute_cost=[[-0.1] * mesh.ndim],
)
)
return new_factory_strategy
@register_op_strategy(aten.bucketize.Tensor)
def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""Just propagate input sharding, but expect replicated for boundaries input."""
input_strategy = op_schema.args_schema[0]
bucketize_strategy = OpStrategy([])
assert isinstance(input_strategy, OpStrategy)
for arg_strategy in input_strategy.strategies:
arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements)
replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
bucketize_strategy.strategies.append(
PlacementStrategy(
output_specs=arg_spec, input_specs=(arg_spec, replica_spec)
)
)
return bucketize_strategy
@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1))
def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""Forward all shardings except the slice dimension."""
defaults = (None, 0, None, None, 1)
input_strategy, dim, start, end, step = (
op_schema.args_schema + defaults[len(op_schema.args_schema) :]
)
assert isinstance(input_strategy, OpStrategy)
input_shape = input_strategy.shape
input_ndim = input_strategy.ndim
assert isinstance(dim, int)
if start is None:
start = 0
if end is None or end > input_shape[dim]:
end = input_shape[dim]
assert isinstance(start, int)
assert isinstance(end, int)
assert isinstance(step, int)
# normalize args
slice_dim = normalize_dim(dim, input_ndim)
start = normalize_dim(start, input_shape[dim])
end = normalize_dim(end, input_shape[dim])
redundant_slice = start == 0 and end == input_shape[dim] and step == 1
slice_strategy = OpStrategy([])
for arg_strategy in input_strategy.strategies:
arg_spec = arg_strategy.output_spec
if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice:
# only add the strategy if the slice dim is not sharded
out_spec = DTensorSpec(mesh, arg_spec.placements)
slice_strategy.strategies.append(PlacementStrategy(output_specs=out_spec))
if not slice_strategy.strategies:
# if all strategies are filtered out, unsharding all specs on slice dim
# of the input strategy, and use that as the op strategy
for arg_strategy in input_strategy.strategies:
arg_spec = arg_strategy.output_spec
unshard_spec = DTensorSpec(
mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim)
)
slice_strategy.strategies.append(
PlacementStrategy(output_specs=unshard_spec)
)
return slice_strategy
def unshard_tensor_dim(
placements: Sequence[Placement], dim: int
) -> Tuple[Placement, ...]:
"""Disallow the given tensor dimension to be sharded."""
return tuple(
p if (not isinstance(p, Shard) or p.dim != dim) else Replicate()
for p in placements
)
def replicate_tensor_dim(
placements: Sequence[Placement], dim: int
) -> Tuple[Placement, ...]:
"""Force the given tensor dimension to be replicated."""
# Not using p.is_shard() to avoid mypy complain about Placement not having
# attribute dim.
return tuple(
Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p
for p in placements
)
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
# 1. number of dimensions in input and src need to match.
# 2. number of elements on all non-dim need to match between input and src.
# 3. numer of elements in src in dim need to match the slice size.
# Given the above:
# - We suggest for src to follow the sharding of input, except on the scatter dimension,
# where our best bet for now is to make them replicated as a fall-back.
# TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding.
input_strategy = op_schema.args_schema[0]
assert isinstance(input_strategy, OpStrategy)
input_ndim = input_strategy.ndim
slice_dim = (
cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
)
slice_dim = normalize_dim(slice_dim, input_ndim)
slice_scatter_strategy = OpStrategy([])
# by default follow the input strategy for both input and src
for arg_strategy in input_strategy.strategies:
arg_spec = arg_strategy.output_spec
if not (
is_tensor_dim_sharded(arg_spec, dim=slice_dim)
or is_tensor_partial(arg_spec)
):
# only add the strategy if the slice_scatter dim is not sharded or partial
slice_scatter_strategy.strategies.append(
PlacementStrategy(output_specs=arg_spec)
)
if not slice_scatter_strategy.strategies:
# if all strategies are filtered out, replicating all specs on slice_scatter dim
# of the input strategy, and use that as the op strategy
for arg_strategy in input_strategy.strategies:
arg_spec = arg_strategy.output_spec
replicate_spec = DTensorSpec(
mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim)
)
slice_scatter_strategy.strategies.append(
PlacementStrategy(output_specs=replicate_spec)
)
return slice_scatter_strategy
@register_op_strategy(aten._local_scalar_dense.default)
def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
"""Only allow replication on the input/output."""
replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim))
return OpStrategy([PlacementStrategy(replicate_spec)])
@register_op_strategy(
[aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src],
schema_info=RuntimeSchemaInfo(1),
)
def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
single_mesh_dim_strategies = []
# placement list stores placements of [output, input, index, src]
# first we always have replicate all for inputs and output
if len(op_schema.args_strategy) < 3:
# scatter_.src/scatter.src with src be float number instead of tensor
all_replicate: PlacementList = [Replicate()] * 3
else:
all_replicate = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)
# TODO: see if we can support input sharding pattern
inplace_op = _is_inplace_op(op_schema.op)
op_strategy = expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, inplace_op=inplace_op
)
return op_strategy
@register_op_strategy(aten.gather.default)
def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
dim = cast(int, op_schema.args_schema[1])
index_strategy = cast(OpStrategy, op_schema.args_schema[2])
input_shape = input_strategy.shape
index_shape = index_strategy.shape
single_mesh_dim_strategies = []
# placement list stores placements of [output, input, index]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 3
single_mesh_dim_strategies.append(all_replicate)
# input sharding, input sharded, index accepts mask partial, output follows index
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
index_partial_placement,
]
single_mesh_dim_strategies.append(input_sharding)
# index sharding, input replicated, index sharded, output follows index
# this only works when the sharding dimension is the gather dimension
index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)]
single_mesh_dim_strategies.append(index_sharding)
return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)
def _derive_follow_placements_from_tuple_strategy(
tuple_strategy: TupleStrategy,
) -> Sequence[Placement]:
"""
derive the placements to follow from the tuple strategy, mainly used by
aten.stack, aten.cat, where each operand have the same shape, and correspondingly
expecting the same sharding
"""
def merge_placement(
cur_placement: Placement, new_placement: Placement
) -> Placement:
# semantic if we already have a follow placement, we
# check each placement for the current arg placement
# to see if we want to merge/adjust the placement to follow
# the priority: Partial -> Shard -> Replicate
if cur_placement == new_placement:
return cur_placement
if cur_placement.is_partial():
if new_placement.is_shard():
# follow new placement
return new_placement
elif new_placement.is_partial():
# different partial types, we can't merge and have to replicate all here
return Replicate()
else:
# follow partial
return cur_placement
elif cur_placement.is_shard():
if new_placement.is_shard():
# cur/new placement are different sharding (i.e. different shard dim)
# currently fallback to replicate all args
return Replicate()
else:
# for partial/replicate, follow the current shard placement
return cur_placement
else:
# current replicate, just follow new placement
return new_placement
follow_placements: Optional[List[Placement]] = None
for arg_strategy in tuple_strategy.childs:
assert isinstance(arg_strategy, OpStrategy)
for placement_strategy in arg_strategy.strategies:
arg_placements = placement_strategy.output_spec.placements
if follow_placements is None:
follow_placements = list(arg_placements)
continue
mesh_ndim = len(follow_placements)
assert follow_placements is not None
for mesh_idx in range(mesh_ndim):
# merge placements with the priority
follow_placements[mesh_idx] = merge_placement(
follow_placements[mesh_idx], arg_placements[mesh_idx]
)
assert follow_placements is not None, "follow placements should not be None!"
return follow_placements
def normalize_shard_for_stack(
placements: Sequence[Placement], insert_dim: int = 0
) -> Sequence[Placement]:
# stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
# be normalized with the new Shard placement
normalized_placements: List[Placement] = []
for placement in placements:
if isinstance(placement, Shard) and placement.dim >= insert_dim:
normalized_placements.append(Shard(placement.dim + 1))
else:
normalized_placements.append(placement)
return normalized_placements
@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True))
def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
# normalize the dim to be within the common input ndim
dim = normalize_dim(dim, common_input_ndim)
follow_placements = _derive_follow_placements_from_tuple_strategy(
input_tuple_strategy
)
# create op strategy base on the follow placements
op_strategy = OpStrategy([])
input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
)
follow_placements = normalize_shard_for_stack(follow_placements, dim)
op_strategy.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(mesh, tuple(follow_placements)),
input_specs=input_specs,
)
)
return op_strategy
@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True))
def cat_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
# normalize the dim to be within the common input ndim
dim = normalize_dim(dim, common_input_ndim)
follow_placements = _derive_follow_placements_from_tuple_strategy(
input_tuple_strategy
)
# for cat we unshard the cat dim if it is sharded
follow_placements = unshard_tensor_dim(follow_placements, dim)
# create op strategy base on the follow placements
op_strategy = OpStrategy([])
input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
)
op_strategy.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(mesh, tuple(follow_placements)),
input_specs=input_specs,
)
)
return op_strategy
@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1))
def prop_index_select(op_schema: OpSchema) -> OutputSharding:
values_spec, dim, indices_spec = op_schema.args_schema
assert isinstance(values_spec, DTensorSpec)
assert isinstance(dim, int)
assert isinstance(indices_spec, DTensorSpec)
all_indices_spec: List[Optional[DTensorSpec]] = [
indices_spec if dim == i else None for i in range(values_spec.ndim)
]
result = prop_index(
OpSchema(
op=op_schema.op,
args_schema=(values_spec, all_indices_spec),
kwargs_schema=op_schema.kwargs_schema,
)
)
if result.redistribute_schema:
schema_suggestion = result.redistribute_schema
result.redistribute_schema = OpSchema(
op=op_schema.op,
args_schema=(
schema_suggestion.args_schema[0],
dim,
schema_suggestion.args_schema[1][dim],
),
kwargs_schema=op_schema.kwargs_schema,
)
return result
@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True))
def prop_index(op_schema: OpSchema) -> OutputSharding:
"""
Expect replicated on the first input; _mostly_ pointwise on the second input.
TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first.
"""
# Current sharding constraints:
# For values:
# 1. We currently require that the dimension of values_spec be replicated or partial
# if they are being indexed on.
# 2. Other dimensions of values_spec can remain sharded if they are so.
# For indices:
# Indices can be either sharded or replicated. All index tensors need to be sharded
# in a compatible way, following the pointwise rule (including resolving Partial
# into either sharded or replicated)
values_spec, multi_indices_spec = op_schema.args_schema
assert isinstance(values_spec, DTensorSpec)
assert isinstance(multi_indices_spec, list)
multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec)
valid_indices_spec: List[Tuple[int, DTensorSpec]] = [
(i, a) for i, a in enumerate(multi_indices_spec) if a is not None
]
# 1. All indices have to be sharded equally. Moreover, indices can be broadcast.
# Here, we piggyback on the pointwise sharding rule for indices.
indices_out = pointwise_rule(
OpSchema(
op=op_schema.op,
args_schema=tuple(v[1] for v in valid_indices_spec),
kwargs_schema={},
)
)
need_reshard_on_indices = indices_out.output_spec is None
if not need_reshard_on_indices:
# this means that our inputs are already sharded properly and we will use that as our indices_spec
assert isinstance(indices_out.output_spec, DTensorSpec)
indices_spec: DTensorSpec = indices_out.output_spec
else:
assert indices_out.redistribute_schema is not None
valid_indices_suggestion = indices_out.redistribute_schema
for i, v in enumerate(valid_indices_suggestion.args_spec):
multi_indices_spec[valid_indices_spec[i][0]] = v
# we'll need to call pointwise_rule again to see what's our ideal indices_spec and then
# use that to compute our ideal values_spec
indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec
assert isinstance(indices_output_spec, DTensorSpec)
indices_spec = indices_output_spec
lookup_dims = {v[0] for v in valid_indices_spec}
need_reshard_on_values = tuple(
(isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard)))
for vp, ip in zip(values_spec.placements, indices_spec.placements)
)
if not need_reshard_on_indices and not any(need_reshard_on_values):
value_placements = values_spec.placements
all_dims_consecutive = all(
b[0] - a[0] == 1
for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1])
)
if all_dims_consecutive:
# if all index vectors are consecutives, insert at the dimension of the first index
insert_dim: int = valid_indices_spec[0][0]
else:
# else, insert on the first dimension
insert_dim = 0
def place(vp: Placement, ip: Placement) -> Placement:
if isinstance(vp, Shard):
return Shard(
vp.dim
if vp.dim < insert_dim
# accounts for the offset in output dimensions
else vp.dim
+ indices_spec.ndim
- sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec)
)
if isinstance(ip, Shard):
return Shard(ip.dim + insert_dim)
# Partial or Replicated
return vp
value_placements = tuple(
place(vp, ip)
for vp, ip in zip(values_spec.placements, indices_spec.placements)
)
result = OutputSharding(
output_spec=DTensorSpec(
mesh=values_spec.mesh,
placements=value_placements,
)
)
return result
else:
result = OutputSharding(
output_spec=None,
redistribute_schema=OpSchema(
op=op_schema.op,
args_schema=(
DTensorSpec(
mesh=values_spec.mesh,
placements=tuple(
[
Replicate() if need_reshard_on_values[i] else v
for i, v in enumerate(values_spec.placements)
]
),
tensor_meta=values_spec.tensor_meta,
),
multi_indices_spec,
),
kwargs_schema=op_schema.kwargs_schema,
),
)
return result
@register_prop_rule(
[
aten.split.Tensor,
aten.split_with_sizes.default,
aten.split_with_sizes_copy.default,
],
schema_info=RuntimeSchemaInfo(1),
)
def split_rule(op_schema: OpSchema) -> OutputSharding:
output_spec_list: List[DTensorSpec] = []
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
ndim = input_spec.ndim
split_size_or_sections = op_schema.args_schema[1]
dim = cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0
dim = normalize_dim(dim, ndim)
# TODO: tensor to split cannot have Partial
# in its placements for now. Will need to
# support in future.
if input_spec.sums:
raise NotImplementedError(
f"splitting distributed tensor with "
f"Partial placement is not implemented!\n"
f"DTensorSpec={input_spec}"
)
# TODO: just like slice op, split replicates before
# splitting on a sharded dimension
need_reshard = False
if is_tensor_dim_sharded(input_spec, dim=dim):
need_reshard = True
input_spec = DTensorSpec(
mesh=input_spec.mesh,
placements=unshard_tensor_dim(input_spec.placements, dim=dim),
tensor_meta=input_spec.tensor_meta,
)
if need_reshard:
return OutputSharding(
None,
redistribute_schema=OpSchema(
op=op_schema.op,
args_schema=(input_spec,) + op_schema.args_schema[1:],
kwargs_schema=op_schema.kwargs_schema,
),
)
def size_split(N, i):
# Last chunk will be smaller if the tensor size N
# along the given dimension dim is not divisible by i.
assert i > 0
return [i] * (N // i) + ([N % i] if N % i != 0 else [])
output_size_list = (
size_split(input_spec.shape[dim], split_size_or_sections)
if isinstance(split_size_or_sections, int)
else split_size_or_sections
)
output_spec_list = [
DTensorSpec(
mesh=input_spec.mesh,
placements=input_spec.placements,
)
for _ in range(len(output_size_list))
]
return OutputSharding(output_spec_list)

View File

@ -0,0 +1,666 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import (
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
from torch import Tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
)
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
normalize_dim,
normalize_dims,
prod,
register_op_strategy,
)
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
aten = torch.ops.aten
Shape = Tuple[int, ...]
@dataclass
class DimSpec:
"""Specifies how an output dimension maps to an input dimension."""
def inputs(self) -> Iterable["DimSpec"]:
return ()
# Rules that map each dimension of the output to dimensions of the input tensor
DimMap = Tuple[DimSpec, ...]
@dataclass
class Singleton(DimSpec):
"""Output dimension is a singleton."""
@dataclass
class InputDim(DimSpec):
"""Output dimension maps directly to an input dimension."""
input_dim: int
@dataclass
class Broadcast(DimSpec):
"""Output is the broadcast of a singleton input dimension."""
dim: DimSpec
dim_size: int
@classmethod
def new(cls, dim: DimSpec, dim_size: int) -> DimSpec:
return Broadcast(dim, dim_size)
def inputs(self) -> Iterable[DimSpec]:
return (self.dim,)
@dataclass
class NewDim(DimSpec):
"""This is a new dimension created by the op."""
size: int
@classmethod
def new(cls, size: int) -> DimSpec:
return Singleton() if size == 1 else NewDim(size)
@dataclass
class Repeat(DimSpec):
"""Output dimension is the input dimension repeated n-times."""
input_dim: DimSpec
times: int
@classmethod
def new(cls, dim: DimSpec, times: int) -> DimSpec:
if times == 1:
return dim
elif isinstance(dim, Singleton):
# repeating a singleton is the same as broadcasting it
return Broadcast(dim, times)
else:
return Repeat(dim, times)
def inputs(self) -> Iterable[DimSpec]:
return (self.input_dim,)
@dataclass
class Flatten(DimSpec):
"""Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output."""
input_dims: Sequence[DimSpec]
@classmethod
def new(cls, dims: Sequence[DimSpec]) -> DimSpec:
if len(dims) == 0:
# flattening a scalar leads to a singleton
return Singleton()
elif len(dims) == 1:
# flattening a single dimension is no-op
return dims[0]
else:
return Flatten(dims)
def inputs(self) -> Iterable[DimSpec]:
return self.input_dims
@dataclass
class Split(DimSpec):
"""
This dimension is a member of a decomposition of the input dim.
Note that input_dim itself could be a Flattened set of input dims.
"""
input_dim: DimSpec
group_shape: Shape
split_id: int
@classmethod
def new(cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int) -> DimSpec:
assert len(group_shape) > 0
if len(group_shape) == 1:
# not really a group, just return the input dim back
assert idx == 0
return dim
elif group_shape[idx] == 1:
return Singleton()
else:
# remove singletons from group
# group_mapping = [(new_index, (shape, old_index)) ...]
group_mapping = list(
enumerate((s, i) for i, s in enumerate(group_shape) if s != 1)
)
new_group_shape = tuple(m[1][0] for m in group_mapping)
new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0]
return Split(dim, new_group_shape, new_idx)
def inputs(self) -> Iterable[DimSpec]:
return (self.input_dim,)
def dim_pad_left(ndim: int, min_dims: int) -> DimMap:
return (Singleton(),) * max(0, min_dims - ndim) + tuple(
InputDim(i) for i in range(ndim)
)
def dim_atleast_3d(ndim: int) -> DimMap:
if ndim == 0:
return (Singleton(), Singleton(), Singleton())
elif ndim == 1:
return (Singleton(), InputDim(0), Singleton())
elif ndim == 2:
return (InputDim(0), InputDim(1), Singleton())
else:
return tuple(InputDim(i) for i in range(ndim))
def expand(input_shape: Shape, shape: Shape) -> DimMap:
"""Implement broadcast on multiple dimensions."""
assert len(shape) >= len(input_shape)
# 1. create padded input dimensions
padded_input = dim_pad_left(len(input_shape), len(shape))
# 2. check that input shapes are compatible
mapping = []
for p, desired_s in zip(padded_input, shape):
if isinstance(p, Singleton):
actual_s = 1
assert desired_s >= 0
else:
assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}"
actual_s = input_shape[p.input_dim]
assert actual_s == 1 or desired_s == -1 or desired_s == actual_s
mapping.append(
p
if desired_s in (1, -1) or desired_s == actual_s
else Broadcast.new(p, desired_s)
)
return tuple(mapping)
def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape:
if isinstance(sizes[0], int):
return cast(Shape, sizes)
elif len(sizes) == 1:
return sizes[0]
else:
raise RuntimeError("Size must be int... or tuple")
def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
if ndim == 0:
return (Singleton(),)
elif ndim == 1:
return (InputDim(0),)
else:
# only flattening dims from start_dim to end_dim (inclusive)
# other dims are passed through
if end_dim < 0:
end_dim += ndim
results: List[DimSpec] = [InputDim(i) for i in range(start_dim)]
results.append(
Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1)))
)
results.extend([InputDim(i) for i in range(end_dim + 1, ndim)])
return tuple(results)
def dim_movedim(
ndim: int,
input: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
) -> DimMap:
input = normalize_dims(input, ndim)
destination = normalize_dims(destination, ndim)
assert len(input) == len(destination)
input_set = set(input)
assert len(input_set) == len(input), "Found repeated input dims"
assert len(set(destination)) == len(destination), "Found repeated output dims"
assert max(input) < ndim
assert max(destination) < ndim
dest = [-1] * ndim
for i, d in zip(input, destination):
dest[d] = i
unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set)
for i in range(ndim):
if dest[i] == -1:
dest[i] = next(unused_inputs_iter)
return tuple(InputDim(i) for i in dest)
def dim_repeat(ndim: int, sizes: Shape) -> DimMap:
sizes = normalize_sizes(sizes)
assert (
len(sizes) >= ndim
), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}."
pad = len(sizes) - ndim
return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple(
Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:])
)
def infer_size(total_size: int, sizes: Shape) -> Shape:
"""
One dimension input to view may be "-1".
Infer the size of this dimension given the total_size.
"""
infers = [i for i, s in enumerate(sizes) if s == -1]
size = prod(sizes)
assert len(infers) <= 1, "can only infer one size"
if infers:
size = -size
missing_size = total_size // size
assert (
total_size % size == 0
), f"size inferred for -1 is not integral {sizes} should have {total_size} elements."
return tuple(s if s != -1 else missing_size for s in sizes)
assert size == total_size, f"sizes do not match {total_size} vs {size}"
return sizes
def view_groups(from_size: Shape, to_size: Shape) -> DimMap:
"""
Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension.
A view or reshape operation can be decomposed into a set of 3 types of smaller operations:
1) Forward a dimension from input to output
2) Flatten a set of dimensions into a single dimension
3) Split one dimension into multiple dimensions
view_groups identifies these operations and returns, for each output dimension, what
is operation was performed in the input dimension. For example:
view_groups([2, 3, 4], [2, 12]) -> (
InputDim(0),
Flatten((InputDim(1), InputDim(2)))
)
- ouptut dimension 0 maps to input dimension 0
- output dimension 1 maps to a flattened input dimensions 1 and 2
view_groups([2, 3], [3, 2]) -> (
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0),
Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1),
)
- in the above, input is flattened into a single dimension and then split
into two separate dimensions with different sizes from the input.
"""
from_nelem = prod(from_size)
to_size = infer_size(from_nelem, normalize_sizes(to_size))
assert from_nelem == prod(to_size), "Total view shape does not add up"
from_idx = 0
to_idx = 0
from_len = len(from_size)
to_len = len(to_size)
result_pp = []
while from_idx < from_len or to_idx < to_len:
from_group_dim, to_group_shape = [], []
if from_idx >= from_len:
f = 1
else:
f = from_size[from_idx]
from_group_dim.append(from_idx)
from_idx += 1
if to_idx >= to_len:
t = 1
else:
t = to_size[to_idx]
to_group_shape.append(t)
to_idx += 1
# if any of the groups is singleton, great, we need to backtrack though
if f == 1 and t != 1:
# produces ([1], [])
to_idx -= 1
to_group_shape = []
elif f != 1 and t == 1:
# produces ([], [1])
from_idx -= 1
from_group_dim = []
else:
# produces ([1], [1]), ([2], [2]), ([2,3], [6])
while f != t:
if f < t:
nf = from_size[from_idx]
from_group_dim.append(from_idx)
from_idx += 1
f *= nf
else:
nt = to_size[to_idx]
to_group_shape.append(nt)
to_idx += 1
t *= nt
if len(to_group_shape) > 0:
flattened = Flatten.new(
tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1)
)
result_pp += [
Split.new(flattened, tuple(to_group_shape), i)
for i in range(len(to_group_shape))
]
return tuple(result_pp)
def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap:
if len(dims) < ndim:
dims = (1,) * (ndim - len(dims)) + dims
return dim_repeat(ndim, dims)
def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap:
dim1 = normalize_dim(dim1, ndim)
dim2 = normalize_dim(dim2, ndim)
assert dim1 < ndim
assert dim2 < ndim
dimmap = [InputDim(i) for i in range(ndim)]
swapdim = dimmap[dim1]
dimmap[dim1] = dimmap[dim2]
dimmap[dim2] = swapdim
return tuple(dimmap)
def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap:
# FIXME: this is wrong when dim=None and one of the dimensions
# equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could
# end up as squeeze(tensor(1)) if we have 4 devices; this would lead to
# removal of a dimension that is not actually a singleton.
return tuple(
InputDim(i)
for i, s in enumerate(shape)
if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape)))
)
def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
dims = tuple(InputDim(i) for i in range(ndim))
if dim < 0:
dim += ndim + 1
return dims[:dim] + (Singleton(),) + dims[dim:]
def dim_view_as_real(shape: Shape) -> DimMap:
ndim = len(shape)
results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
# each complex number is split into two real numbers,
# resulting in one more dimension of size 2
results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0))
results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1))
return tuple(results)
def dim_reduction(
ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool
) -> DimMap:
"""
General fallback for reduction ops where Partial() does not apply.
This will cause incoming tensor to be replicated on the reducing dimensions.
"""
if dim_or_dims is None:
dim_or_dims = tuple(range(ndim))
if isinstance(dim_or_dims, int):
dim_or_dims = (dim_or_dims,)
dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims)
return tuple(
InputDim(i) if i not in dim_or_dims else Singleton()
for i in range(ndim)
if i not in dim_or_dims or keepdim
)
dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
torch.broadcast_to: lambda input, shape: expand(input.shape, shape),
Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
torch.flatten: lambda tensor: dim_flatten(tensor.ndim),
torch.movedim: lambda input, source, destination: dim_movedim(
input.ndim, source, destination
),
torch.permute: lambda input, dims: tuple(
InputDim(i) for i in normalize_dims(dims, input.ndim)
),
torch.ravel: lambda tensor: dim_flatten(tensor.ndim),
Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes),
torch.reshape: lambda input, shape: view_groups(input.shape, shape),
torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim),
torch.tile: lambda input, dims: dim_tile(input.ndim, dims),
torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1),
torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim),
Tensor.view: lambda input, *shape: view_groups(input.shape, shape),
torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2),
torch.view_as_real: lambda input: dim_view_as_real(input.shape),
}
def propagate_shape_and_sharding(
input_src_placements: Sequence[Placement],
local_in_shape: Shape,
rule: DimMap,
mesh_sizes: Shape,
) -> Tuple[Sequence[Placement], Sequence[Placement]]:
"""
Determine input target sharding and output sharding based on
given global tensor shape and input source sharding.
Sharding propagation follows mapped dimensions:
- An output dimension that maps directly to an input dimension is sharded equally
- An output dimension that is a flattened set of input dimensions can only be
sharded if only the leftmost flattened dimension is sharded.
- An output dimension that is a split of the input dimension can only be sharded
if the leftmost split size is divisible by the mesh dimension
"""
assert len(input_src_placements) == len(mesh_sizes)
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
mesh_ndim = len(mesh_sizes)
shardable_dims: Dict[int, List[bool]] = {}
# in case an input dimension disappears (e.g. collapsing, reduction)
# we cannot shard in that dimension (we need a replication fall-back rule)
seen_input_dims: Set[int] = set()
def collect_used_inputs(cmd: DimSpec) -> None:
if isinstance(cmd, InputDim):
seen_input_dims.add(cmd.input_dim)
for inp in cmd.inputs():
collect_used_inputs(inp)
for cmd in rule:
collect_used_inputs(cmd)
for dim in range(len(local_in_shape)):
shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
if isinstance(cmd, InputDim):
return cmd
elif isinstance(cmd, Flatten):
for dim in cmd.input_dims[1:]:
if isinstance(dim, InputDim):
shardable_dims[dim.input_dim] = [False] * mesh_ndim
dim0 = cmd.input_dims[0]
return dim0 if isinstance(dim0, InputDim) else None
elif isinstance(cmd, Split):
in_dim = get_in_dim_to_shard(cmd.input_dim)
out_size = cmd.group_shape[cmd.split_id]
if cmd.split_id == 0 and in_dim is not None:
# we need to check that the input dimension is divisible
# by the size of the submesh we're sharding it on
# NOTE: it would be possible to shard the same input dimension
# on more than one mesh dimension. In that case, the dimension
# needs to be divisible by the product of mesh sizes.
# In order to keep the problem more tractable, we will not consider
# double resharding as a suggestion (e.g. [Shard(0), Shard(0) ])
# but we will allow it if that's the input and it's compatible
# 1. is this dimension shardable on each individual mesh dim?
shardable_dims[in_dim.input_dim] = [
out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes
]
# 2. here we special case things like [Shard(0), Shard(0)]
submesh_size = 1
for size, shard in zip(mesh_sizes, input_src_placements):
if isinstance(shard, Shard) and shard.dim == in_dim:
submesh_size *= size
assert (
out_size % submesh_size == 0
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
# we will only shard our first component of the split
return in_dim if cmd.split_id == 0 else None
elif isinstance(cmd, Repeat):
in_dim = get_in_dim_to_shard(cmd.input_dim)
if in_dim is not None:
shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
return None
else:
return None
# for each output dim, find the corresponding input dim in terms of sharding prop
shard_dim_map = {}
for dim, cmd in enumerate(rule):
in_dim = get_in_dim_to_shard(cmd)
if in_dim is not None:
shard_dim_map[in_dim.input_dim] = dim
input_tgt_placements = [
Replicate()
if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]
else p
for mesh_dim, p in enumerate(input_src_placements)
]
output_placements = [
Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p
for p in input_tgt_placements
]
return input_tgt_placements, output_placements
def register_op_strategy_map(
aten_op_overload: torch._ops.OpOverload,
local_op_name: Callable[..., torch.Tensor],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> None:
dim_map: Callable[..., DimMap] = dim_maps[local_op_name]
@register_op_strategy(aten_op_overload, schema_info=schema_info)
def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
global_in_shape = input_strategy.shape
assert global_in_shape is not None, "Shape required."
output_strategy = OpStrategy([])
for input_placement_strategy in input_strategy.strategies:
input_src_spec = input_placement_strategy.output_spec
input_tgt_placements, output_placements = propagate_shape_and_sharding(
input_src_spec.placements,
tuple(global_in_shape),
rules,
mesh.shape,
)
# TODO: optimize this. we shouldn't simply blindly replicate
# unshardable dims ...
# FIXME: this can be wrong for situations where we have
# [Shard(0), Shard(0)]
input_tgt_spec = DTensorSpec(
placements=tuple(input_tgt_placements),
mesh=input_src_spec.mesh,
tensor_meta=input_src_spec.tensor_meta,
)
redistribute_costs = [
generate_redistribute_costs(input_strategy, input_tgt_spec)
]
output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements))
output_strategy.strategies.append(
PlacementStrategy(
output_specs=output_spec,
input_specs=(input_tgt_spec,),
redistribute_cost=redistribute_costs,
)
)
return output_strategy
register_op_strategy_map(aten.squeeze.default, torch.squeeze)
register_op_strategy_map(
aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex)
register_op_strategy_map(aten.view_as_real.default, torch.view_as_real)

View File

@ -0,0 +1,280 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import itertools
import operator
from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor._op_schema import (
OpSchema,
OpStrategy,
PlacementList,
PlacementStrategy,
RuntimeSchemaInfo,
)
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
# convenient wrapper to register sharding propagation rules
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def register_prop_rule(op, schema_info=None):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def wrapper(impl):
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule(
overload, impl, schema_info
)
return impl
return wrapper
def register_op_strategy(op, schema_info=None):
# pyre-fixme[53]: Captured variable `func` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
# thus, we will detect ATen schemas with any of these args and ensure
# that they get specialized here.
arg_names_that_require_specializing_cache_strategy = [
"memory_format",
]
def wrapper(impl):
if isinstance(op, list):
overloads = op
else:
overloads = [op]
for overload in overloads:
curr_schema_info = None
if schema_info is None:
specialized_args = [
a.name
for a in overload._schema.arguments
if a.name in arg_names_that_require_specializing_cache_strategy
]
if any(specialized_args):
curr_schema_info = RuntimeSchemaInfo(
static_kwargkey=specialized_args
)
else:
curr_schema_info = schema_info
DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
overload, impl, curr_schema_info
)
return impl
return wrapper
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.
if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
return x
else:
return [x]
def normalize_dim(dim: int, ndim: int) -> int:
return dim if dim >= 0 else dim + ndim
def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]:
"""Normalize a dim or a sequence of dims, so that they are all positive."""
if isinstance(dims, int):
dims = (normalize_dim(dims, ndim),)
elif isinstance(dims, list):
dims = [normalize_dim(dim, ndim) for dim in dims]
elif isinstance(dims, tuple):
dims = tuple([normalize_dim(dim, ndim) for dim in dims])
return dims
def prod(xs: Iterable[int]) -> int:
return functools.reduce(operator.mul, xs, 1)
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shards_map[shard_dim] *= spec.mesh.size(i)
for i, dim_size in enumerate(shape):
# TODO: maybe we should determine is_shardable based on
# whether it's evenly sharded or not
if shards_map[i] > 1 and dim_size < shards_map[i]:
return False
return True
def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is evenly shardable according to the spec."""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shards_map[shard_dim] *= spec.mesh.size(i)
for i, dim_size in enumerate(shape):
if shards_map[i] > 1 and (dim_size % shards_map[i] != 0):
return False
return True
def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
"""Return True if tensor dim is sharded."""
return any(p.is_shard(dim) for p in spec.placements)
def is_tensor_partial(spec: DTensorSpec) -> bool:
"""Return True if tensor is partial on the mesh."""
return any(p.is_partial() for p in spec.placements)
def infer_broadcast_dims_map(
common_shape: torch.Size, input_shape: torch.Size
) -> List[int]:
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
# this is aligned with the broadcast semantics
common_ndim = len(common_shape)
input_ndim = len(input_shape)
broadcast_dims_map = [-1] * common_ndim
for idx in range(-1, -1 - input_ndim, -1):
if input_shape[idx] == common_shape[idx]:
broadcast_dims_map[common_ndim + idx] = input_ndim + idx
return broadcast_dims_map
def map_placements_after_broadcast(
placements: Tuple[Placement, ...],
shape: torch.Size,
broadcast_dims_map: List[int],
) -> Tuple[Placement, ...]:
"""Map each placement based on the output shape after broadcast."""
new_placements: List[Placement] = []
for placement in placements:
if isinstance(placement, (Replicate, Partial)):
new_placements.append(placement)
else:
assert isinstance(placement, Shard)
shard_dim = normalize_dim(placement.dim, len(shape))
new_shard_dim = broadcast_dims_map[shard_dim]
if new_shard_dim != -1:
# there's a map from the common shape shard dim to
# the input shape shard dim before broadcasting,
# use that instead
new_placements.append(Shard(new_shard_dim))
else:
# there's no map between common shape shard dim and
# the input shape shard dim before broadcasting,
# in this case it means implicit broadcasting happen
# in this dim, so we can just mark it as replicate
# and implict broadcast will broadcast automatically
# to the sharded shape
new_placements.append(Replicate())
return tuple(new_placements)
def generate_redistribute_costs(
src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> List[float]:
redistribute_costs: List[float] = []
for strat in src_strategy.strategies:
redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
return redistribute_costs
def expand_to_full_mesh_op_strategy(
mesh: DeviceMesh,
op_schema: OpSchema,
single_mesh_dim_strategies: List[PlacementList],
*,
input_index: int = 1,
inplace_op: bool = False,
) -> OpStrategy:
# Expand the single_mesh_dim_strategies to full mesh dim strategies.
all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list: List[Optional[DTensorSpec]] = []
for specs in zip(*strategy_comb):
if specs[0] is not None:
spec_list.append(DTensorSpec(mesh, specs))
else:
spec_list.append(None)
input_specs: List[DTensorSpec] = [
s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
]
input_args_strategy = op_schema.args_strategy
assert len(input_specs) == len(input_args_strategy)
self_spec = input_args_strategy[0].strategies[0].output_spec
if inplace_op and self_spec.placements != input_specs[0].placements:
# if it's inplace op, we would only allow the placement strategy to be added when the
# input_spec matches the first argument's runtime sharding, otherwise we skip
continue
# check inputs shardable
inputs_shardable = all(
is_tensor_shardable(inp.shape, s)
for inp, s in zip(input_args_strategy, input_specs)
)
# only add to the all_strategies list when all inputs are shardable
if inputs_shardable:
redistribute_cost = [
generate_redistribute_costs(input_strategy, input_spec)
for input_strategy, input_spec in zip(input_args_strategy, input_specs)
]
if input_index > 1:
output_specs = tuple(spec_list[:input_index])
else:
if spec_list[0] is not None:
output_specs = spec_list[0] # type: ignore[assignment]
else:
raise RuntimeError("output spec is None")
strategy = PlacementStrategy(
output_specs=output_specs,
input_specs=input_specs,
redistribute_cost=redistribute_cost,
)
all_strategies.append(strategy)
return OpStrategy(all_strategies)

View File

@ -0,0 +1,381 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import warnings
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import Shard
__all__ = [
"is_rng_supported_mesh",
"manual_seed",
"OffsetBasedRNGTracker",
"TensorParallelRNGTracker",
]
_rng_tracker: Optional["_RNGStateTracker"] = None
def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
"""Checks if the current device of ``device_mesh`` supports DTensor's random APIs.
Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest
users call this API to test the availability before using our random APIs.
Args:
device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the
random ops APIs are supported.
Returns:
A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise.
.. warning::
Currently we only support correct RNG on cuda/cuda-like devices.
"""
device_handle = _get_device_handle(device_mesh.device_type)
if device_handle and hasattr(device_handle, "set_rng_state"):
return True
else:
# TODO: Logs way too much
warnings.warn(
f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh"
)
return False
def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
"""Sets the seed for generating random numbers for the calling rank.
Args:
seed (int): The desired seed.
device_mesh (:class:`DeviceMesh`): The device mesh to set the seed.
Returns:
None
.. warning::
When calling this function, :func:`manual_seed` must be called from all ranks of the
default ``ProcessGroup`` even if some ranks may not be a part of the ``device_mesh``,
with the same ``seed`` value.
If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it,
``manual_seed`` will not set its GPU device's generator seed.
Current implementation only supports a GPU device mesh.
"""
device_handle = _get_device_handle(device_mesh.device_type)
if not device_handle:
raise NotImplementedError(
f"DTensor randomness only supports cuda/cuda-like device type, but got {device_mesh.device_type}"
)
# allgather the seed over the default PG
object_list = [seed] * dist.get_world_size()
dist.all_gather_object(object_list, seed)
for rank, object in enumerate(object_list):
if seed != int(object):
raise RuntimeError(
f"calling manual_seed function over {device_mesh} but received different seed values on ranks:",
f"seed on rank {dist.get_rank()} is {seed}, and seed on rank {rank} is {object}!",
)
# instantiate a RNG tracker if haven't. By default DTensor uses an
# OffsetBasedRNGTracker to perform random operators.
global _rng_tracker
if not _rng_tracker:
_rng_tracker = OffsetBasedRNGTracker(device_mesh.device_type)
# the current rank is in mesh
if device_mesh.get_coordinate() is not None:
if isinstance(_rng_tracker, TensorParallelRNGTracker):
_rng_tracker._manual_seed(device_mesh, seed)
elif isinstance(_rng_tracker, OffsetBasedRNGTracker):
_rng_tracker._manual_seed(seed)
else:
raise RuntimeError(
f"Unknown type of cuda RNG state tracker: _rng_tracker = {_rng_tracker}"
)
class _RNGStateTracker:
"""
_RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object)
in a dict, mapping from a corresponding tag to each state tensor. It also provides
a set of convenient utility methods to help access/modify the state tensors. The most
important interface is _distribute_region which will be used when DTensor executes
a random op (an operator that calls RNG).
"""
def __init__(self, device_type: str = "cuda"):
self._device_type = device_type
self._device_handle = _get_device_handle(device_type)
if not (self._device_handle and self._device_handle.is_available()):
raise RuntimeError(
f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
)
self._states: Dict[str, Tensor] = {}
self._devices = [self._device_handle.current_device()]
self._use_distribute_region = True
@property
def rng_states(self) -> Dict[str, Tensor]:
return self._states
@property
def distribute_region_enabled(self) -> bool:
return self._use_distribute_region
@distribute_region_enabled.setter
def distribute_region_enabled(self, value) -> None:
self._use_distribute_region = value
def rng_state_is_sync(self, name) -> bool:
return name in self.rng_states
def get_seed(self, name: str) -> int:
if name not in self.rng_states:
raise RuntimeError(
f"{self.__class__.__name__} does not have random state for {name}"
)
seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64)
return int(seed_tensor.item())
def set_seed(self, name: str, seed: int) -> None:
seed_tensor = torch.tensor([seed]).view(torch.uint8)
offset_tensor = torch.tensor([0]).view(torch.uint8)
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
def _distribute_region(self, spec: DTensorSpec):
pass
class OffsetBasedRNGTracker(_RNGStateTracker):
"""
This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
should be shared and synchronized among all ranks to respect the semantics of DTensor
random operators.
"""
def __init__(self, device_type: str = "cuda"):
super().__init__(device_type)
# synchronize RNG state using rank 0's current one
rng_state = self._device_handle.get_rng_state().to(device_type)
dist.broadcast(rng_state, 0)
self.rng_states["parallel-rng"] = rng_state.to("cpu")
def _manual_seed(self, parallel_seed: int) -> None:
self.set_seed("parallel-rng", parallel_seed)
@contextlib.contextmanager
def _distribute_region(self, spec: DTensorSpec):
# check if the parallel rng state has been synchronized or not
if not self.rng_state_is_sync("parallel-rng"):
raise RuntimeError(
"OffsetBasedRNGTracker requires the random state to be synchronized "
"before entering into a distribute region!"
)
if self.distribute_region_enabled:
old_offset = self.get_offset("parallel-rng")
self._set_pre_op_offset(spec)
with torch.random.fork_rng(self._devices, device_type=self._device_type):
self._device_handle.set_rng_state(self.rng_states["parallel-rng"])
try:
yield # execute the region code
finally:
# update offset to synchronize among ranks
self._set_post_op_offset(spec, old_offset)
else:
yield
def get_offset(self, name: str) -> int:
if name not in self.rng_states:
raise RuntimeError(
f"{self.__class__.__name__} does not have random state for {name}"
)
offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64)
return int(offset_tensor.item())
def set_offset(self, name: str, offset: int) -> None:
if name not in self.rng_states:
raise RuntimeError(
f"{self.__class__.__name__} does not have random state for {name}"
)
seed_tensor = (self.rng_states[name])[0:8]
offset_tensor = torch.tensor([offset]).view(torch.uint8)
self.rng_states[name] = torch.cat([seed_tensor, offset_tensor])
def _set_pre_op_offset(self, spec: DTensorSpec) -> None:
"""Set the starting RNG offset for current device's local shard before actual
op execution. The pre_op_offset value should start from the current RNG offset
and increment by the size of local shard until it reaches the size of the whole
DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset
will be the same.
Args:
spec (:class:`DTensorSpec`): the spec of the DTensor object on which
we prepare the offset for running random ops.
Returns:
None
.. warning::
Note that, current implementation does not consider DTensor's continguity.
Example:
take a DTensor of shape [8, 16] as an example. Assume that the DTensor
is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
and the mesh is:
[[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
``spec.mesh.get_coordinate()`` provides the coordinate of the current rank
in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
Another concept to introduce besides rank coordinate is shard coordinate.
Each rank holds a local shard of the DTensor. In the example, the DTensor
is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
That being said, the local shard on rank 0 and rank 2 correspond to the same
shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
(in the example, it will be a tuple (i, j) where shard (i, j) has the slice
DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
Once we have rank coordinate and shard coordinate, we can calculate on each rank
what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
(x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
Following this calculation,
rank 0 and rank 2 holds the shard of coord (0, 0);
rank 1 and rank 3 holds the shard of coord (0, 1);
rank 4 and rank 6 holds the shard of coord (1, 0);
rank 5 and rank 7 holds the shard of coord (1, 1);
The last value to calculate before obtaining the starting offset is the shard linear index.
The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
"""
dtensor_shape = spec.shape
mesh = spec.mesh
dim_map = spec.dim_map
# Compute shard coordinate:
# The coordinate on each tensor dim is a tuple (idx, range)
# If a DTensor is partitioned on its dim i into n shards, and the current rank
# holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
coordinate = mesh.get_coordinate()
assert coordinate is not None
shard_coord = [
coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dim_map
]
shard_size = [
mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dim_map
]
# compute shard linear index
shard_linear_idx = self._calc_shard_linear_idx(shard_coord, shard_size)
# compute starting offset using the first shard's size
local_size_on_rank_0 = list(dtensor_shape)
for idx, placement in enumerate(spec.placements):
if isinstance(placement, Shard):
mesh_dim_size = mesh.size(idx)
shard_dim = placement.dim
local_size_on_rank_0[shard_dim] = placement._local_shard_size_on_dim(
dtensor_shape[shard_dim],
mesh_dim_size,
0,
return_offset=False,
)[0]
from torch.distributed.tensor._ops.utils import prod
local_size = prod(local_size_on_rank_0)
# get current RNG offset
current_offset = self.get_offset("parallel-rng")
# pytorch: offset must be multiple of 4
# source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4
self.set_offset("parallel-rng", current_offset + offset_incr)
def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None:
"""Sets the RNG to a synchronized state after running the local random op. Every
rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is
the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor
random ops.
Args:
spec (:class:`DTensorSpec`): the spec of the DTensor object on which
we post-process the offset for running random ops.
Returns:
None
"""
dtensor_shape = spec.shape
from torch.distributed.tensor._ops.utils import prod
numel = prod(dtensor_shape)
# pytorch: offset must be multiple of 4
# source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
numel = (numel + 3) // 4 * 4
self.set_offset("parallel-rng", old_offset + numel)
def _calc_shard_linear_idx(
self, shard_coord: List[int], shard_size: List[int]
) -> int:
# compute shard linear index
shard_linear_idx = 0
shard_coord_stride = 1
for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
shard_linear_idx += idx * shard_coord_stride
shard_coord_stride *= size
return shard_linear_idx
class TensorParallelRNGTracker(_RNGStateTracker):
def __init__(self, device_type: str = "cuda"):
super().__init__(device_type)
# copy the default RNG state
self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()
def _manual_seed(
self,
tp_mesh: DeviceMesh,
base_seed: int = 1234,
):
tensor_parallel_rank = tp_mesh.get_local_rank()
# this magic number 2718 comes from Megatron's code
# (https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/core/tensor_parallel/random.py#L162-L163)
MegatronMagicNum = 2718
tensor_parallel_seed = base_seed + MegatronMagicNum + tensor_parallel_rank
self.set_seed("tensor-parallel-rng", tensor_parallel_seed)
@contextlib.contextmanager
def _distribute_region(self, spec: DTensorSpec):
# check if the tensor parallel rng state has been synchronized or not
if not self.rng_state_is_sync("tensor-parallel-rng"):
raise RuntimeError(
"TensorParallelRNGTracker requires the random state to be synchronized "
"before entering into a distribute region!"
)
if self.distribute_region_enabled:
with torch.random.fork_rng(self._devices, device_type=self._device_type):
self._device_handle.set_rng_state(
self.rng_states["tensor-parallel-rng"]
)
try:
yield
finally:
self.rng_states[
"tensor-parallel-rng"
] = self._device_handle.get_rng_state()
else:
yield

View File

@ -0,0 +1,351 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from functools import lru_cache
from typing import cast, List, NamedTuple, Tuple
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.tensor._api as dtensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
Partial,
Placement,
Replicate,
Shard,
)
logger = logging.getLogger(__name__)
class _TransformInfo(NamedTuple):
mesh_dim: int
src_dst_placements: Tuple[Placement, Placement]
# logical_shape on this mesh dimension
logical_shape: List[int]
@lru_cache(maxsize=None)
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
To transform from source to target placement it might have multiple steps, i.e. it
might decompose Si -> Sj into Si -> R -> Sj.
This would detect if there're mis-aligned/nested shardings between src/dst placements.
E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)),
in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
the latter is the first sharding on tensor dimension 0.
"""
transform_infos: List[_TransformInfo] = []
device_mesh = src_spec.device_mesh
my_coordinate = device_mesh.get_coordinate()
assert my_coordinate is not None
# logical shape records the logic tensor shape on the mesh dimension
# this is useful to ensure uneven sharding gets correct output shape
initial_logical_shape = list(src_spec.shape)
mesh_dims_to_logical_shape = [initial_logical_shape]
if device_mesh.ndim == 1:
# if device_mesh is 1D, redistribute is a simple direct transformation
transform_infos.append(
_TransformInfo(
mesh_dim=0,
src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]),
logical_shape=initial_logical_shape,
)
)
return transform_infos
# Handle multi-dim device mesh placement redistribution
# First, we need to build the logical shape for each mesh dim
# for correct allgathering uneven shards on each mesh dim (with dynamic padding)
for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)):
current_logical_shape = mesh_dims_to_logical_shape[i]
if isinstance(src, Shard):
if i < device_mesh.ndim - 1:
# calculate and save the logical shape for this sharding
mesh_dim_size = device_mesh.size(mesh_dim=i)
local_shard_size, _ = src._local_shard_size_on_dim(
current_logical_shape[src.dim],
mesh_dim_size,
my_coordinate[i],
)
new_logical_shape = list(current_logical_shape)
new_logical_shape[src.dim] = local_shard_size
mesh_dims_to_logical_shape.append(new_logical_shape)
else:
mesh_dims_to_logical_shape.append(current_logical_shape)
# Next, we need to derive the transform infos from src to dst placements,
# here we use a greedy search with step by step state transformations
current_placements = list(src_spec.placements)
target_placements = list(dst_spec.placements)
if src_spec.num_shards > 1:
# If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec
# a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))).
# In those cases, we first traverse from inner placement to outer placement
# to detect misaligned shardings and properly replicate nested sharding first.
for mesh_dim in reversed(range(len(current_placements))):
current = current_placements[mesh_dim]
target = target_placements[mesh_dim]
# If target is not Shard, we can directly redistribute since we are traversing from innner
# to outer placements here
if isinstance(target, Shard):
# If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim
shard_dim = target.dim
current_mesh_sharding, target_mesh_sharding = [], []
for i, (s, p) in enumerate(zip(current_placements, target_placements)):
if i >= mesh_dim:
break
if s.is_shard(shard_dim):
current_mesh_sharding.append(i)
if p.is_shard(shard_dim):
target_mesh_sharding.append(i)
if current_mesh_sharding != target_mesh_sharding:
# if current/target_placements have misaligned sharding on the tensor dim BEFORE the current
# mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding
target = Replicate()
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
# We always traverse from outer placement to inner placement to collect the remaining
# needed transform infos (i.e. the replication from nested sharding might need to further
# perform resharding to Shard again)
for mesh_dim, (current, target) in enumerate(
zip(current_placements, target_placements)
):
if current != target:
transform_infos.append(
_TransformInfo(
mesh_dim=mesh_dim,
src_dst_placements=(current, target),
logical_shape=mesh_dims_to_logical_shape[mesh_dim],
)
)
current_placements[mesh_dim] = target
return transform_infos
def redistribute_local_tensor(
local_tensor: torch.Tensor,
current_spec: DTensorSpec,
target_spec: DTensorSpec,
*,
async_op: bool = False,
is_backward: bool = False,
) -> torch.Tensor:
"""
This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to
the target DTensorSpec, which involves the necessary collective calls to transform
the local shard of the DTensor from its current spec to the target spec.
"""
if current_spec.mesh != target_spec.mesh:
# TODO: alltoall/permute reshuffling to change device_mesh if they are not the same
raise NotImplementedError("Cross device mesh comm not supported yet!")
new_local_tensor = None
device_mesh = current_spec.mesh
my_coordinate = device_mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we skip redistribute and simply return local_tensor,
# which should be an empty tensor
return local_tensor
transform_infos = _gen_transform_infos(current_spec, target_spec)
for transform_info in transform_infos:
i = transform_info.mesh_dim
current, target = transform_info.src_dst_placements
num_chunks = device_mesh.size(mesh_dim=i)
if current == target:
# short cut, just use the original local tensor
new_local_tensor = local_tensor
continue
logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i)
if target.is_replicate():
# Case 1: target is Replicate
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_value(
local_tensor, device_mesh, i
)
elif current.is_shard():
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
elif target.is_shard():
# Case 2: target is Shard
target_placement = cast(Shard, target)
target_dim = target_placement.dim
if current.is_partial():
partial_spec = cast(Partial, current)
new_local_tensor = partial_spec._reduce_shard_value(
local_tensor, device_mesh, i, target_placement
)
elif current.is_replicate():
# split the tensor and return the corresponding cloned local shard
new_local_tensor = target_placement._replicate_to_shard(
local_tensor, device_mesh, i, my_coordinate[i]
)
else:
assert (
current.is_shard()
), f"Current placement should be shard but found {current}"
shard_spec = cast(Shard, current)
if shard_spec.dim != target_placement.dim:
new_local_tensor = shard_spec._to_new_shard_dim(
local_tensor,
device_mesh,
i,
transform_info.logical_shape,
target_placement.dim,
)
elif target.is_partial():
if current.is_replicate():
partial_spec = cast(Partial, target)
# skip the replicate to partial transformation when we are in backward pass
# In this case we keep the grad as replicate, this is because we don't
# want to convert the replicated gradients back to partial, although
# that's logically conform with the same layout, converting the gradients
# back to partial is actually useless as you would have to do reduce later
# which would be more expensive than keeping it replicate! For this reason,
# we keep the replicate grad here.
new_local_tensor = (
partial_spec._partition_value(local_tensor, device_mesh, i)
if not is_backward
else local_tensor
)
elif current.is_shard():
if not is_backward:
raise RuntimeError(
f"redistribute from {current} to {target} not supported yet"
)
# for backward shard -> partial, we just need to convert the shard to replicate
current_placement = cast(Shard, current)
new_local_tensor = current_placement._to_replicate_tensor(
local_tensor, device_mesh, i, transform_info.logical_shape
)
else:
# partial -> partial no op, should never hit
new_local_tensor = local_tensor
assert new_local_tensor is not None
local_tensor = new_local_tensor
assert new_local_tensor is not None, "redistribute failed!"
if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor):
new_local_tensor = new_local_tensor.wait()
return new_local_tensor
class Redistribute(torch.autograd.Function):
@staticmethod
def forward( # type: ignore[override]
# pyre-fixme[2]: Parameter must be annotated.
ctx,
input: "dtensor.DTensor",
device_mesh: DeviceMesh,
placements: Tuple[Placement, ...],
async_op: bool = False,
):
current_spec = input._spec
ctx.current_spec = current_spec
ctx.async_op = async_op
if current_spec.placements != placements:
target_spec = DTensorSpec(
device_mesh, placements, tensor_meta=input._spec.tensor_meta
)
local_tensor = input._local_tensor
output = redistribute_local_tensor(
local_tensor, current_spec, target_spec, async_op=async_op
)
else:
# use the same local tensor if placements are the same.
output = input._local_tensor
target_spec = current_spec
return dtensor.DTensor(
output,
target_spec,
requires_grad=input.requires_grad,
)
@staticmethod
def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
previous_spec = ctx.current_spec
current_spec = grad_output._spec
async_op = ctx.async_op
local_tensor = grad_output._local_tensor
output = redistribute_local_tensor(
local_tensor,
current_spec,
previous_spec,
async_op=async_op,
is_backward=True,
)
# normalize the target placement to replicate if it is partial
normalized_placements: List[Placement] = []
for previous_placement in previous_spec.placements:
if previous_placement.is_partial():
# keep target placement to replicate instead of partial in this case
normalized_placements.append(Replicate())
else:
normalized_placements.append(previous_placement)
spec = DTensorSpec(
previous_spec.device_mesh,
tuple(normalized_placements),
tensor_meta=TensorMeta(
shape=grad_output.shape,
stride=grad_output.stride(),
dtype=grad_output.dtype,
),
)
output_dtensor = dtensor.DTensor(
output,
spec,
requires_grad=grad_output.requires_grad,
)
return (
output_dtensor,
None,
None,
None,
)

View File

@ -0,0 +1,497 @@
# mypy: allow-untyped-defs
import threading
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
OpInfo,
OpSchema,
OpStrategy,
OutputSharding,
OutputSpecType,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
TupleStrategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape,
compute_local_stride,
try_find_mesh_from_args,
)
aten = torch.ops.aten
def _length(obj) -> int:
if obj is None:
return 0
if not isinstance(obj, Sequence):
return 1
return len(obj)
class LocalLRUCache(threading.local):
def __init__(self, user_function: Callable) -> None:
self.cache = lru_cache(None)(user_function)
def __call__(self, *args, **kwargs) -> object:
return self.cache(*args, **kwargs)
def cache_info(self):
return self.cache.cache_info()
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
OpOverload,
Callable[[DeviceMesh, OpSchema], StrategyType],
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
self.propagate_op_sharding = LocalLRUCache(
self.propagate_op_sharding_non_cached
)
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
self.op_to_shape_and_stride_idx: Dict[
OpOverload, Union[int, Tuple[int, int]]
] = {
# new factory ops
aten.new_empty.default: 1,
aten.new_full.default: 1,
aten.new_ones.default: 1,
aten.new_zeros.default: 1,
aten.new_empty_strided.default: (1, 2),
# view ops
aten.expand.default: 1,
aten.reshape.default: 1,
aten.view.default: 1,
aten._unsafe_view.default: 1,
}
def register_sharding_prop_rule(
self,
op_overload: OpOverload,
rule_func: Callable[[OpSchema], OutputSharding],
schema_info: Optional[RuntimeSchemaInfo] = None,
):
"""
Register a sharding propagation rule for an operator.
"""
self.op_to_rules[op_overload] = rule_func
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info
def register_op_strategy(
self,
op_overload: OpOverload,
strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType],
schema_info: Optional[RuntimeSchemaInfo] = None,
):
"""
Register a sharding strategy generator for an operator.
"""
self.op_strategy_funcs[op_overload] = strategy_func
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info
@lru_cache # noqa: B019
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas
"""
if op_schema.op == aten.equal.default:
# data dependent ops can't be used for fake propagation
return None
# NOTE: We must call the tracing in fake tensor mode so that it
# avoids materializing memory
with FakeTensorMode():
fake_args = op_schema.gen_fake_args()
fake_kwargs = op_schema.gen_fake_kwargs()
fake_out = op_schema.op(*fake_args, **fake_kwargs)
if isinstance(fake_out, torch.Tensor):
return TensorMeta(
shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype
)
elif isinstance(fake_out, (tuple, list)):
tensor_meta_list: List[Optional[TensorMeta]] = []
for fake_out_item in fake_out:
if isinstance(fake_out_item, torch.Tensor):
tensor_meta_list.append(
TensorMeta(
shape=fake_out_item.shape,
stride=fake_out_item.stride(),
dtype=fake_out_item.dtype,
)
)
else:
tensor_meta_list.append(None)
return (
tuple(tensor_meta_list)
if isinstance(fake_out, tuple)
else tensor_meta_list
)
else:
# if fake is not a tensor or tuple of tensor, return as none
return None
def _wrap_output_spec_tensor_meta(
self,
op: OpOverload,
output_specs: OutputSpecType,
output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
) -> None:
"""
Wrap the output_specs with the tensor metadata from the output.
"""
if isinstance(output_specs, DTensorSpec):
if not isinstance(output_tensor_meta, TensorMeta):
# Either error due to ShardingPropagator or due to incorrect OutputSpec
if not isinstance(output_tensor_meta, (tuple, list)):
raise ValueError(
"ShardingPropagator error: output does not have an associated TensorMeta"
)
raise ValueError(
f"For the op {op.name()}, `output_specs` has 1 output which does not equal the "
f"number of op outputs: {len(output_tensor_meta)}."
)
output_specs.tensor_meta = output_tensor_meta
elif isinstance(output_specs, (tuple, list)):
if not isinstance(output_tensor_meta, (tuple, list)) or len(
output_specs
) != len(output_tensor_meta):
raise ValueError(
f"For the op {op.name()}, `output_specs` has {len(output_specs)} outputs which does not equal the "
f"number of op outputs {_length(output_tensor_meta)}."
)
for i, spec in enumerate(output_specs):
if isinstance(spec, DTensorSpec):
output_tensor_meta_i = output_tensor_meta[i]
if not isinstance(output_tensor_meta_i, TensorMeta):
raise ValueError(
f"ShardingPropagator error: output {i} does not have an associated TensorMeta"
)
spec.tensor_meta = output_tensor_meta_i
def propagate(self, op_info: OpInfo) -> None:
# We cannot use an lru cache if we know that inputs will have dynamic shapes,
# because SymInts are not hashable.
# This is generally ok because this only happens during tracing in torch.compile,
# and tracing does not need to be as fast as eagermode DTensor usages.
if op_info.schema.has_symints:
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
else:
output_sharding = cast(
OutputSharding, self.propagate_op_sharding(op_info.schema)
)
op_info.output_sharding = output_sharding
def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
"""
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:
return OutputSharding(None, op_schema)
out_tensor_meta = self._propagate_tensor_meta(op_schema)
def spec_to_strategy(spec: object) -> object:
if isinstance(spec, DTensorSpec):
return OpStrategy([PlacementStrategy(spec)])
elif (
isinstance(spec, (list, tuple))
and len(spec) > 0
and isinstance(spec[0], DTensorSpec)
):
# tensor list create tuple strategy
tuple_strategy = [spec_to_strategy(s) for s in spec]
tuple_strategy = cast(Sequence[StrategyType], tuple_strategy)
return TupleStrategy(
tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy
)
else:
return spec
if op_schema.op in self.op_strategy_funcs:
# generate op strategy for the op.
mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema)
# swap the args spec with args strategies
args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema]
kwargs_op_strategy = {
k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items()
}
# construct a new OpSchema on args for strategy based propagation
strategy_schema: OpSchema = OpSchema(
op=op_schema.op,
args_schema=tuple(args_op_strategy),
kwargs_schema=kwargs_op_strategy,
)
op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
if isinstance(op_strategy, OpStrategy):
# single Op strategy
output_strategy = self._select_strategy(op_strategy)
# check if we need to redistribute the input
needs_redistribute = False
expected_input_specs: List[DTensorSpec] = []
# in case where the op does not specify input_specs and output_specs
# is a DTensorSpec, we use output_specs as the spec for each DTensor
# input arg.
if output_strategy.input_specs is None:
assert isinstance(output_strategy.output_specs, DTensorSpec)
for idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
output_strategy.output_spec
if output_strategy.input_specs is None
else output_strategy.input_specs[idx]
)
expected_input_specs.append(
desired_spec.shallow_copy_with_tensor_meta(
input_spec.tensor_meta
)
)
if input_spec.placements != desired_spec.placements:
needs_redistribute = True
suggestion_schema = None
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.op, tuple(expected_input_specs), {}
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
# shape and stride args need to be modified for
# view ops and new factory ops, potentially
if op_schema.op in self.op_to_shape_and_stride_idx:
assert isinstance(output_strategy.output_spec, DTensorSpec)
# It happens when the output has the same shape as the input
# and the input placements are not all Replicate().
if output_strategy.output_spec.is_sharded():
schema = suggestion_schema or op_schema
assert isinstance(out_tensor_meta, TensorMeta)
suggestion_schema = self._adjust_shape_and_stride_args(
out_tensor_meta, schema, output_strategy.output_spec, mesh
)
needs_redistribute = True
# construct output spec for the op
if op_schema.return_type_tuple_tensor_like():
# for ops that return multiple tensors and the output_specs is not
# a tuple, we use a tuple of that single output spec as the new
# output_specs
output_specs: OutputSpecType = output_strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = tuple(
[
# create a new DTensorSpec with the same placement as the
# output_specs in output_strategy
DTensorSpec(
mesh=output_specs.mesh,
placements=output_specs.placements,
tensor_meta=output_specs.tensor_meta,
)
for _ in range(len(op_schema.op._schema.returns))
]
)
elif op_schema.return_type_tensor():
output_specs = output_strategy.output_specs
else:
output_specs = None
output_sharding = OutputSharding(
output_specs,
suggestion_schema,
needs_redistribute=needs_redistribute,
)
elif isinstance(op_strategy, TupleStrategy):
# tuple strategy output sharding processing
# runtime selected placement strategy for each TupleStrategy input arg
selected_strategies: List[PlacementStrategy] = []
out_spec_list: List[DTensorSpec] = []
for strategy in op_strategy.childs:
assert isinstance(strategy, OpStrategy)
selected_strategy = self._select_strategy(strategy)
selected_strategies.append(selected_strategy)
out_spec_list.append(selected_strategy.output_spec)
needs_redistribute = False
suggestion_args: List[object] = []
tensor_or_list_tensor_arg_idx = 0
for arg in op_schema.args_schema:
if (
arg
and isinstance(arg, (list, tuple))
and isinstance(arg[0], DTensorSpec)
):
expected_input_spec_list: List[DTensorSpec] = []
for idx, arg_spec in enumerate(arg):
expected_input_spec = selected_strategies[idx].input_spec(
tensor_or_list_tensor_arg_idx
)
expected_input_spec = (
expected_input_spec.shallow_copy_with_tensor_meta(
arg_spec.tensor_meta
)
)
if arg_spec.placements != expected_input_spec.placements:
needs_redistribute = True
expected_input_spec_list.append(expected_input_spec)
suggestion_args.append(
tuple(expected_input_spec_list)
if isinstance(arg, tuple)
else expected_input_spec_list
)
tensor_or_list_tensor_arg_idx += 1
elif isinstance(arg, DTensorSpec):
expected_input_spec = selected_strategies[0].input_spec(
tensor_or_list_tensor_arg_idx
)
expected_input_spec = (
expected_input_spec.shallow_copy_with_tensor_meta(
arg.tensor_meta
)
)
if arg.placements != expected_input_spec.placements:
needs_redistribute = True
suggestion_args.append(expected_input_spec)
tensor_or_list_tensor_arg_idx += 1
else:
suggestion_args.append(arg)
suggestion_schema = None
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema
)
output_sharding = OutputSharding(
tuple(out_spec_list) if out_tensor_meta is not None else None,
suggestion_schema,
needs_redistribute=needs_redistribute,
)
else:
raise ValueError("Unsupported op strategy type")
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
elif op_schema.op in self.op_to_rules:
# propagate the sharding with rule
sharding_prop_func = self.op_to_rules[op_schema.op]
# step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except NotImplementedError as e:
raise e
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}.\n" f"Error: {e}"
) from e
# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
# decide how to do redistribute on inputs
if output_sharding.output_spec is None:
if output_sharding.redistribute_schema is None:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}!"
)
else:
# we do auto redistribute on inputs if necessary
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(
output_sharding.redistribute_schema
)
# we set the output sharding with the new propagation result
# so that dispatching know both output_spec and redistribute_schema
# exist, which indicates a reshard is needed
output_sharding.output_spec = propagation_res.output_spec
output_sharding.needs_redistribute = True
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
else:
raise NotImplementedError(
f"Operator {op_schema.op} does not have a sharding strategy registered."
)
def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
if len(strategy.strategies) == 1:
# short cut with only one possible strategy
return strategy.strategies[0]
strategy_costs: List[float] = []
for strtg in strategy.strategies:
assert (
strtg.redistribute_cost is not None
), "must set redistribute cost each strategy!"
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
strategy_costs.append(redistribute_cost)
# for eager execution, we just select the one with the minimal redistribute cost
return strategy.strategies[strategy_costs.index(min(strategy_costs))]
def _adjust_shape_and_stride_args(
self,
out_tensor_meta: TensorMeta,
schema: OpSchema,
spec: DTensorSpec,
mesh: DeviceMesh,
) -> OpSchema:
shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
if isinstance(shape_stride_idx, tuple):
shape_idx, stride_idx = shape_stride_idx
else:
shape_idx = shape_stride_idx
stride_idx = None
expected_input_schema = list(schema.args_schema)
# adjust shape to be the same as that of the _local_tensor
# of the DTensor input arg at index 0, which is inferred
expected_input_schema[shape_idx] = compute_local_shape(
out_tensor_meta.shape, mesh, spec.placements
)
# adjust the stride arg for aten.new_empty_strided.default
if stride_idx:
expected_input_schema[stride_idx] = compute_local_stride(
out_tensor_meta.stride, mesh, spec.placements
)
return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)

View File

@ -0,0 +1,316 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Tuple
import torch
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
TensorWriteData,
WriteItem,
WriteItemType,
)
aten = (
torch.ops.aten
) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified.
class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
"""
A wrapper class to hold local shards of a DTensor.
This class is used largely for checkpointing purposes and implicity subtypes
the _Checkpointable protocol.
"""
__slots__ = ["_local_shards", "_storage_meta"]
_local_shards: List[torch.Tensor]
_storage_meta: TensorStorageMetadata
@staticmethod
def __new__(
cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]]
) -> "LocalShardsWrapper":
assert len(local_shards) > 0
assert len(local_shards) == len(local_offsets)
assert all(
tensor.device == local_shards[0].device for tensor in local_shards[1:]
)
# we calculate the total tensor size by "concat" on second tensor dimension
cat_tensor_shape = list(local_shards[0].size())
if len(local_shards) > 1: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[1] += shard.size()[1]
wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
wrapper_shape = torch.Size(cat_tensor_shape)
chunks_meta = [
ChunkStorageMetadata(
offsets=torch.Size(offset),
sizes=shard.size(),
)
for shard, offset in zip(local_shards, local_offsets)
]
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
torch.Size(cat_tensor_shape),
)
r._local_shards = local_shards
r._storage_meta = TensorStorageMetadata(
properties=wrapper_properties,
size=wrapper_shape,
chunks=chunks_meta,
)
return r
# necessary for ops dispatching from this subclass to its local shards
@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
dispatcher = {
torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor,
torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor,
aten._to_copy.default: cls.handle_to_copy,
aten.view.default: cls.handle_view,
aten.equal.default: cls.handle_equal,
aten.detach.default: cls.handle_detach,
aten.clone.default: cls.handle_clone,
}
if func in dispatcher:
return dispatcher[func](
args, kwargs
) # pyre-ignore [29] - `Variable[_VT]` is not a function.
else:
raise NotImplementedError(
f"{func} is not supported for LocalShardsWrapper!"
)
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_all_gather_into_tensor(args, kwargs):
dim = args[0].local_sizes()[0][1]
cat_tensor = torch.cat(
[t.view(-1) for t in args[0].local_shards()], dim=0
).view(-1, dim)
return torch.ops._c10d_functional.all_gather_into_tensor.default(
cat_tensor, *args[1:], **kwargs
)
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_wait_tensor(args, kwargs):
return torch.ops._c10d_functional.wait_tensor(args[0])
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_to_copy(args, kwargs):
res_shards_list = [
aten._to_copy.default(shard, *args[1:], **kwargs)
for shard in args[0].local_shards()
]
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_view(args, kwargs):
# TODO, do we need to change the shape of associated offsets?
res_shards_list = [
aten.view.default(shard, args[1], **kwargs)
for shard in args[0].local_shards()
]
return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_equal(args, kwargs):
"""
LocalShardsWrapper equal impl also checks for equality of storage metadata
and the order of shards
"""
a, b = args[0], args[1]
if len(a.local_shards()) != len(b.local_shards()):
return False
if not all(
aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
):
return False
if not a.storage_metadata() == b.storage_metadata():
return False
return True
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_detach(args, kwargs):
self_ls = args[0]
deatched_local_shards = [
aten.detach.default(shard) for shard in self_ls.local_shards()
]
self_ls._local_shards = deatched_local_shards
self_ls._storage_meta.properties.requires_grad = False
return self_ls
@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_clone(args, kwargs):
self_ls = args[0]
desired_memory_format = kwargs.get("memory_format", None)
if desired_memory_format and desired_memory_format != torch.preserve_format:
raise NotImplementedError(
f"{desired_memory_format} is not supported for LocalShardsWrapper!"
)
cloned_local_shards = [
shard.clone(memory_format=desired_memory_format)
for shard in self_ls._local_shards
]
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
@property
def device(self) -> torch._C.device: # type: ignore[override]
return self._local_shards[0].device
@property
def is_meta(self) -> bool: # type: ignore[override]
return self._local_shards[0].is_meta
# pyre-ignore[14]
def is_pinned(self) -> bool: # type: ignore[override]
return self._storage_meta.properties.pin_memory
# pyre-ignore[14]
def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper":
self._storage_meta.properties.requires_grad = requires_grad
[shard.requires_grad_(requires_grad) for shard in self._local_shards]
return self
def local_shards(self) -> List[torch.Tensor]:
"""
Returns a list of :class:`torch.Tensor' corresponding to the
local shards for this rank. Returns an empty list if the current rank
does not host any shards for this Tensor.
"""
return self._local_shards
def local_sizes(self) -> List[torch.Size]:
"""
Returns a list of :class:`torch.Size' corresponding to the
local sizes for the shards on this rank. Returns an empty list if the current rank
does not host any shards for this Tensor.
"""
return [chunk.sizes for chunk in self._storage_meta.chunks]
def local_offsets(self) -> List[torch.Size]:
"""
Returns a list of :class:`torch.Size' corresponding to the
local offsets for the shards on this rank. Returns an empty list if the current rank
does not host any shards for this Tensor.
"""
return [chunk.offsets for chunk in self._storage_meta.chunks]
@property
def local_chunks(self) -> List[ChunkStorageMetadata]:
"""
Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
metadata for each tensor shard
"""
return self._storage_meta.chunks
def storage_metadata(self) -> TensorStorageMetadata:
"""
Returns a :class:`TensorStorageMetadata` object corresponding to the
metadata for the local tensor on current rank
"""
return self._storage_meta
def __create_write_items__(
self, fqn: str, object: Any
) -> List[WriteItem]: # pyre-ignore[2]
"""
For compatibility with DCP, we support creation of WriteItems
such that they can be saved properly.
"""
return [
WriteItem(
index=MetadataIndex(fqn, chunks.offsets),
type=WriteItemType.SHARD,
tensor_data=TensorWriteData(
chunk=ChunkStorageMetadata(
offsets=chunks.offsets,
sizes=chunks.sizes,
),
properties=self._storage_meta.properties,
size=object.size(),
),
)
for tensor, chunks in zip(self.local_shards(), self.local_chunks)
]
def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
"""
For compatibility with DCP, we support creation of chunk lists
such that they can be saved properly.
"""
return self._storage_meta.chunks
def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
"""
For compatibility with DCP, we support finding shard based on index
Return a 'torch.Tensor' shard based on 'MetadataIndex'.
"""
# Fast lookup path
if index.index is not None:
if (
len(self._local_shards) > index.index
and self._storage_meta.chunks[index.index].offsets == index.offset
):
return self._local_shards[index.index]
if index.offset is not None:
for shard, chunk in zip(self._local_shards, self._storage_meta.chunks):
if chunk.offsets == index.offset:
return shard
raise ValueError(
f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
)
def _get_tensor_size_bytes(self) -> int:
object_size = 0
for shard in self.local_shards():
object_size += shard.nelement() * shard.element_size()
return object_size
# pyre-fixme[3]: Return type must be annotated.
def __hash__(self):
return id(self)
# pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
# pyre-fixme[3]: Return type must be annotated.
def __repr__(self):
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
def __str__(self) -> str:
return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"

View File

@ -0,0 +1,279 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import cast, Dict, List, Tuple
import torch
import torch.distributed as dist
import torch.distributed.tensor._api as dtensor
aten = torch.ops.aten
def _requires_data_exchange(padding):
# TODO: whether there requires data exchange is currently determined by padding
return padding[1] != 0
def _is_supported(input_size, kernel_size, stride, padding, dilation):
if dilation[1] != 1:
raise RuntimeError("Dilation must be 1 for tensor parallel convolution.")
if padding[1] != 0:
if stride[1] != 1:
raise RuntimeError(
"Stride must be 1 when there is padding for tensor parallel convolution."
)
if kernel_size[3] // 2 > input_size[3]:
raise RuntimeError(
"kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution."
)
else:
if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]):
raise RuntimeError(
"It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] "
"when there is padding for tensor parallel convolution."
)
return True
def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size):
# dist comms and reconstruct local input tensor
send_to_right = in_tensor[:, :, :, -d1:].contiguous()
send_to_left = in_tensor[:, :, :, :d2].contiguous()
recv_from_right = torch.zeros_like(send_to_left)
recv_from_left = torch.zeros_like(send_to_right)
send_op_right = dist.P2POp(dist.isend, send_to_right, right)
send_op_left = dist.P2POp(dist.isend, send_to_left, left)
recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
reqs = dist.batch_isend_irecv(
[send_op_right, send_op_left, recv_op_left, recv_op_right]
)
for req in reqs:
req.wait()
if rank == 0:
in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1)
elif rank == size - 1:
in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1)
else:
in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1)
return in_tensor
def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
# dist comms and aggregate gradients for edge pixels
send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous()
send_to_left = grad_in_tensor[:, :, :, :d1].contiguous()
recv_from_right = torch.zeros_like(send_to_left)
recv_from_left = torch.zeros_like(send_to_right)
send_op_right = dist.P2POp(dist.isend, send_to_right, right)
send_op_left = dist.P2POp(dist.isend, send_to_left, left)
recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right)
recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left)
reqs = dist.batch_isend_irecv(
[send_op_right, send_op_left, recv_op_left, recv_op_right]
)
for req in reqs:
req.wait()
if rank == 0:
grad_in_tensor = grad_in_tensor[:, :, :, :-d2]
grad_in_tensor[:, :, :, -d1:] = torch.add(
grad_in_tensor[:, :, :, -d1:], recv_from_right
)
elif rank == size - 1:
grad_in_tensor = grad_in_tensor[:, :, :, d1:]
grad_in_tensor[:, :, :, :d2] = torch.add(
grad_in_tensor[:, :, :, :d2], recv_from_left
)
else:
grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2]
grad_in_tensor[:, :, :, -d1:] = torch.add(
grad_in_tensor[:, :, :, -d1:], recv_from_right
)
grad_in_tensor[:, :, :, :d2] = torch.add(
grad_in_tensor[:, :, :, :d2], recv_from_left
)
def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: Tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
rank = dist.get_rank()
size = dist.get_world_size()
in_tensor = cast(torch.Tensor, local_tensor_args[0])
weight = cast(torch.Tensor, local_tensor_args[1])
stride, padding, dilation = local_tensor_args[3:6]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
# step 0 compute the overlap pixels of the input tensor
d = weight.shape[3] - 1
d1 = d // 2
d2 = d - d1
assert d1 + d2 == d
right = (rank + 1) % size
left = (rank - 1 + size) % size
# step1 reconstruct local input tensor
in_tensor = _ring_send_recv_construct(
in_tensor, d1, d2, left, right, rank, size
)
# step2 feed local input tensor to op_call
local_tensor_args_list = list(local_tensor_args)
local_tensor_args_list[0] = in_tensor
local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
# step3 remove extra outputs from the results
padding_w = padding[1]
w = local_results.size(3)
if rank == 0:
local_results = local_results[:, :, :, : w - padding_w]
elif rank == size - 1:
local_results = local_results[:, :, :, padding_w:]
else:
local_results = local_results[:, :, :, padding_w : w - padding_w]
return local_results
def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: Tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
rank = dist.get_rank()
size = dist.get_world_size()
grad_out_tensor = cast(torch.Tensor, local_tensor_args[0])
in_tensor = cast(torch.Tensor, local_tensor_args[1])
weight = cast(torch.Tensor, local_tensor_args[2])
stride, padding, dilation = local_tensor_args[4:7]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
return local_results
else:
# step 0 compute the overlap pixels of the input tensor
d = weight.shape[3] - 1
d1 = d // 2
d2 = d - d1
assert d1 + d2 == d
right = (rank + 1) % size
left = (rank - 1 + size) % size
# step1 reconstruct local input tensor
in_tensor = _ring_send_recv_construct(
in_tensor, d1, d2, left, right, rank, size
)
# step2 reconstruct local gradient output tensor
N, C_out, H_out, _ = grad_out_tensor.shape
padding_w = padding[1]
if rank == 0:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (0, padding_w), "constant", 0
)
elif rank == size - 1:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (padding_w, 0), "constant", 0
)
else:
grad_out_tensor = torch.nn.functional.pad(
grad_out_tensor, (padding_w, padding_w), "constant", 0
)
# step3 feed local input tensor to op_call
local_tensor_args_list = list(local_tensor_args)
local_tensor_args_list[0] = grad_out_tensor
local_tensor_args_list[1] = in_tensor
local_tensor_args = cast(Tuple[object, ...], local_tensor_args_list)
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
# step4 aggregate gradients for edge pixels
grad_in_tensor = local_results[0]
grad_in_tensor = _ring_send_recv_aggregate(
grad_in_tensor, d1, d2, left, right, rank, size
)
local_results = list(local_results)
local_results[0] = grad_in_tensor
local_results = cast(Tuple[object, ...], local_results)
return local_results
def convolution_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# extract local tensor and sharding infos to a OpInfo
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
# sharding propagation
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
# local propagation
local_results = tp_convolution(
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)
def convolution_backward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as input tensor
args = list(args)
assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
args = tuple(args)
# extract local tensor and sharding infos to a OpInfo
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
# sharding propagation
dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
# local propagation
local_results = tp_convolution_backward(
op_call, tuple(op_info.local_args), op_info.local_kwargs
)
return dtensor.DTensor._op_dispatcher.wrap(
local_results, output_sharding.output_spec
)

View File

@ -0,0 +1,316 @@
from typing import cast, List, Sequence, Tuple
import torch
import torch.distributed.tensor._api as dtensor
from torch._prims_common import ShapeType
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
from torch.distributed.tensor.placement_types import (
_StridedShard,
Partial,
Placement,
Replicate,
Shard,
)
# TODO: audit existing code base to see if we can safely remove this API.
def compute_local_shape(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> Tuple[int, ...]:
"""
Compute the shape of a local shard of the given DTensor on its current
coordinate of the mesh.
"""
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
# if rank not in the mesh, return empty shape
return (0,)
else:
local_shape = list(global_shape) # start with global shape
ndim = len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
assert (
shard_dim < ndim
), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}"
local_shard_size, _ = placement._local_shard_size_on_dim(
local_shape[shard_dim], mesh_dim_size, my_coordinate[idx]
)
assert isinstance(local_shard_size, int)
local_shape[shard_dim] = local_shard_size
return tuple(local_shape)
def compute_local_shape_and_global_offset(
global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Compute the local tensor shape and the global offsets into the original tensor
of a DTensor on its current global rank. This is useful for checkpointing purpose.
Example (2 host with 4GPUs each):
# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
],
)
Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh
with a placements of [Shard(0), Shard(0)].
The local shape and global offset will be as follows:
rank0 -- local_shape:[1, 4], global_offset:[0, 0]
rank1 -- local_shape:[1, 4], global_offset:[1, 0]
rank2 -- local_shape:[1, 4], global_offset:[2, 0]
rank5 -- local_shape:[1, 4], global_offset:[5, 0]
rank3 -- local_shape:[1, 4], global_offset:[3, 0]
rank4 -- local_shape:[1, 4], global_offset:[4, 0]
rank6 -- local_shape:[1, 4], global_offset:[6, 0]
rank7 -- local_shape:[1, 4], global_offset:[7, 0]
Let's say we distribute a global_tensor of shape (2) over the above DeviceMesh with
a placements of [Shard(0)]. We will not have non-empty local tensor for all the ranks.
The local shape and global offset will be as follows:
rank0 -- local_shape:[1,], global_offset:[0,]
rank1 -- local_shape:[1,], global_offset:[1,]
rank2 -- local_shape:[0,], global_offset:[2,]
rank5 -- local_shape:[0,], global_offset:[2,]
rank3 -- local_shape:[0,], global_offset:[2,]
rank4 -- local_shape:[0,], global_offset:[2,]
rank6 -- local_shape:[0,], global_offset:[2,]
rank7 -- local_shape:[0,], global_offset:[2,]
"""
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((), ())
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)
shard_idx_stride_by_mesh_dim = [
[0] * mesh.ndim for _ in range(len(global_shape))
] # index by (shard_dim, mesh_dim)
num_shards_by_tensor_dim = [1] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
return_offset=True,
)
local_shape[shard_dim] = shard_size
local_offset[shard_dim] = shard_offset
# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
# it means that this dimension has been already sharded in previous placement.
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
if global_offset[shard_dim] <= local_offset[shard_dim]:
global_offset[shard_dim] = local_offset[shard_dim]
else:
global_offset[shard_dim] += local_offset[shard_dim]
num_shards_by_tensor_dim[shard_dim] *= mesh_dim_size
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
# each local tensor will be further split into when sharding on higher mesh
# dimensions. However, this number is only correct if the DTensor is not
# sharded after the strided sharding completes. For example,
# [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# device mesh dim-2, and last on mesh dim-1. We define the
# "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# part because strided sharding happens on mesh dim-1 and it was caused by
# the fact that sharding on dim-2 occurred ahead. In this case, there's no
# further sharding after this strided sharding part and ``split_factor``
# correctly encodes the number. Another example is
# [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# dim-2. This violates our assumption that no further sharding shall occur
# after the strided sharding part and ``split_factor`` won't correctly
# encode the number of further split. So far, the only case where _StridedShard
# placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
strided_sharding = any(isinstance(p, _StridedShard) for p in placements)
if strided_sharding:
strided_part_seen = [False] * len(global_shape)
strided_part_end = [False] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
if strided_part_end[shard_dim]:
raise NotImplementedError(
f"Strided sharding does not allow Shard() to appear after "
f"the strided part has ended. {placement} at idx {idx} in "
f"{placements} violates this assumption."
)
if strided_part_seen[shard_dim]:
strided_part_end[shard_dim] = True
if isinstance(placement, _StridedShard):
strided_part_seen[shard_dim] = True
shard_idx_stride_by_mesh_dim[shard_dim][
idx
] = num_shards_by_tensor_dim[shard_dim] // (
placement.split_factor * mesh_dim_size
)
else:
num_shards_by_tensor_dim[shard_dim] //= mesh_dim_size
shard_idx_stride_by_mesh_dim[shard_dim][
idx
] = num_shards_by_tensor_dim[shard_dim]
shard_idx = [
sum([x * y for x, y in zip(shard_idx_stride, my_coordinate)])
for shard_dim, shard_idx_stride in enumerate(
shard_idx_stride_by_mesh_dim
)
]
global_offset = [x * y for x, y in zip(local_shape, shard_idx)]
return tuple(local_shape), tuple(global_offset)
def compute_global_tensor_info(
tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
) -> Tuple[List[int], List[int]]:
"""
Compute the global size and stride of a DTensor from the given local tensor.
The local size is multiplited by `world_size` per Sharding dim.
The local stride is multiplited by `world_size` per Sharding dim, as long as the
dimension is outside sharding dim.
For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).
If the DTensor placements are [Shard(2)] and world_size is 2;
then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).
Args:
tensor (:class:`torch.Tensor`):
Local tensor which DTensor will be constructed from.
mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
placements (Sequence[:class:`Placement`]]):
The attribute of the DTensor that describes its layout
on the mesh topology.
Return:
tensor_shape: A List of int which specifies the size of DTensor which build
on top of the local tensor.
tensor_stride: A List of int which specifies the stride of DTensor.
"""
tensor_shape = list(tensor.size())
tensor_stride = list(tensor.stride())
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if placement.is_shard():
shard_placement = cast(Shard, placement)
if shard_placement.dim < 0:
raise AssertionError(
"Shard placements should have negative dims normalized in "
f"the user-facing APIs: {shard_placement}"
)
shard_dim = shard_placement.dim
assert (
shard_dim < tensor.ndim
), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}."
local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
# recover tensor stride by modifying the stride that larger than
# the current stride on the shard_dim
for i in range(len(tensor_stride)):
if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
# rescale the stride by the shard size
tensor_stride[i] = tensor_stride[i] * mesh_dim_size
elif not isinstance(placement, (Replicate, Partial)):
raise RuntimeError(f"placement type {type(placement)} not supported!")
return tensor_shape, tensor_stride
def try_find_mesh_from_args(
op_call: torch._ops.OpOverload, args: Sequence[object]
) -> DeviceMesh:
"""
Find the device mesh object from args.
It returns None if no mesh is found.
NOTE: we can optimize this search if needed
"""
for arg in args:
if isinstance(arg, (dtensor.DTensor, DTensorSpec)):
return arg.device_mesh
elif (
isinstance(arg, (list, tuple))
and len(arg) > 0
and isinstance(arg[0], (dtensor.DTensor, DTensorSpec))
):
return arg[0].device_mesh
raise ValueError(f"Cannot find device mesh from args for op : {op_call}.")
def compute_local_stride(
global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
) -> Tuple[int, ...]:
"""
Compute the stride of a local tensor shard, given the global stride of the DTensor.
NOTE: Currently this function is assuming the DTensor is evenly shardable.
"""
stride_divisors = [1] * len(global_stride)
for mesh_idx, p in enumerate(placements):
if p.is_shard():
i = cast(Shard, p).dim
# tensor dimension i is sharded on mesh dimension mesh_idx,
# so we need to divide all the strides larger than stride[i]
# (by the submesh size)
for j in range(len(global_stride)):
if global_stride[j] > global_stride[i]:
stride_divisors[j] *= mesh.size(mesh_idx)
return tuple(
global_stride[i] // stride_divisors[i] for i in range(len(global_stride))
)
def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def]
"""
Unify variable types of size argument to torch.Size
Acceptable types include:
int, Sequence[int], Tuple[int], Tuple[Sequence[int]],
or torch.Size
"""
if isinstance(size, torch.Size):
return size
if isinstance(size, int):
torch_size = [size]
elif len(size) == 1 and isinstance(size[0], Sequence):
torch_size = list(size[0])
else:
torch_size = list(size)
return torch.Size(torch_size)

View File

@ -0,0 +1,24 @@
# mypy: allow-untyped-defs
from torch.distributed.tensor.debug._comm_mode import CommDebugMode
from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding
__all__ = ["CommDebugMode", "visualize_sharding"]
def _get_sharding_prop_cache_info():
"""
Get the cache info for the sharding propagation cache, used for debugging purpose only.
This would return a named tuple showing hits, misses, maxsize and cursize of the sharding
propagator cache.
"""
from torch.distributed.tensor._api import DTensor
return (
DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined]
)
# Set namespace for exposed private names
CommDebugMode.__module__ = "torch.distributed.tensor.debug"
visualize_sharding.__module__ = "torch.distributed.tensor.debug"

View File

@ -0,0 +1,735 @@
# mypy: allow-untyped-defs
import copy
import json
import re
import weakref
from collections import defaultdict
from typing import Any, Dict
import torch
import torch.nn
from torch._guards import detect_fake_mode
from torch.autograd.graph import register_multi_grad_hook
from torch.distributed._tools.mod_tracker import ModTracker
from torch.distributed.tensor._api import DTensor
from torch.nn.modules.module import (
register_module_forward_hook,
register_module_forward_pre_hook,
register_module_full_backward_pre_hook,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
__all__ = ["CommDebugMode"]
funcol_native = torch.ops._c10d_functional
funcol_py = torch.ops.c10d_functional
funcol_autograd = torch.ops._c10d_functional_autograd
c10d_ops = torch.ops.c10d
NATIVE_TO_PY_MAPPING = {
funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced,
funcol_native.all_reduce: funcol_py.all_reduce,
funcol_native.all_reduce_coalesced: funcol_py.all_reduce_coalesced,
funcol_native.all_to_all_single: funcol_py.all_to_all_single,
funcol_native.broadcast: funcol_py.broadcast,
funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor,
funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced,
# functional ops
funcol_autograd.all_to_all_single: funcol_py.all_to_all_single,
}
c10d_collective_ops = {
c10d_ops._allgather_base_,
c10d_ops._reduce_scatter_base_,
c10d_ops.allgather_,
c10d_ops.allgather_coalesced_,
c10d_ops.allgather_into_tensor_coalesced_,
c10d_ops.allreduce_,
c10d_ops.allreduce_coalesced_,
c10d_ops.alltoall_,
c10d_ops.alltoall_base_,
c10d_ops.broadcast_,
c10d_ops.gather_,
c10d_ops.scatter_,
c10d_ops.reduce_,
c10d_ops.reduce_scatter_,
c10d_ops.reduce_scatter_tensor_coalesced_,
}
trivial_ops = {
"aten.detach.default",
"aten.t.default",
"aten.view.default",
"aten._to_copy.default",
"aten.as_strided.default",
"aten.transpose.int",
}
class _CommModeModuleTracker(ModTracker):
"""
Inherits ModuleTracker and expands on its functionality to track the
parameters and sharding information of a model at a module-level
"""
def __init__(self):
super().__init__()
self.module_helper_dict = {}
self.module_parameters_dict = {}
self.module_parents_dict = {}
self.register_forward_hook_handles = {}
self.parent_dict = {}
self.parent_list = []
self.sharding_dict = {}
self.activation_checkpointing = False
self.name = ""
def _fw_set_module_hook(self, mod, input, output):
"""
Updates the current module after module finishes running and
all other hooks are resolved
"""
if self.is_bw:
self.activation_checkpointing = True
else:
self.activation_checkpointing = False
if not self.activation_checkpointing:
# module is no longer parent of next modules
self.parent_list.pop()
# set current module to previous parent module
self.name = self.parent_list[-1]
def _fw_pre_hook(self, mod, input):
"""
This function is called before the forward pass of a module. It
collects the parameters and sharding information of a module and
stores it in a dictionary.
"""
if self.is_bw:
self.activation_checkpointing = True
else:
self.activation_checkpointing = False
self.name = super()._get_mod_name(mod)
w_mod = weakref.ref(mod)
# adds current sub-module to module tracker parent class
super()._get_append_fn(w_mod, self.name, False)()
args, _ = tree_flatten(input)
tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
if not self.is_bw and tensors:
register_multi_grad_hook(
tensors, super()._get_pop_fn(w_mod, self.name, True)
)
if not self.activation_checkpointing:
# contains information about module ordering and depth in the module tree
if self.name not in self.module_helper_dict:
self.module_helper_dict[self.name] = {}
self.module_helper_dict[self.name]["module_type"] = (
str(type(mod)).replace("<", "").replace(">", "")
)
self.module_helper_dict[self.name]["depth"] = len(self.parents) - 1
for param_name, param in mod.named_parameters(recurse=False):
if self.name not in self.module_parameters_dict:
self.module_parameters_dict[self.name] = {}
self.module_parameters_dict[self.name][param_name] = param.data
if isinstance(param.data, DTensor):
key_name = self.name + "." + param_name
self.sharding_dict[key_name] = param.data.placements
if "parameters" not in self.module_helper_dict[self.name]:
self.module_helper_dict[self.name]["parameters"] = {}
self.module_helper_dict[self.name]["parameters"][param_name] = str(
param.data.placements
)
# used to store module's parents to ensure correctness in backward pass/checkpointing
if self.name not in self.module_parents_dict:
self.module_parents_dict[self.name] = copy.deepcopy(self.parents)
# used to create parent-child module associations for json dumps
parent = self.parent_list[-1]
if parent not in self.parent_dict:
self.parent_dict[parent] = []
self.parent_dict[parent].append(self.name)
self.parent_list.append(self.name)
self.register_forward_hook_handles[self.name] = mod.register_forward_hook(
self._fw_set_module_hook
)
def _fw_post_hook(self, mod, input, output):
"""
This function is called when the forward pass of a module is called.
It updates the module tracker and removes the module from parent data
"""
super()._fw_post_hook(mod, input, output)
def _bw_hook(self, mod, output):
"""
This function is called when the backward pass of a module is called. It
updates the current module for backward passes
"""
self.activation_checkpointing = False
self.name = super()._get_mod_name(mod)
def __enter__(self):
self.activation_checkpointing = False
self.module_parameters_dict.clear()
self.sharding_dict.clear()
self.parent_dict.clear()
self.parent_list = ["Global"]
self.module_helper_dict.clear()
self.module_helper_dict["Global"] = {"depth": 0}
self.module_parents_dict.clear()
self.module_parents_dict["Global"] = set()
self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
self.register_forward_hook_handles.clear()
self._bw_handle = register_module_full_backward_pre_hook(self._bw_hook)
self.name = "Global"
def __exit__(self, *args):
super().__exit__(*args)
self._bw_handle.remove()
# removes all forward_hook handles added in the pre-hook
for handle in self.register_forward_hook_handles.values():
handle.remove()
def print_paramater_info(self):
print(self.module_parameters_dict)
def print_sharding_info(self):
for key, value in self.sharding_dict.items():
print(key + ": " + str(value))
class CommDebugMode(TorchDispatchMode):
"""
:class:`CommDebugMode` is a context manager that counts the number of
functional collectives within its context. It does this using a
``TorchDispatchMode``.
.. note: Not all collectives are supported yet.
Example usage
.. code-block:: python
mod = ...
comm_mode = CommDebugMode()
with comm_mode:
mod.sum().backward()
print(comm_mode.get_comm_counts())
"""
def __init__(self):
self.comm_counts: Dict[Any, int] = defaultdict(int)
self.comm_module_counts = {}
self.comm_module_operation_counts = {}
self.comm_registry = set()
for native_op, py_op in NATIVE_TO_PY_MAPPING.items():
self.comm_registry.add(native_op)
self.comm_registry.add(py_op)
self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall)
self.advanced_module_tracker = _CommModeModuleTracker()
def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3):
"""
Creates json file used to build browser visual
0. prints module-level collective counts
1. prints dTensor operations not included in trivial operations
2. prints operations not included in trivial operations
3. prints all operations
"""
(
include_DTensor_ops,
include_module_data,
include_ops,
include_trivial_ops,
) = self._set_noise_parameters(noise_level)
# recursively builds json data
def add_json_information(json_dict, fqn):
json_dict["fqn"] = fqn
json_dict["module_type"] = ""
json_dict["parameters"] = []
json_dict["children"] = []
json_dict["collectives_forward"] = []
json_dict["collectives_backward"] = []
json_dict["operations_forward"] = []
json_dict["operations_backward"] = []
# adds module layer type and parameters, and their sharding
if (
"module_type" in self.advanced_module_tracker.module_helper_dict[fqn]
and include_module_data
):
json_dict[
"module_type"
] = self.advanced_module_tracker.module_helper_dict[fqn]["module_type"]
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
for (
param_name,
placement,
) in self.advanced_module_tracker.module_helper_dict[fqn][
"parameters"
].items():
json_dict["parameters"].append((param_name, placement))
# adds module collective information
if fqn in self.comm_module_counts:
for collective, count in self.comm_module_counts[fqn][
"forward"
].items():
json_dict["collectives_forward"].append((str(collective), count))
for collective, count in self.comm_module_counts[fqn][
"backward"
].items():
json_dict["collectives_backward"].append((str(collective), count))
# adds module operation information
forward_operations = []
backward_operations = []
checkpointing_operations = []
# only get operations if the minimum operation noise level is set to true
if include_DTensor_ops:
if fqn in self.comm_module_operation_counts:
(
forward_operations,
backward_operations,
checkpointing_operations,
) = self._get_operations_list(
self.comm_module_operation_counts[fqn]
)
# remove all operations who don't have DTensor inputs
if not include_ops:
forward_operations = [
op for op in forward_operations if len(op["input_sharding"])
]
backward_operations = [
op for op in backward_operations if len(op["input_sharding"])
]
checkpointing_operations = [
op for op in checkpointing_operations if len(op["input_sharding"])
]
# remove all operations in trivial operations set
if not include_trivial_ops:
forward_operations = [
op
for op in forward_operations
if str(op["name"]) not in trivial_ops
]
backward_operations = [
op
for op in backward_operations
if str(op["name"]) not in trivial_ops
]
checkpointing_operations = [
op
for op in checkpointing_operations
if str(op["name"]) not in trivial_ops
]
# converts operation information into string format for json.dumps()
forward_operations = copy.deepcopy(forward_operations)
for op in forward_operations:
op["name"] = str(op["name"])
for i in range(len(op["input_sharding"])):
op["input_sharding"][i] = str(op["input_sharding"][i])
op["input_shape"][i] = str(op["input_shape"][i])
backward_operations = copy.deepcopy(backward_operations)
for op in backward_operations:
op["name"] = str(op["name"])
for i in range(len(op["input_sharding"])):
op["input_sharding"][i] = str(op["input_sharding"][i])
op["input_shape"][i] = str(op["input_shape"][i])
checkpointing_operations = copy.deepcopy(checkpointing_operations)
for op in checkpointing_operations:
op["name"] = str(op["name"])
for i in range(len(op["input_sharding"])):
op["input_sharding"][i] = str(op["input_sharding"][i])
op["input_shape"][i] = str(op["input_shape"][i])
json_dict["operations_forward"] = forward_operations
json_dict["operations_backward"] = backward_operations
json_dict["operations_checkpointing"] = checkpointing_operations
if fqn not in self.advanced_module_tracker.parent_dict:
return json_dict
# recursively adds module's children
for ele in self.advanced_module_tracker.parent_dict[fqn]:
json_dict["children"].append(add_json_information({}, ele))
return json_dict
json_dict: Dict[str, Any] = {}
add_json_information(json_dict, "Global")
# converts dictonary into json file
with open(file_name, "w") as json_file:
json.dump(json_dict, json_file, indent=4)
def generate_comm_debug_tracing_table(self, noise_level=3):
"""
Generates detailed table displaying operations and collective tracing information
on a module level. Amount of information is dependent on noise_level
0. prints module-level collective counts
1. prints dTensor operations not included in trivial operations, module information
2. prints operations not included in trivial operations
3. prints all operations
"""
(
include_DTensor_ops,
include_module_data,
include_ops,
include_trivial_ops,
) = self._set_noise_parameters(noise_level)
table = ""
for fqn in self.advanced_module_tracker.module_helper_dict:
# setting up indentations for table formatting
indent = " " * (
2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"]
)
table += f"{indent}{fqn}\n"
if include_module_data:
if (
"module_type"
in self.advanced_module_tracker.module_helper_dict[fqn]
):
module_type = self.advanced_module_tracker.module_helper_dict[fqn][
"module_type"
]
table += f"{indent}*module type: {module_type}\n"
if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]:
table += f"{indent}*Parameter List\n"
for (
param_name,
placement,
) in self.advanced_module_tracker.module_helper_dict[fqn][
"parameters"
].items():
table += f"{indent} *{param_name}: {placement}\n"
indent += " "
collective_indent = " " * (
2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 2
)
operation_indent = " " * (
2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 3
)
# separate the module's collective and operations by forward and backward
forward_collectives = {}
backward_collectives = {}
if fqn in self.comm_module_counts:
forward_collectives = self.comm_module_counts[fqn]["forward"]
backward_collectives = self.comm_module_counts[fqn]["backward"]
forward_operations = []
backward_operations = []
checkpointing_operations = []
if include_DTensor_ops:
if fqn in self.comm_module_operation_counts:
(
forward_operations,
backward_operations,
checkpointing_operations,
) = self._get_operations_list(
self.comm_module_operation_counts[fqn]
)
def add_tracing_information(table, collectives_dict, operation_list):
"""
adds tracing information for module's forward or backward
"""
for collective, count in collectives_dict.items():
table += (
f"\033[1;33m{collective_indent}*{collective}: {count}\033[0m\n"
)
def add_operations(
table, operation, collective_indent, operation_indent
):
"""
adds operation information to the table
"""
table += f"\033[1;33m{collective_indent}**{operation_name}\033[0m\n"
if len(operation["input_shape"]):
operation_shape = operation["input_shape"]
operation_sharding = operation["input_sharding"]
operation_device_mesh = operation["device_mesh"]
table += f"\033[1;31m{operation_indent}shape: {operation_shape}\033[0m\n"
table += f"\033[1;31m{operation_indent}sharding: {operation_sharding}\033[0m\n"
table += f"\033[1;31m{operation_indent}device mesh: {operation_device_mesh}\033[0m\n"
return table
for operation in operation_list:
operation_name = str(operation["name"])
# include all operations
if include_trivial_ops:
table = add_operations(
table, operation, collective_indent, operation_indent
)
# include all operations not in trivial operations
elif include_ops and operation_name not in trivial_ops:
table = add_operations(
table, operation, collective_indent, operation_indent
)
# only include dTensor operations not in trivial set
elif (
include_DTensor_ops
and (operation_name not in trivial_ops)
and len(operation["input_shape"])
):
table = add_operations(
table, operation, collective_indent, operation_indent
)
return table
if len(forward_collectives) or len(forward_operations):
table += f"{indent}FORWARD PASS\n"
table = add_tracing_information(
table, forward_collectives, forward_operations
)
if len(backward_collectives) or len(backward_operations):
table += f"{indent}BACKWARD PASS\n"
table = add_tracing_information(
table, backward_collectives, backward_operations
)
if len(checkpointing_operations):
table += f"{indent}ACTIVATION CHECKPOINTING\n"
table = add_tracing_information(table, {}, checkpointing_operations)
return table
def _get_operations_list(self, module_operation_counts):
forward_operations = [
op for op in module_operation_counts["operations_list"] if not op["is_bw"]
]
backward_operations = [
op
for op in module_operation_counts["operations_list"]
if op["is_bw"] and not op["is_activation_checkpointing"]
]
checkpointing_operations = [
op
for op in module_operation_counts["operations_list"]
if op["is_activation_checkpointing"]
]
return forward_operations, backward_operations, checkpointing_operations
def get_total_counts(self) -> int:
return sum(self.comm_counts.values())
def get_comm_counts(self) -> Dict[Any, int]:
"""Returns the communication counts as a dictionary.
Returns:
Dict[Any, int]: The communication counts as a dictionary.
"""
return self.comm_counts
def get_parameter_info(self) -> Dict[str, Dict[str, Any]]:
return self.advanced_module_tracker.module_parameters_dict
def get_sharding_info(self) -> Dict[str, Dict[str, Any]]:
return self.advanced_module_tracker.sharding_dict
def __enter__(self):
self.comm_counts.clear()
self.comm_module_counts.clear()
self.comm_module_counts["Global"] = {}
self.comm_module_counts["Global"]["forward"] = defaultdict(int)
self.comm_module_counts["Global"]["backward"] = defaultdict(int)
self.comm_module_operation_counts.clear()
super().__enter__()
self.advanced_module_tracker.__enter__()
return self
def __exit__(self, *args):
self.advanced_module_tracker.__exit__()
super().__exit__(*args)
def log_comm_debug_tracing_table_to_file(
self, file_name="comm_mode_log.txt", noise_level=3
):
"""
Alternative to console CommDebugMode output, writes to file specified by the user
"""
ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
table = ansi_escape.sub("", self.generate_comm_debug_tracing_table(noise_level))
with open(file_name, "w") as log_file:
log_file.write(table)
def _set_noise_parameters(self, noise_level):
"""
sets variables controlling what information displays based on noise level
"""
include_DTensor_ops = False
include_module_data = False
include_ops = False
include_trivial_ops = False
if noise_level > 0:
include_DTensor_ops = True
include_module_data = True
if noise_level > 1:
include_ops = True
if noise_level > 2:
include_trivial_ops = True
return (
include_DTensor_ops,
include_module_data,
include_ops,
include_trivial_ops,
)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
# When running this mode with DTensor, ordinarily all modes will
# run **before** subclasses get a chance to run.
# Returning NotImplemented here gives us a chance to let DTensor
# run and desugar into comms ops, before CommDebugMode sees them.
# sets up operation-level collective count
if self.advanced_module_tracker.name not in self.comm_module_operation_counts:
# dictionary should hold module input and output shape, operations list and collective counter
self.comm_module_operation_counts[self.advanced_module_tracker.name] = {
"operations_list": []
}
operation_dict = {}
operation_dict["name"] = func
operation_dict["input_shape"] = []
operation_dict["input_sharding"] = []
operation_dict["device_mesh"] = ""
# tracks if the operation is part of the backward pass
operation_dict["is_bw"] = self.advanced_module_tracker.is_bw
# tracks if the operation is part of activation checkpointing
operation_dict[
"is_activation_checkpointing"
] = self.advanced_module_tracker.activation_checkpointing
if any(t == DTensor for t in types):
for ele in args:
if isinstance(ele, DTensor):
# saves shapes and placements of all DTensor args
operation_dict["input_shape"].append(ele.shape)
operation_dict["input_sharding"].append(ele.placements)
operation_dict["device_mesh"] = str(ele.device_mesh)
self.comm_module_operation_counts[self.advanced_module_tracker.name][
"operations_list"
].append(operation_dict)
return NotImplemented
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
# We have many tests that use CommDebugMode to verify the occurrence of
# collectives. These tests do so by querying comm_counts with legacy
# funcol ops as key. For the purpose of native funcol migration, we
# need these tests to work for both legacy and native funcol. To avoid
# the need to modify all tests to accommodate the two implementations,
# we make CommDebugMode translate native funcol ops into legacy funcol
# ops until the migration finishes.
if func_packet in self.comm_registry or func_packet in c10d_collective_ops:
if func_packet in NATIVE_TO_PY_MAPPING:
func_packet = NATIVE_TO_PY_MAPPING[func_packet]
self.comm_counts[func_packet] += 1
key = "forward"
if self.advanced_module_tracker.is_bw:
key = "backward"
# adds collective count to current module
if self.advanced_module_tracker.name not in self.comm_module_counts:
self.comm_module_counts[self.advanced_module_tracker.name] = {}
self.comm_module_counts[self.advanced_module_tracker.name][
"forward"
] = defaultdict(int)
self.comm_module_counts[self.advanced_module_tracker.name][
"backward"
] = defaultdict(int)
self.comm_module_counts[self.advanced_module_tracker.name][key][
func_packet
] += 1
# adds collective count to parent modules
for par in self.advanced_module_tracker.module_parents_dict[
self.advanced_module_tracker.name
]:
# makes sure we aren't double counting when current sub-module hasn't been removed from parents
if par != self.advanced_module_tracker.name:
if par not in self.comm_module_counts:
self.comm_module_counts[par] = {}
self.comm_module_counts[par]["forward"] = defaultdict(int)
self.comm_module_counts[par]["backward"] = defaultdict(int)
self.comm_module_counts[par][key][func_packet] += 1
# if tensor op uses fake tensors, return
if detect_fake_mode(args):
return out
# add tensor operation to module operation list
self.comm_module_operation_counts[self.advanced_module_tracker.name][
"operations_list"
].append(operation_dict)
return out

View File

@ -0,0 +1,105 @@
# mypy: allow-untyped-defs
from operator import itemgetter
from typing import List
import torch
import torch.fx
import torch.nn as nn
from functorch.compile import make_boxed_func
from torch._functorch.compilers import aot_module
from torch._inductor.decomposition import select_decomp_table
from torch.distributed.tensor import DTensor
inductor_decomps = select_decomp_table()
graphs: List[torch.fx.GraphModule] = []
def fwd_bwd_compiler(fx_g, _):
graphs.append(fx_g)
return make_boxed_func(fx_g)
def get_inductor_decomp_graphs(model: nn.Module, args, kwargs):
"""
Obtain forward and backward graphs of a model with inductor decompositions using tracing and aot_module.
Convenient util to get the fwd and bwd graphs of an arbitrary model
with inductor decompositions. Note that this would simply do tracing
with aot_module and don't ensure correctness. This is useful to track
the ops needed in DTensor.
"""
compiled_mod = aot_module(
model, fw_compiler=fwd_bwd_compiler, decompositions=inductor_decomps
)
output = compiled_mod(*args, **kwargs)
if output.ndim != 0:
# if output is not a scalar tensor, by default sum it in order to
# run backward
output = output.sum()
output.backward()
# one fwd, one bwd graph
assert len(graphs) == 2
return graphs
def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=False):
"""
Util to print the operator coverage summary of a certain model with tabulute.
Must have tabulate module installed.
"""
# python module required for summary
import csv
from tabulate import tabulate
fwd_graph, bwd_graph = get_inductor_decomp_graphs(model, args, kwargs)
op_counts = {}
for node in fwd_graph.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
if node.target not in op_counts:
op_counts[node.target] = 0
op_counts[node.target] += 1
for node in bwd_graph.graph.nodes:
if node.op == "call_function" and isinstance(
node.target, torch._ops.OpOverload
):
if node.target not in op_counts:
op_counts[node.target] = 0
op_counts[node.target] += 1
op_infos = []
for op, count in op_counts.items():
supported = op in DTensor._op_dispatcher.sharding_propagator.op_to_rules
op_infos.append([op, str(op._schema), count, supported])
# sort the op info base on the total count index
count_idx = 2
op_infos.sort(key=itemgetter(count_idx), reverse=True)
headers = ["Operator", "Schema", "Total Count", "Supported"]
print(tabulate(op_infos, headers=headers))
if output_csv:
# Open a CSV file for writing
with open("op_summary.csv", "w", newline="") as csv_file:
# Create a CSV writer object
csv_writer = csv.writer(csv_file)
csv_writer.writerow(headers)
# Write each table row to the CSV file
for row in op_infos:
csv_writer.writerow(row)

View File

@ -0,0 +1,178 @@
# mypy: allow-untyped-defs
from typing import List, Sequence, Tuple
import numpy as np
from torch._prims_common import ShapeType
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement, Shard
__all__ = ["visualize_sharding"]
def _mesh_to_coordinate(mesh, device_type):
"""
Given a n-dimensional list of device mesh, this function creates a map of
device and its coordinate
"""
# Convert the n-dimensional list to a NumPy array
np_mesh = np.array(mesh.mesh.tolist())
# Create a dictionary to map each value to its coordinate
device_to_coordinate_map = {}
for coord, value in np.ndenumerate(np_mesh):
# device is unique in device_mesh
device_to_coordinate_map[f"{device_type}:{str(value)}"] = list(coord)
return device_to_coordinate_map
def _convert_offset_to_ranges(all_offsets):
"""
Using tabulate package to create a table is easier when we specify row and col ranges
This function converts offsets to ranges.
"""
converted_blocks = []
for offset in all_offsets:
shape, offset, value = offset
# Calculate row_range and column_range
row_range = (offset[0], offset[0] + shape[0] - 1)
column_range = (offset[1], offset[1] + shape[1] - 1)
# Convert value to string to match your desired format
converted_block = {
"row_range": row_range,
"column_range": column_range,
"value": str(value),
}
converted_blocks.append(converted_block)
return converted_blocks
def _create_table(blocks):
"""
Creates a tabulate table given row and column ranges with device name
"""
try:
from tabulate import tabulate
except ImportError as e:
raise ImportError("tabulate package is required to visualize sharding") from e
# Extract unique row and column ranges
row_ranges = sorted({block["row_range"] for block in blocks})
col_ranges = sorted({block["column_range"] for block in blocks})
# Create a matrix initialized with empty strings
matrix = [["" for _ in col_ranges] for _ in row_ranges]
# Fill the matrix with values
for block in blocks:
row_index = row_ranges.index(block["row_range"])
col_index = col_ranges.index(block["column_range"])
if matrix[row_index][col_index] == "":
matrix[row_index][col_index] = block["value"]
else:
matrix[row_index][col_index] += ", " + block["value"]
# Prepare headers
row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges]
col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges]
return tabulate(matrix, headers=col_headers, showindex=row_headers)
def _compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh: DeviceMesh,
placements: Sequence[Placement],
my_coordinate: List[int],
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but
with custom my_coordinate input. This is the modified implementation for visualize_sharding.
"""
if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((), ())
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
return_offset=True,
)
local_shape[shard_dim] = shard_size
local_offset[shard_dim] = shard_offset
# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
# it means that this dimension has been already sharded in previous placement.
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
if global_offset[shard_dim] <= local_offset[shard_dim]:
global_offset[shard_dim] = local_offset[shard_dim]
else:
global_offset[shard_dim] += local_offset[shard_dim]
return tuple(local_shape), tuple(global_offset)
def visualize_sharding(dtensor, header=""):
"""
Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D.
.. note:: This requires the ``tabulate`` package. No sharding info will be printed for empty tensors
"""
if dtensor.numel() == 0: # we do not print for empty dtensors
return
if len(dtensor.shape) >= 3:
raise RuntimeError(
"visualize sharding is only implemented for 1D or 2D dtensor"
)
placements = dtensor.placements
device_mesh = dtensor.device_mesh
device_type = dtensor.device_mesh.device_type
if device_mesh.get_coordinate() is None: # current rank is not in the mesh
return
# Only display the visualization once for each DTensor, on the rank whose
# coordinate is 0 on all dimensions. For example, if the mesh is a full mesh,
# we will only print on rank 0.
local_rank_zero_on_all_dim = all(
device_mesh.get_local_rank(mesh_dim=dim) == 0 for dim in range(device_mesh.ndim)
)
if not local_rank_zero_on_all_dim:
return
device_map = _mesh_to_coordinate(device_mesh, device_type)
all_offsets = []
for device in device_map:
local_shape, global_offset = _compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, placements, device_map[device]
)
all_offsets.append([local_shape, global_offset, device])
# Convert offsets to blocks with row_ranges for tabulate
blocks = _convert_offset_to_ranges(all_offsets)
# Print the table
print(header)
print(_create_table(blocks))

View File

@ -0,0 +1,9 @@
from torch.distributed.device_mesh import ( # noqa: F401
_get_device_handle,
_mesh_resources,
DeviceMesh,
init_device_mesh,
)
__all__ = ["init_device_mesh", "DeviceMesh"]

View File

@ -0,0 +1,32 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from contextlib import contextmanager
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor.experimental._func_map import local_map
from torch.distributed.tensor.experimental._register_sharding import register_sharding
__all__ = ["implicit_replication", "local_map", "register_sharding"]
@contextmanager
def implicit_replication():
"""
This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``)
in the program be replicate :class:`DTensor` s during the operator computation.
.. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated
in practice, please use it at your discretion.
"""
try:
DTensor._op_dispatcher._allow_implicit_replication = True
yield
finally:
DTensor._op_dispatcher._allow_implicit_replication = False
# Set namespace for exposed private names
implicit_replication.__module__ = "torch.distributed.tensor.experimental"
local_map.__module__ = "torch.distributed.tensor.experimental"
register_sharding.__module__ = "torch.distributed.tensor.experimental"

View File

@ -0,0 +1,867 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import itertools
import logging
import types
import weakref
from enum import Enum
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.nn.functional as F
from torch import nn
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard
from torch.distributed.tensor.parallel.style import ParallelStyle
# TODO: expose a single API
__all__ = ["context_parallel"]
aten = torch.ops.aten
logger = logging.getLogger(__name__)
# Whether to upcast parameters and gradients to float32 to avoid accumulation
# errors. It is likely this is always True but we currently keep this variable
# for the experimental purpose.
_convert_to_f32 = True
class _CausalBehavior(Enum):
SKIP = None
NOT_IS_CAUSAL = False
IS_CAUSAL = True
def _is_causal_behavior(
rank: int, world_size: int, i: int, is_causal: bool
) -> _CausalBehavior:
"""
Calculate is_causal behavior for each KV block. The attention can either be
calculated in full, not at all or with the causal mask applied.
"""
if not is_causal:
return _CausalBehavior.NOT_IS_CAUSAL
if i == 0:
return _CausalBehavior.IS_CAUSAL
source_rank = (rank - i) % world_size
if source_rank < rank:
return _CausalBehavior.NOT_IS_CAUSAL
else:
return _CausalBehavior.SKIP
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
"""
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
so we cannot call ``wait()``.
"""
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
return tensor.wait()
return tensor
class _SDPAMerger:
"""A class to help to merge the local SDPA result."""
def __init__(self, convert_to_f32: bool):
self._out: Optional[torch.Tensor] = None
self._lse: Optional[torch.Tensor] = None
self._convert_to_f32 = convert_to_f32
self._out_dtype = torch.float32
self._lse_dtype = torch.float32
def _merge_one(self, block_out: torch.Tensor, block_lse: torch.Tensor) -> None:
block_lse = block_lse.unsqueeze(dim=-1)
if self._lse is None:
self._lse = block_lse
self._out = block_out
else:
# The algorithm from
# github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
# gives a relatively stable result.
self._out = self._out - F.sigmoid(block_lse - self._lse) * (
self._out - block_out
)
self._lse = self._lse - F.logsigmoid(self._lse - block_lse)
def step(self, out: torch.Tensor, lse: torch.Tensor) -> None:
self._out_dtype = out.dtype
self._lse_dtype = lse.dtype
if self._convert_to_f32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
self._merge_one(out, lse)
def results(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._out is not None
assert self._lse is not None
out, lse = self._out, self._lse.squeeze(-1)
return out.to(self._out_dtype), lse.to(self._lse_dtype)
def _scaled_dot_product_ring_flash_attention(
mesh: DeviceMesh,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
return_debug_mask: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
if return_debug_mask:
raise NotImplementedError("return_debug_mask is not supported yet")
return _templated_ring_attention(
mesh,
aten._scaled_dot_product_flash_attention,
query=query,
key=key,
value=value,
is_causal=is_causal,
dropout_p=dropout_p,
scale=scale,
)
def _scaled_dot_product_ring_efficient_attention(
mesh: DeviceMesh,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
compute_log_sumexp: bool = True,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
if attn_bias is not None:
raise NotImplementedError("attn_bias is not supported yet")
if not compute_log_sumexp:
raise NotImplementedError("compute_log_sumexp must be set")
return _templated_ring_attention(
mesh,
aten._scaled_dot_product_efficient_attention,
query=query,
key=key,
value=value,
is_causal=is_causal,
attn_bias=attn_bias,
dropout_p=dropout_p,
scale=scale,
compute_log_sumexp=compute_log_sumexp,
)
class _AttentionOp(Protocol):
def __call__(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kwargs: object,
) -> Tuple[torch.Tensor, ...]:
...
def _ring_rotate(
block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool
) -> torch.Tensor:
size = dist.get_world_size(pg)
dsts = (
list(range(1, size)) + [0]
if send_to_next
else [size - 1] + list(range(0, size - 1))
)
return ft_c.permute_tensor(block, dsts, pg)
def _templated_ring_attention(
mesh: DeviceMesh,
op: _AttentionOp,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
**kwargs: object,
) -> Tuple[torch.Tensor, ...]:
"""
This is a generalized ring attention implementation that can support multiple attention ops.
Parameters
----------
op:
The attention op to use
*args:
additional args are passed to the op
**kwargs:
additional kwargs are passed to the op
Returns
-------
out:
The merged attention output
softmax_lse:
The logsumexp of the merged attention output
"""
if is_causal and (query.size(2) != key.size(2)):
raise NotImplementedError(
"is_causal requires the same query and context sequence lengths"
)
if isinstance(mesh, dist.ProcessGroup):
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
else:
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
rank = dist.get_rank(pg)
size = dist.get_world_size(pg)
next_kv = None
# Without making key and value contiguous(), the lose curve is bad.
# TODO(fegin): figure out why this is a requirement since SDPA does not have
# this requirement.
key = key.contiguous()
value = value.contiguous()
sdpa_merger = _SDPAMerger(_convert_to_f32)
rest: List[Any]
out: torch.Tensor
logsumexp: torch.Tensor
for i in range(size):
# overlap communication with compute
if next_kv is not None:
next_kv = _maybe_wait(next_kv)
key = next_kv[: key.numel()].reshape(key.shape)
value = next_kv[key.numel() :].reshape(value.shape)
if i < (size - 1):
next_kv = torch.cat([key.flatten(), value.flatten()])
next_kv = _ring_rotate(next_kv, pg, send_to_next=True)
is_causal_behavior = _is_causal_behavior(
rank=rank, world_size=size, i=i, is_causal=is_causal
)
if is_causal_behavior != _CausalBehavior.SKIP:
out, logsumexp, *rest = op(
query,
key,
value,
is_causal=is_causal_behavior.value,
**kwargs,
)
sdpa_merger.step(out, logsumexp)
return *sdpa_merger.results(), *rest
def _sdpa_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# extract local tensor and sharding infos to a OpInfo
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
logger.debug("Dispatching op_call: %s", op_info.schema)
# sharding propagation
# TODO: remove the context parallel strategy from the default propagation
# rule. Either figure out how to dynamically enable it or just don't call
# propagate.
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
assert not output_sharding.needs_redistribute, "inputs need to be redistributed"
if op_call == aten._scaled_dot_product_flash_attention.default:
local_results = _scaled_dot_product_ring_flash_attention(
op_info.mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
elif op_call == aten._scaled_dot_product_efficient_attention.default:
local_results = _scaled_dot_product_ring_efficient_attention(
op_info.mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
else:
raise NotImplementedError(
"CP only supports flash attention and memory efficient attention now."
)
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
def _sdpa_backward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as output tensor
args = list(args)
args = tuple(args)
# extract local tensor and sharding infos to a OpInfo
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
logger.debug("Dispatching op_call: %s", op_info.schema)
# sharding propagation
DTensor._op_dispatcher.sharding_propagator.propagate(op_info)
output_sharding = op_info.output_sharding
assert output_sharding is not None, "output sharding should not be None"
assert not output_sharding.needs_redistribute, "inputs need to be redistributed"
if op_call == aten._scaled_dot_product_flash_attention_backward.default:
local_results = _scaled_dot_product_ring_flash_attention_backward(
op_info.mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
elif op_call == aten._scaled_dot_product_efficient_attention_backward.default:
local_results = _scaled_dot_product_ring_efficient_attention_backward(
op_info.mesh,
*op_info.local_args, # type: ignore[arg-type]
**op_info.local_kwargs, # type: ignore[arg-type]
)
else:
raise NotImplementedError(f"{op_call=}")
return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec)
def _templated_ring_attention_backward(
mesh: DeviceMesh,
op: _AttentionOp,
grad_out: torch.Tensor,
grad_out_name: str,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
is_causal: bool,
**kwargs: Any,
) -> Tuple[torch.Tensor, ...]:
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
rank = dist.get_rank(pg)
size = dist.get_world_size(pg)
next_kv = None
next_grad_kv = None
rest: List[Any]
grad_query_, grad_key_, grad_value_ = None, None, None
accum_dtype = torch.float32 if _convert_to_f32 else query.dtype
grad_query = torch.zeros_like(query, dtype=accum_dtype)
grad_key = torch.zeros_like(key, dtype=accum_dtype)
grad_value = torch.zeros_like(value, dtype=accum_dtype)
key = key.contiguous()
value = value.contiguous()
for i in range(size):
if next_kv is not None:
buffer = _maybe_wait(next_kv)
pointer = 0
key = buffer[pointer : pointer + key.numel()].reshape(key.shape)
pointer += key.numel()
value = buffer[pointer : pointer + value.numel()].reshape(value.shape)
pointer += value.numel()
if i != size - 1:
next_kv = torch.cat([key.flatten(), value.flatten()])
next_kv = _ring_rotate(next_kv, pg, send_to_next=True)
is_causal_behavior = _is_causal_behavior(
rank=rank, world_size=size, i=i, is_causal=is_causal
)
if is_causal_behavior != _CausalBehavior.SKIP:
kwargs[grad_out_name] = grad_out
grad_query_, grad_key_, grad_value_, *rest = op(
query=query,
key=key,
value=value,
out=out,
logsumexp=logsumexp,
is_causal=is_causal_behavior.value,
**kwargs,
)
else:
grad_query_ = torch.zeros_like(query, dtype=accum_dtype)
grad_key_ = torch.zeros_like(key, dtype=accum_dtype)
grad_value_ = torch.zeros_like(value, dtype=accum_dtype)
# Get the grad key and grad value for the i round.
if i > 0:
pointer = 0
assert next_grad_kv is not None
next_grad_kv = _maybe_wait(next_grad_kv)
grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape(
grad_key.shape
)
pointer += grad_key.numel()
grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape(
grad_value.shape
)
grad_key += grad_key_
grad_value += grad_value_
# Send the key, value, grad key, and grad value to the next rank.
next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()])
next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True)
grad_query += grad_query_
assert next_grad_kv is not None
assert grad_key_ is not None
assert grad_value_ is not None
grad_query = grad_query.to(query.dtype)
next_grad_kv = _maybe_wait(next_grad_kv).to(key.dtype)
grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape)
grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape)
return (
grad_query,
grad_key,
grad_value,
*rest,
)
def _scaled_dot_product_ring_flash_attention_backward(
mesh: DeviceMesh,
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
cum_seq_q: torch.Tensor,
cum_seq_k: torch.Tensor,
max_q: int,
max_k: int,
dropout_p: float,
is_causal: bool,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
return _templated_ring_attention_backward(
mesh,
aten._scaled_dot_product_flash_attention_backward.default,
grad_out=grad_out,
grad_out_name="grad_out",
query=query,
key=key,
value=value,
out=out,
logsumexp=logsumexp,
is_causal=is_causal,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=max_q,
max_k=max_k,
dropout_p=dropout_p,
philox_seed=philox_seed,
philox_offset=philox_offset,
scale=scale,
)
def _scaled_dot_product_ring_efficient_attention_backward(
mesh: DeviceMesh,
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bias: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
philox_seed: torch.Tensor,
philox_offset: torch.Tensor,
dropout_p: float,
grad_input_mask: Tuple[bool, ...],
is_causal: bool = False,
*,
scale: Optional[float] = None,
) -> Tuple[torch.Tensor, ...]:
return _templated_ring_attention_backward(
mesh,
aten._scaled_dot_product_efficient_attention_backward.default,
grad_out=grad_out,
grad_out_name="grad_out_",
query=query,
key=key,
value=value,
attn_bias=bias,
out=out,
logsumexp=logsumexp,
philox_seed=philox_seed,
philox_offset=philox_offset,
dropout_p=dropout_p,
grad_input_mask=grad_input_mask,
is_causal=is_causal,
scale=scale,
)
customized_ops = {
aten._scaled_dot_product_flash_attention.default: _sdpa_handler,
aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler,
aten._scaled_dot_product_efficient_attention.default: _sdpa_handler,
aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler,
}
_replaced_functions: Dict[Callable, Tuple[str, Callable]] = {}
def _distribute_function(
fn: Callable,
fn_module: types.ModuleType,
device_mesh: DeviceMesh,
input_fn: Optional[Callable] = None,
output_fn: Optional[Callable] = None,
) -> None:
"""
``distribute_function`` is an experimental API that allows users to "distribute"
the inputs and outputs of a function. Similar to ``distribute_module``, this API
installs hooks to the ``fn`` to convert the inputs and outputs. There are two
major differences between ``distribute_function`` and ``distribute_module``.
First, a function does not have parammeters and buffers, as a result,
``distribute_function`` itself won't convert any parameters/buffers but simply
install the input and output hooks. The tensor conversion will happen in the hooks.
Another difference is an nn.Module subclass can have several instances and each
instance be fed into ``distribute_module`` independently with affecting other
instance. On the other hand, function is a singleton object. So if a function
is distributed by ``distribute_function`` all subsequent calls to the function
will invoke the installed hooks.
Args:
fn (Callable): the function to be distributed.
fn_module (types.ModuleType): the Python module that the function is declared.
e.g., if ``fn`` is ``torch.nn.functional.scaled_dot_product_attention``,
``fn_module`` is ``torch.nn.functional``.
device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the
input and output hooks to distribute the tensors.
input_fn (Optioinal[Callable]): the hook to distribute or convert the input
arguments of ``fn``.
output_fn (Optioinal[Callable]): the hook to distribute or convert the output
arguments of ``fn``.
"""
def wrapper(
target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable]
) -> Callable:
def inner_fn(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
if input_fn is not None:
args, kwargs = input_fn(device_mesh, *args, **kwargs)
output = target_fn(*args, **kwargs)
if output_fn is not None:
output = output_fn(device_mesh, output)
return output
return inner_fn
global _replaced_functions
if fn in _replaced_functions:
return
wrapper_fn = wrapper(fn, input_fn, output_fn)
setattr(fn_module, fn.__name__, wrapper_fn)
_replaced_functions[wrapper_fn] = (fn.__name__, fn)
def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None:
"""Restore the function that is replaced by _distribute_function."""
global _original_functions
global _wrapper_functions
if fn not in _replaced_functions:
return
original_name, original_fn = _replaced_functions[fn]
setattr(fn_module, original_name, original_fn)
@contextlib.contextmanager
def _enable_cp_dispatcher() -> Generator[None, None, None]:
"""Enables DTensor dispatcher to dispatch SDPA to CP."""
old_handlers = DTensor._op_dispatcher._custom_op_handlers
DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops}
yield
DTensor._op_dispatcher._custom_op_handlers = old_handlers
class _AttentionContextParallel(ParallelStyle):
"""
Applies context parallel optimizations to the attention layer.
This will work for nn.MultiHeadedAttention and custom attention layers that
call F.scaled_dotproduct_attention with a simliar signature.
This expects the `forward` method consumes either:
* a single tensor for self attention
* one argument for each of: query, key, value
This currently only supports ring attention and the
SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel.
Non-flash attention backends will result in incorrect results.
"""
# use a weakref dictionary to store context managers for each nn.Module
_CONTEXT_MANAGERS: "weakref.WeakKeyDictionary[nn.Module, Any]" = (
weakref.WeakKeyDictionary()
)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
if not isinstance(device_mesh, DeviceMesh):
raise ValueError(
f"{type(device_mesh)} is not supported by {type(self)} yet."
)
if not device_mesh.ndim == 1:
raise ValueError
return distribute_module(
module,
device_mesh,
input_fn=self._input_fn, # type: ignore[arg-type]
output_fn=self._output_fn, # type: ignore[arg-type]
)
@classmethod
def _input_fn(
cls,
module: nn.Module,
inputs: Tuple[Union[torch.Tensor, int, float], ...],
device_mesh: DeviceMesh,
) -> Tuple[Union[torch.Tensor, int, float], ...]:
# TODO(d4l3k); this should be Shard(2), need to fix Linear layer rules
placement = [Replicate()]
def backward_hook(grad: torch.Tensor) -> None:
if module in cls._CONTEXT_MANAGERS:
cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
del cls._CONTEXT_MANAGERS[module]
# convert inputs to DTensor
inp = []
for input in inputs:
if isinstance(input, torch.Tensor) and not isinstance(input, DTensor):
input = DTensor.from_local(
input.contiguous(), device_mesh, placement, run_check=False
)
if isinstance(input, torch.Tensor) and input.requires_grad:
input.register_hook(backward_hook)
inp.append(input)
manager = _enable_cp_dispatcher()
manager.__enter__()
cls._CONTEXT_MANAGERS[module] = manager
return tuple(inp)
@classmethod
def _output_fn(
cls,
module: nn.Module,
outputs: Union[torch.Tensor, Tuple[Union[torch.Tensor, int, float], ...]],
device_mesh: DeviceMesh,
) -> Union[
Union[torch.Tensor, int, float], Tuple[Union[torch.Tensor, int, float], ...]
]:
cls._CONTEXT_MANAGERS[module].__exit__(None, None, None)
del cls._CONTEXT_MANAGERS[module]
def backward_hook(grad: torch.Tensor) -> None:
if module not in cls._CONTEXT_MANAGERS:
manager = _enable_cp_dispatcher()
manager.__enter__()
cls._CONTEXT_MANAGERS[module] = manager
# back to local tensor
out = []
for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
output = output.to_local() if isinstance(output, DTensor) else output
if isinstance(output, torch.Tensor) and output.requires_grad:
output.register_hook(backward_hook)
out.append(output)
if isinstance(outputs, torch.Tensor):
return out[0]
return tuple(out)
@contextlib.contextmanager
def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]:
"""Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
def attention_input_fn(
mesh: DeviceMesh, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
placement = [Shard(seq_dim)]
all_args = []
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
arg = DTensor.from_local(arg, mesh, placement, run_check=False)
all_args.append(arg)
new_args = tuple(all_args[0 : len(args)])
new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :]))
return new_args, new_kwargs
def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any:
new_outputs = []
for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs:
output = output.to_local() if isinstance(output, DTensor) else output
new_outputs.append(output)
if isinstance(outputs, torch.Tensor):
return new_outputs[0]
return tuple(new_outputs)
# TODO: provide a more robust way to replace SDPA.
# Currently we use monkey patch to replace scaled_dot_product_attention with the
# wrapped fn. This is okay if users do `import torch.nn.functional` but will not
# work if users do `import torch.nn.functional.scaled_dot_product_attention`.
_distribute_function(
F.scaled_dot_product_attention,
F,
mesh,
attention_input_fn,
attention_output_fn,
)
with _enable_cp_dispatcher():
yield
_restore_function(F.scaled_dot_product_attention, F)
def _get_sequence_shard(
buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int
) -> torch.Tensor:
return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()]
def _context_parallel_buffers(
mesh: DeviceMesh,
buffers: List[torch.Tensor],
buffer_seq_dims: List[int],
) -> List[torch.Tensor]:
"""Shard the buffers along the sequence dimensions according to CP rules."""
new_buffers = []
for buffer, seq_dim in zip(buffers, buffer_seq_dims):
new_buffers.append(_get_sequence_shard(buffer, mesh, seq_dim))
return new_buffers
@contextlib.contextmanager
@torch.no_grad()
def context_parallel(
mesh: DeviceMesh,
*,
buffers: Optional[List[torch.Tensor]] = None,
buffer_seq_dims: Optional[List[int]] = None,
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
``context_parallel`` is an experimental API to enable context
parallelism (CP). This API performs two actions: 1) patch the SDPA
(``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled
one, 2) shard ``buffers`` along the sequence dimension and each rank will
preserve the corresponding shard according ``mesh``.
Args:
mesh (:class:`DeviceMesh`): the device mesh for the context parallelism.
buffers (Optional[List[torch.Tensor]]): buffers that the usage depend
on the sequence dimension. Examples are input batch, labels and
positional embedding buffers. These buffers must be sharded along
the sequence dimension to ensure the accuracy. The sharding will
happen in-place, the buffer's shape will change within the context.
The buffers will be restored after the context finishes.
``no_restore_buffers`` can be used to specify which buffers don't
need to be restored. Note that ``buffers`` should not contain any
nn.Parameter.
buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``.
no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set
won't be restored after the context exits. This set must be a subset
of ``buffers``. If the buffers won't be used after the context exits,
these buffers can be put in this list to avoid extra restore time.
.. warning::
`torch.distributed._tensor.experimental.attention.context_parallel` is a
prototype feature in PyTorch. The API is subject to change.
"""
buffers = [] if buffers is None else buffers
buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
if len(buffers) != len(buffer_seq_dims):
raise ValueError(
"`seq_dims` must have the same number of elements as `buffers`."
)
for buffer in no_restore_buffers:
# Cannot use `if not buffer in buffers` which will incur tensor comparison.
if not any(b is buffer for b in buffers):
raise ValueError("`no_restore_buffers` must be a subset of `buffers`.")
original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers]
chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims)
for buffer, chunk in zip(buffers, chunks):
chunk = chunk.clone()
buffer.resize_(chunk.shape)
buffer.copy_(chunk)
with _context_parallel(seq_dim=2, mesh=mesh):
yield
for buffer, original_buffer in zip(buffers, original_buffers):
if original_buffer is not None:
buffer.resize_(original_buffer.shape)
buffer.copy_(original_buffer)

View File

@ -0,0 +1,228 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor.placement_types import Placement
try:
from torch.utils import _cxx_pytree as pytree
except ImportError:
from torch.utils import _pytree as pytree # type: ignore[no-redef]
__all__ = ["local_map"]
PlacementType = Optional[Sequence[Placement]]
InputPlacements = Optional[Tuple[PlacementType, ...]]
OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]]
def local_map(
func: Callable,
out_placements: OutputPlacements,
in_placements: Optional[InputPlacements] = None,
device_mesh: Optional[DeviceMesh] = None,
*,
redistribute_inputs: bool = False,
):
"""
:meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s
to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting
the local components of :class:`DTensor`, call the function, and wrap the outputs to
:class:`DTensor` according to the ``out_placements``.
Args:
func (Callable): the function to be applied on each local shard of
:class:`DTensor` s.
out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]):
the desired placements of the :class:`DTensor` s in ``func``'s flattened output.
If the flattened ``output`` is a single value, the ``out_placements`` should be
of type `PlacementType`. Otherwise if the flattened ``output`` has multiple
values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1
mapping to the flattened ``output``.
Besides, for :class:`Tensor` output, we use `PlacementType` as its
placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType`
should be `None`.
Note that the only exception is when no :class:`DTensor` argument is passed
in. In this case, even if `out_placements` is not `None`, the result function
should ignore the desired placements because the function is not running with
:class:`DTensor` s.
in_placements (Tuple[`PlacementType`, ...], optional):
the required placements of the :class:`DTensor` s in the flattened inputs of ``func``.
If ``in_placements`` is specified, :meth:`local_map` would examine whether the
placements of each :class:`DTensor` argument is the same as the required
placements or not. If the placements are not the same and
``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if
``redistribute_inputs`` is ``True``, the argument will be first redistributed to
the required sharding placements before passing its local tensor to ``func``.
The only exception is when required placements are not ``None`` and the
argument is a :class:`torch.Tensor`. In this case, the placements examination
will be skipped and the argument will be directly passed to ``func``.
If ``in_placements`` is ``None``, no placements examination will be performed.
Default: None
device_mesh (:class:`DeviceMesh`, optional):
the device mesh that all the :class:`DTensor` s are placed on. If not
specified, this will be inferred from the input :class:`DTensor` s' device
mesh. `local_map` requires every :class:`DTensor` s to be placed on the same
device mesh. Default: None.
redistribute_inputs (bool, optional):
the bool value indicating whether to reshard the input :class:`DTensor` s when
their placements are different from the required input placements. If this
value is ``False`` and some :class:`DTensor` input has a different placement,
an exception will be raised. Default: False.
Returns:
A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor`
and returns a :class:`DTensor` constructed from the return value of ``func``.
Raises:
AssertionError: If the input :class:`DTensor` is not placed on the same device
mesh, or if they are placed on a different device mesh than the ``device_mesh``
argument passed in.
AssertionError: For any non-DTensor output, we require its corresponding
output placement in ``out_placements`` be None. An AssertionError will be raised
if this is not the case.
ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs
a redistribution according to ``in_placements``.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> def mm_allreduce_forward(device_mesh, W, X):
>>> partial_sum_tensor = torch.mm(W, X)
>>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>> return reduced_tensor
>>>
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
>>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
>>>
>>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion
>>> local_mm_allreduce_forward = local_map(
>>> mm_allreduce_forward,
>>> out_placements=[Replicate()],
>>> in_placements=[col_wise, row_wise],
>>> device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
.. note:: This API is currently experimental and subject to change
"""
def wrapped(*args, **kwargs):
# process input args
flat_args, args_spec = pytree.tree_flatten(args)
if in_placements is not None:
assert len(in_placements) == len(flat_args), (
f"in_placements length {len(in_placements)} does not match the number "
f"of input args {len(flat_args)}!"
)
# we assume every DTensor object is placed on the same device mesh
flat_local_args = []
nonlocal device_mesh # access var device_mesh from the outer scope
seen_dtensor_arg = False
for idx, arg in enumerate(flat_args):
if isinstance(arg, DTensor):
# TODO: the current code doesn't consider the uneven sharding case
# Need to think about what the consequence is when the input DTensor
# is uneven sharded.
if device_mesh is None: # infer device mesh from the DTensor arg
device_mesh = arg.device_mesh
# this function is applied to at least one DTensor argument
seen_dtensor_arg = True
assert arg.device_mesh == device_mesh, (
f"arg {arg} in local_map has a mismatched device mesh: "
f"{arg} has device mesh {arg.device_mesh} while "
f"the expected device mesh is {device_mesh}!"
)
if in_placements is not None:
spec = in_placements[idx]
assert (
spec is not None
), f"DTensor input {arg} expects placements but received {spec}!"
if not isinstance(spec, tuple):
spec = tuple(spec)
if arg.placements != spec:
if redistribute_inputs:
# redistribute to input placements
arg = arg.redistribute(device_mesh, spec)
else:
raise ValueError(
f"arg {arg} in local_map has a mismatched placements: "
f"arg placements is {arg.placements} but the input "
f"placements is {spec}! "
"If redistribute_inputs is wanted, set "
"redistribute_inputs=True to local_map."
)
local_arg = arg.to_local()
if isinstance(local_arg, AsyncCollectiveTensor):
local_arg = local_arg.wait()
flat_local_args.append(local_arg)
else:
# Non-Tensor input must have None in `in_placements`
if in_placements is not None and not isinstance(arg, torch.Tensor):
spec = in_placements[idx]
assert spec is None, (
f"Non-Tensor input {arg} expects None placements "
f"but received {spec}!"
)
flat_local_args.append(arg)
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
out = func(*local_args, **kwargs)
if seen_dtensor_arg:
# process output
flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = []
out_placements_tuple = (
out_placements
if isinstance(out_placements, tuple)
else (out_placements,)
)
assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor):
assert not isinstance(
out, DTensor
), f"torch.Tensor output expected but received {type(out)}: {out}"
flat_dist_out.append(
DTensor.from_local(out, device_mesh, spec, run_check=False)
)
else:
assert (
spec is None
), f"Non-tensor output {out} expects None placements but received {spec}!"
flat_dist_out.append(out)
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out
return wrapped

View File

@ -0,0 +1,136 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Callable, List, Sequence, Tuple, Union
import torch
from torch._ops import OpOverload
from torch.distributed.tensor import DeviceMesh, DTensor
from torch.distributed.tensor._op_schema import (
_is_inplace_op,
OpSchema,
OpStrategy,
PlacementList,
RuntimeSchemaInfo,
StrategyType,
TupleStrategy,
)
from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
__all__ = ["register_sharding"]
def register_sharding(op: Union[OpOverload, List[OpOverload]]):
"""
:meth:`register_sharding` is an experimental API that allows users to register sharding
strategies for an operator when the tensor inputs and outputs are DTensor.
It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``,
e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2)
when users would like to overwrite default sharding strategies of existing operators.
Args:
op (Union[OpOverload, List[OpOverload]]):
An op or a list of ops to register the customized sharding function.
Returns:
A function decorator which can be used to wrap a function that defines the sharding
strategy for the operator specified in ``op``. The defined sharding strategy will be
registered to DTensor and will override the default sharding strategy if DTensor has
already implemented the operator. The customized sharding function takes the same inputs
as the original op (except that if an arg is a :class:`torch.Tensor`, it will be
replaced by a tensor-like object that DTensor uses internally). The function should
return a sequence of 2-tuples, each specifying acceptable output placements and its
corresponding intput placements.
Example:
>>> # xdoctest: +SKIP("distributed")
>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>> softmax_dim = dim if dim >= 0 else dim + x.ndim
>>> acceptable_shardings = []
>>>
>>> all_replicate = ([Replicate()], [Replicate(), None, None])
>>> acceptable_shardings.append(all_replicate)
>>>
>>> for sharding_dim in range(x.ndim):
>>> if sharding_dim != softmax_dim:
>>> all_sharded = (
>>> [Shard(sharding_dim)],
>>> [Shard(sharding_dim), None, None],
>>> )
>>> acceptable_shardings.append(all_sharded)
>>>
>>> return acceptable_shardings
.. note:: This API is currently experimental and subject to change
"""
def custom_strategy(
custom_sharding_fn: Callable[
..., Sequence[Tuple[PlacementList, PlacementList]]
],
mesh: DeviceMesh,
op_schema: OpSchema,
) -> StrategyType:
def strategy_to_spec(strategy: object) -> object:
if isinstance(strategy, OpStrategy):
# take the output spec from the first strategy
return strategy.strategies[0].output_spec
elif isinstance(strategy, TupleStrategy):
return tuple(strategy_to_spec(s) for s in strategy.childs)
else:
return strategy
args_schema = tuple(strategy_to_spec(i) for i in op_schema.args_schema)
kwargs_schema = {
k: strategy_to_spec(v) for k, v in op_schema.kwargs_schema.items()
}
acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema)
single_mesh_dim_strategies: List[PlacementList] = []
for output_specs, input_specs in acceptable_shardings:
single_mesh_dim_strategies.append(output_specs + input_specs)
# TODO: handle out variant ops
return expand_to_full_mesh_op_strategy(
mesh,
op_schema,
single_mesh_dim_strategies,
input_index=len(op_schema.op._schema.returns),
inplace_op=_is_inplace_op(op_schema.op),
)
def wrapper(custom_sharding_fn):
def derive_schema_info(op):
# NOTE: without user directly providing RuntimeSchemaInfo, for now
# we create it in a conservative fashion as follows:
# 1. let static_argnum be the first int argument
# 2. let static_kwargkey include all the int type kwargs
# 3. always set needs_pytree=True
static_argnum = 100
static_kwargkey: List[str] = []
for i, arg in enumerate(op._schema.arguments):
if isinstance(arg.type, torch.IntType) or (
isinstance(arg.type, torch.OptionalType)
and isinstance(arg.type.getElementType(), torch.IntType)
):
static_argnum = min(i, static_argnum)
if arg.kwarg_only:
static_kwargkey.append(arg.name)
return RuntimeSchemaInfo(
static_argnum, static_kwargkey or None, needs_pytree=True
)
overloads = op if isinstance(op, list) else [op]
for overload in overloads:
DTensor._op_dispatcher.sharding_propagator.register_op_strategy(
overload,
partial(custom_strategy, custom_sharding_fn),
derive_schema_info(overload),
)
return custom_sharding_fn
return wrapper

View File

@ -0,0 +1,552 @@
# mypy: allow-untyped-defs
import copy
import operator
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
OpSchema,
OutputSharding,
OutputSpecType,
PlacementStrategy,
)
from torch.distributed.tensor._redistribute import redistribute_local_tensor
from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle
from torch.distributed.tensor.placement_types import Placement, Replicate, Shard
from torch.export import ExportedProgram
from torch.export.exported_program import ExportGraphSignature
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.node import Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
__all__ = ["tensor_parallel_transformation"]
aten = torch.ops.aten
def tensor_parallel_transformation(
exported_program: ExportedProgram,
rank: int,
world_size: int,
device_type: str,
parallel_strategies: Dict[str, ParallelStyle],
) -> ExportedProgram:
"""
The entry point function to perform graph transformations on an exported program
to transform a single-device graph into a tensor parallel graph.
.. warning::
This API is experimental and subject to change.
"""
gm = exported_program.graph_module
sig = copy.deepcopy(exported_program.graph_signature)
state_dict = copy.copy(exported_program.state_dict)
with gm._set_replace_hook(sig.get_replace_hook()):
res = _TensorParallelTransformPass(
rank,
world_size,
device_type,
state_dict,
exported_program.graph_signature,
parallel_strategies,
)(gm)
assert res is not None
gm = res.graph_module
return exported_program._update(gm, sig, state_dict=state_dict)
class _TensorParallelTransformPass(PassBase):
"""
This pass is responsible for transforming a single-device graph into a tensor parallel
graph. It will mark the placement strategy of each node in the graph,
partition the graph into distributed graph, then shard the parameters/buffers accordingly.
"""
def __init__(
self,
rank: int,
world_size: int,
device_type: str,
state_dict: Dict[str, torch.Tensor],
graph_signature: ExportGraphSignature,
parallel_strategies: Dict[str, ParallelStyle],
) -> None:
super().__init__()
self.rank = rank
self.mesh = DeviceMesh(device_type, torch.arange(world_size))
self.state_dict: Dict[str, torch.Tensor] = state_dict
self.graph_signature = graph_signature
self.parallel_strategies = parallel_strategies
def call(self, graph_module) -> PassResult:
gm = copy.deepcopy(graph_module)
parameter_placements = _generate_parameter_and_buffer_placements(
list(self.state_dict.keys()), self.parallel_strategies
)
placement_strategies = _mark_sharding(
gm, self.graph_signature, self.mesh, parameter_placements
)
_partitioner(gm)
_shard_state_dict(
self.state_dict, placement_strategies, self.graph_signature, self.mesh
)
return PassResult(gm, True)
def _generate_parameter_and_buffer_placements(
params_and_buffers: List[str],
parallel_strategies: Dict[str, ParallelStyle],
) -> Dict[str, Placement]:
"""
Build parameter placements based on the give parallel style of linear layers.
"""
parameter_placements: Dict[str, Placement] = {}
for linear_fqn, parallel_style in parallel_strategies.items():
weight_fqn = f"{linear_fqn}.weight"
bias_fqn = f"{linear_fqn}.bias"
assert weight_fqn in params_and_buffers
parameter_placements[weight_fqn] = (
Shard(0) if parallel_style == ColwiseParallel else Shard(1)
)
if bias_fqn in params_and_buffers:
parameter_placements[bias_fqn] = (
Shard(0) if parallel_style == ColwiseParallel else Replicate()
)
return parameter_placements
def _mark_tensor_parallel_shardings(
gm: GraphModule,
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
parameter_placements: Dict[str, Placement],
) -> Dict[Node, PlacementStrategy]:
"""
Mark the placement strategies of the parameter and buffer placeholder nodes.
"""
placement_strategies: Dict[Node, PlacementStrategy] = {}
num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
graph_signature.inputs_to_buffers
)
placeholder_idx: int = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
if placeholder_idx < num_params_and_buffers:
fqn: str = _get_input_node_fqn(node.name, graph_signature)
placement: Placement = (
parameter_placements[fqn]
if fqn in parameter_placements
else Replicate()
)
placement_strategies[node] = _create_placement_strategy(
node,
mesh,
placements=(placement,),
)
placeholder_idx += 1
else:
placement_strategies[node] = _create_placement_strategy(
node,
mesh,
placements=(Replicate(),),
)
return placement_strategies
def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str:
"""
Return the FQN of an input node.
"""
if input_name in graph_signature.inputs_to_parameters:
return graph_signature.inputs_to_parameters[input_name]
elif input_name in graph_signature.inputs_to_buffers:
return graph_signature.inputs_to_buffers[input_name]
else:
raise ValueError(
f"{input_name} not found in inputs_to_parameters or inputs_to_buffers"
)
def _mark_sharding(
gm: GraphModule,
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
parameter_placements: Dict[str, Placement],
) -> Dict[Node, PlacementStrategy]:
"""
Mark the sharding strategy for each node in the graph module.
"""
placement_strategies: Dict[
Node, PlacementStrategy
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
for node in gm.graph.nodes:
if node.op == "placeholder":
if node not in placement_strategies:
placement_strategies[node] = _create_placement_strategy(
node, mesh, placements=(Replicate(),)
)
node.meta["sharding"] = placement_strategies[node]
elif node.op == "call_function":
if node.target == operator.getitem:
input_nodes = node.all_input_nodes
assert (
len(input_nodes) == 1
), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
arg_strategy = placement_strategies[input_nodes[0]]
placement_strategies[node] = _create_placement_strategy(
node,
mesh,
placements=arg_strategy.output_spec.placements,
input_specs=_get_input_node_specs(node, placement_strategies),
)
node.meta["sharding"] = placement_strategies[node]
else:
op_schema = _get_op_schema(node, placement_strategies)
# get DTensor specs for inputs and outputs
if (
op_schema.op
not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs
and op_schema.op
not in DTensor._op_dispatcher.sharding_propagator.op_to_rules
):
# Mark all as replicated
output_sharding = _generate_default_output_sharding(
node,
mesh,
op_schema,
)
else:
output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding(
op_schema,
)
placement_strategies[node] = PlacementStrategy(
output_specs=_get_output_spec_from_output_sharding(output_sharding),
input_specs=output_sharding.redistribute_schema.args_spec
if output_sharding.redistribute_schema is not None
else _get_input_node_specs(node, placement_strategies),
)
node.meta["sharding"] = placement_strategies[node]
elif node.op == "output":
node.meta["sharding"] = None
else:
raise RuntimeError(f"op code {node.op} not supported")
return placement_strategies
def _get_output_spec_from_output_sharding(
output_sharding: OutputSharding,
) -> DTensorSpec:
"""
Util function to extract output spec from output sharding.
"""
if isinstance(output_sharding.output_spec, DTensorSpec):
return output_sharding.output_spec
else:
# For ops that return multiple outputs, the outputs should have the same output spec
assert isinstance(output_sharding.output_spec, Sequence)
assert output_sharding.output_spec[0] is not None
output_sharding.output_spec[0].tensor_meta = None
return output_sharding.output_spec[0]
def _create_placement_strategy(
node: Node,
mesh: DeviceMesh,
placements: Tuple[Placement, ...],
input_specs: Optional[Sequence[DTensorSpec]] = None,
) -> PlacementStrategy:
"""
Util function to construct a placement strategy for a given node.
"""
placement = PlacementStrategy(
input_specs=input_specs,
output_specs=DTensorSpec(
mesh=mesh,
placements=placements,
),
)
_populate_tensor_meta(node, placement.output_specs)
return placement
def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None:
"""
Util function to populate tensor meta of output_spec based on node metadata.
"""
if isinstance(node.meta["val"], Sequence):
assert isinstance(output_spec, Sequence)
for spec, fake_tensor in zip(output_spec, node.meta["val"]):
assert spec is not None
spec.tensor_meta = TensorMeta(
shape=fake_tensor.shape,
stride=fake_tensor.stride(),
dtype=fake_tensor.dtype,
)
else:
assert isinstance(output_spec, DTensorSpec)
output_spec.tensor_meta = TensorMeta(
shape=node.meta["val"].shape,
stride=node.meta["val"].stride(),
dtype=node.meta["val"].dtype,
)
def _generate_default_output_sharding(
node: Node,
mesh: DeviceMesh,
op_schema: OpSchema,
) -> OutputSharding:
"""
Util function to create a default output sharding that suggests Replicate placement for both args and outputs.
"""
def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec:
return DTensorSpec(
mesh=arg_spec.mesh,
placements=(Replicate(),),
tensor_meta=arg_spec.tensor_meta,
)
new_op_schema = OpSchema(
op=op_schema.op,
args_schema=pytree.tree_map_only(
DTensorSpec, update_arg_spec, op_schema.args_schema
),
kwargs_schema=op_schema.kwargs_schema,
)
def create_output_spec(tensor: FakeTensor) -> DTensorSpec:
return DTensorSpec(
mesh=mesh,
placements=(Replicate(),),
tensor_meta=TensorMeta(
shape=tensor.shape,
stride=tensor.stride(),
dtype=tensor.dtype,
),
)
return OutputSharding(
output_spec=pytree.tree_map_only(
FakeTensor, create_output_spec, node.meta["val"]
),
redistribute_schema=new_op_schema,
needs_redistribute=True,
)
def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Graph partitioner that partitions the single device graph
to distributed graph
"""
for node in gm.graph.nodes:
node_sharding = node.meta["sharding"]
if node.op == "placeholder":
out_spec = node_sharding.output_spec
local_val = _partition_val(node.meta["val"], out_spec)
# update node value
node.meta["val"] = local_val
elif node.op == "call_function":
out_spec = node_sharding.output_spec
# check if there's misaligned sharding, insert reshard if there is
expected_input_specs = node_sharding.input_specs
for idx, input_arg in enumerate(node.all_input_nodes):
input_arg_sharding = input_arg.meta["sharding"]
input_arg_spec = input_arg_sharding.output_spec
desired_spec = (
out_spec
if expected_input_specs is None
else expected_input_specs[idx]
)
if input_arg_spec != desired_spec:
_insert_reshard_gm(
gm, node, input_arg, input_arg_spec, desired_spec
)
# convert output val to its local component
output_val = node.meta["val"]
node.meta["val"] = _partition_val(output_val, out_spec)
elif node.op == "output":
for input_arg in node.all_input_nodes:
# input args of output should be Replicate, otherwise redistribution is needed.
input_args_to_check: Sequence[Node] = (
input_arg if isinstance(input_arg, Sequence) else [input_arg]
)
for arg in input_args_to_check:
arg_sharding = arg.meta["sharding"]
arg_spec = arg_sharding.output_spec
desired_spec = copy.copy(arg_spec)
desired_spec.placements = (Replicate(),)
if arg_spec != desired_spec:
_insert_reshard_gm(gm, node, arg, arg_spec, desired_spec)
else:
raise RuntimeError(f"op code {node} not supported")
_clean_up_graph_metadata(gm)
gm.graph.lint()
gm.recompile()
return gm
def _partition_val(val: Any, spec: DTensorSpec) -> Any:
"""
util function to convert a full tensor val to its local component
"""
if isinstance(val, torch.Tensor):
local_shard = val
if val.ndim == 0:
# If it's already a scalar tensor, it is already local, we don't
# need to do anything
return local_shard
for idx, placement in enumerate(spec.placements):
if placement.is_shard():
placement = cast(Shard, placement)
num_chunks = spec.mesh.size(mesh_dim=idx)
my_coord = spec.mesh.get_coordinate()
assert my_coord is not None, "current rank not in mesh!"
my_coord_on_mesh_dim = my_coord[idx]
local_shard = placement._split_tensor(
local_shard, num_chunks, with_padding=False, contiguous=True
)[0][my_coord_on_mesh_dim]
return local_shard
elif isinstance(val, (list, tuple)):
return val.__class__(_partition_val(v, spec) for v in val)
else:
raise RuntimeError(f"val type {type(val)} not supported")
def _insert_reshard_gm(
gm: torch.fx.GraphModule,
node: Node,
input_arg: Node,
input_arg_spec: DTensorSpec,
desired_spec: DTensorSpec,
) -> None:
"""
Transform the graph for tensor redistribution.
"""
input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
input_arg_tensor = input_arg.meta["val"]
# insert reshard operation
def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
return redistribute_local_tensor(
local_tensor,
input_arg_spec,
desired_spec,
)
reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
reshard_gm_nodes = list(reshard_gm.graph.nodes)
input_node = reshard_gm_nodes[0]
with gm.graph.inserting_before(node):
# copy nn_module_stack metadata for output, all-reduce nodes
for reshard_node in reshard_gm.graph.nodes:
if reshard_node.op not in ["placeholder", "output"]:
reshard_node.meta["nn_module_stack"] = (
copy.copy(input_arg.meta["nn_module_stack"])
if not input_arg.op == "placeholder"
else copy.copy(node.meta["nn_module_stack"])
)
output_node = gm.graph.graph_copy(
reshard_gm.graph,
val_map={
input_node: input_arg,
},
)
node.replace_input_with(input_arg, output_node) # type: ignore[arg-type]
def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
"""
Clean up the graph by removing sharding and partitioning related metadata
"""
for node in gm.graph.nodes:
if "sharding" in node.meta:
del node.meta["sharding"]
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
node.meta["tensor_meta"] = local_tensor_meta
def _get_input_node_specs(
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
) -> Tuple[DTensorSpec, ...]:
"""
Get the input specs of a node.
"""
input_specs_list: List[DTensorSpec] = []
for input_arg in node.all_input_nodes:
if input_arg in placement_strategies:
output_spec = placement_strategies[input_arg].output_specs
assert isinstance(output_spec, DTensorSpec)
input_specs_list.append(output_spec)
else:
raise ValueError(f"{input_arg} does not have output_spec populated.")
return tuple(input_specs_list)
def _get_op_schema(
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
) -> OpSchema:
"""
Util function to construct the operator schema of a node.
"""
args_schema_list = pytree.tree_map_only(
Node, lambda arg: placement_strategies[arg].output_specs, node.args
)
op_schema = OpSchema(
op=cast(torch._ops.OpOverload, node.target),
args_schema=tuple(args_schema_list),
kwargs_schema=cast(Dict[str, object], node.kwargs),
)
return op_schema
def _shard_state_dict(
state_dict: Dict[str, torch.Tensor],
placement_strategies: Dict[Node, PlacementStrategy],
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
) -> None:
"""
Inplace partition the weights based on the placement strategy
"""
for node, placement_strategy in placement_strategies.items():
if node.op != "placeholder":
continue
if node.name in graph_signature.inputs_to_parameters:
fqn = graph_signature.inputs_to_parameters[node.name]
elif node.name in graph_signature.inputs_to_buffers:
fqn = graph_signature.inputs_to_buffers[node.name]
else:
continue
assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}"
original_param = state_dict[fqn]
dtensor_param = distribute_tensor(
original_param,
mesh,
placement_strategy.output_spec.placements,
)
local_param = dtensor_param.to_local()
state_dict[fqn] = (
torch.nn.Parameter(local_param)
if isinstance(original_param, torch.nn.Parameter)
else local_param
)

View File

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.loss import loss_parallel
from torch.distributed.tensor.parallel.style import (
ColwiseParallel,
ParallelStyle,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
)
__all__ = [
"ColwiseParallel",
"ParallelStyle",
"PrepareModuleInput",
"PrepareModuleOutput",
"RowwiseParallel",
"SequenceParallel",
"parallelize_module",
"loss_parallel",
]

View File

@ -0,0 +1,51 @@
from functools import partial
from typing import no_type_check, Optional, Tuple
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor import DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@no_type_check
def sync_grad_hook(grad, *, device_handle=None, compute_stream=None):
if isinstance(grad, AsyncCollectiveTensor):
if compute_stream is not None:
with device_handle.stream(compute_stream):
grad = grad.wait()
else:
grad = grad.wait()
return grad
def _flatten_tensor(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
if isinstance(tensor, DTensor):
tensor._local_tensor.requires_grad_()
return tensor._local_tensor, tensor._spec
return tensor, None
@no_type_check
def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None):
# unflatten would mainly be called everytime FSDP allgather parameters.
result = DTensor.from_local(
tensor,
spec.mesh,
spec.placements,
run_check=False,
shape=spec.shape,
stride=spec.stride,
)
if tensor.requires_grad:
# only register the hook if the tensor requires grad
tensor.register_hook(
partial(
sync_grad_hook,
device_handle=device_handle,
compute_stream=compute_stream,
)
)
return result

View File

@ -0,0 +1,67 @@
# mypy: allow-untyped-defs
import warnings
from typing import Tuple, Union
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor.placement_types import Placement
try:
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
except Exception:
def is_torchdynamo_compiling(): # type: ignore[misc]
return False
LayoutsType = Union[Placement, Tuple[Placement, ...]]
def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
"""
Inject common validation logics for `_prepare_input` funcs via this decorator.
Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor`
and only 1D :class:`DeviceMesh` is passed in.
"""
# TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo.
if not is_torchdynamo_compiling():
warnings.warn(
f"{func_name} is deprecated and will be removed soon. {extra_msg}",
FutureWarning,
stacklevel=3,
)
def _validate_tp_mesh_dim(
device_mesh: DeviceMesh,
) -> None:
"""
Check whether TP mesh dimension is valid or not.
Args:
device_mesh (:class:`DeviceMesh`):
The `device_mesh` where we perform
Tensor Parallelism on.
Return:
`True` if the mesh dimension
is valid, `False` otherwise.
"""
if device_mesh.ndim > 1:
raise ValueError(
f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!"
'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]'
)
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
# if a root mesh is not the same as device_mesh,
# meaning the device_mesh is sliced out from the root mesh.
if root_mesh and root_mesh != device_mesh:
tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh)
if tp_mesh_dim_in_root != root_mesh.ndim - 1:
raise RuntimeError(
f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.",
"Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.",
)

View File

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from fnmatch import fnmatch
from typing import Dict, Union
import torch
import torch.distributed.tensor._random as random
import torch.nn as nn
from torch.distributed.tensor import DeviceMesh
from torch.distributed.tensor._random import (
is_rng_supported_mesh,
TensorParallelRNGTracker,
)
from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim
from torch.distributed.tensor.parallel.style import ParallelStyle
__all__ = [
"parallelize_module",
]
def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: DeviceMesh,
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
) -> nn.Module:
"""
Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan.
We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains
:class:`ParallelStyle`, which indicates how user wants the module or sub_module
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
Args:
module (:class:`nn.Module`):
Module to be parallelized.
device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
The plan used to parallelize the module. It can be either a
:class:`ParallelStyle` object which contains how
we prepare input/output for Tensor Parallelism or it can be a
dict of module FQN and its corresponding :class:`ParallelStyle` object.
Return:
A :class:`nn.Module` object parallelized.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>>
>>> # Define the module.
>>> m = Model(...)
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()})
>>>
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")
_validate_tp_mesh_dim(device_mesh)
# instantiate a TP RNG state tracker if it's not there
if is_rng_supported_mesh(device_mesh) and not isinstance(
random._rng_tracker, TensorParallelRNGTracker
):
random._rng_tracker = TensorParallelRNGTracker(device_mesh.device_type)
# TODO: we should allow user to pass in the default seed from a config
random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
# By default we execute random ops in non-tensor-parallel region. If users want
# to execute in tensor-parallel region, they can manually set this field to True
# after parallelizing the model.
random._rng_tracker.distribute_region_enabled = False
if isinstance(parallelize_plan, ParallelStyle):
return parallelize_plan._apply(module, device_mesh)
elif isinstance(parallelize_plan, dict):
for module_path, parallelize_style in parallelize_plan.items():
path_splits = module_path.split(".")
if len(path_splits) == 0:
raise ValueError(
"Expect module path to be non-empty, but got empty string!"
)
while path_splits:
atom = path_splits.pop(0)
matched_children = filter(
# `t[0]` is child name
lambda t: fnmatch(t[0], atom),
module.named_children(),
)
# apply the plan to all matched submodules
for _, submodule in matched_children:
if path_splits:
# we haven't reached the leaf, apply in dict style
leaf_path = ".".join(
path_splits
) # rest of the path after `atom`
parallelize_module(
submodule, device_mesh, {leaf_path: parallelize_style}
)
else:
# otherwise, directly apply style to this submodule
parallelize_module(submodule, device_mesh, parallelize_style)
return module
else:
raise TypeError( # pyre-ignore[7]
"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
f" parallelize_plan, {type(parallelize_plan)} found!"
)

View File

@ -0,0 +1,104 @@
# mypy: allow-untyped-defs
from typing import Any, List, Optional, Set, Tuple
import torch.nn as nn
from torch.distributed.tensor.parallel._data_parallel_utils import (
_flatten_tensor,
_unflatten_tensor,
)
__all__ = [] # type: ignore[var-annotated]
def _get_submodule_n_params(module: nn.Module, path: str):
"""
Get submodule and the direct path of parameter from the module
"""
if "." in path:
path_list = path.split(".")
parent_module_path = ".".join(path_list[:-1])
module = module.get_submodule(parent_module_path)
path = path_list[-1]
return module, path
def _update_module_param(param_list: List[Tuple[nn.Module, str, nn.Parameter]]):
"""
Update parameters within the module
"""
for item in param_list:
parent_module, module_path, t = item
assert hasattr(parent_module, module_path)
delattr(parent_module, module_path)
setattr(parent_module, module_path, t)
def _reconstruct_dtensor(module: nn.Module, _input: Any):
"""
Recontruct DTensor parameters from local tensors
"""
param_list = []
# TODO: To add perf optimizations to this iterations
for name, t in module.named_parameters():
if hasattr(t, "_st_info"):
dtensor = _unflatten_tensor(t, t._st_info)
param_list.append((*_get_submodule_n_params(module, name), dtensor))
_update_module_param(param_list) # type: ignore[arg-type]
def _localize_dtensor(
module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None
):
"""
Convert DTensor parameters to local tensors
"""
if ignored_params is None:
ignored_params = set()
param_list = []
for name, param in module.named_parameters():
if param in ignored_params:
continue
t, sharding_info = _flatten_tensor(param)
if sharding_info is not None:
t = nn.Parameter(t)
t._st_info = sharding_info # type: ignore[attr-defined]
param_list.append((*_get_submodule_n_params(module, name), t))
_update_module_param(param_list) # type: ignore[arg-type]
def _pre_dp_module_transform(module: nn.Module):
"""
Enable the composability between Tensor Parallelism (TP) and Data
Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which
are DTensors to local tensors before wrapping with data parallelism API.
We then register two hooks, one for converting local tensors back to DTensor
preforward and one to convert DTensors back to tensors after Forward. By
integrating this way, we avoid any special handling of DTensor parameters by DDP
and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP.
For now, this API only works with ``DistributedDataParallel``. It will later support
other DP methods such as FSDP.
Args:
module (:class:`nn.Module`):
Module which has been applied TP on.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform
>>>
>>> # Define the module.
>>> m = module(...)
>>> parallelize_module(m, PairwiseParallel())
>>> m = pre_dp_module_transform(m)
>>> m = DDP(m)
>>>
"""
_localize_dtensor(module, None, None)
# TODO: To add test cases and ensure that it works for nested modules
module.register_forward_pre_hook(_reconstruct_dtensor)
module.register_forward_hook(_localize_dtensor)

View File

@ -0,0 +1,388 @@
# mypy: allow-untyped-defs
import copy
from typing import Any, cast, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
import torch.distributed.distributed_c10d as c10d
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardedTensor,
ShardedTensorMetadata,
TensorProperties,
)
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
from torch.distributed.device_mesh import _mesh_resources
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.remote_device import _remote_device
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
from torch.distributed.tensor.parallel._data_parallel_utils import (
_flatten_tensor,
_unflatten_tensor,
)
__all__ = ["DTensorExtensions"]
def _get_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
placement = tensor.placements[0]
offsets = [0] * len(tensor.size())
num_chunks = device_mesh.size(mesh_dim=0)
if tensor.placements[0].is_shard():
shard_dim = cast(DShard, placement).dim
chunk_size = tensor.size(shard_dim) // num_chunks
offsets[shard_dim] = chunk_size
return (torch.Size(offsets), tensor._local_tensor.size())
def _get_box_for(tensor: DTensor, idx: int) -> Tuple[torch.Size, torch.Size]:
offsets, size = _get_box(tensor)
return (torch.Size([val * idx for val in offsets]), size)
def _get_local_box(tensor: DTensor) -> Tuple[torch.Size, torch.Size]:
device_mesh = tensor.device_mesh
coord = device_mesh.get_coordinate()
assert coord is not None
return _get_box_for(tensor, coord[0])
def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
offsets, sizes = _get_local_box(dt)
return ShardMetadata(
shard_offsets=list(offsets),
shard_sizes=list(sizes),
placement=f"rank:{current_rank}/{dt._local_tensor.device}",
)
def _create_sharded_tensor_md_from_dt(
dt: DTensor, dt_pg: c10d.ProcessGroup
) -> ShardedTensorMetadata:
# This is where it gets tricky, we have to produce a ShardedTensor that has full coverage
# and yet has only one valid shard for the current rank.
shards_md = []
my_rank = dist.get_rank(dt_pg)
scapegoat_rank = 0 if my_rank > 0 else 1
if dt.placements[0].is_shard():
shard_count = dt_pg.size()
else:
shard_count = 1
for i in range(shard_count):
offsets, sizes = _get_box_for(dt, i)
shards_md.append(
ShardMetadata(
shard_offsets=list(offsets),
shard_sizes=list(sizes),
placement=(
f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}"
),
)
)
return ShardedTensorMetadata(
shards_metadata=shards_md,
size=dt.size(),
tensor_properties=TensorProperties(
dtype=dt.dtype,
layout=dt.layout,
requires_grad=dt.requires_grad,
# ignore memory_format and pin_memory as those are not supported by DT
),
)
def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup:
mesh = dt.device_mesh
assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
return mesh.get_group()
def _rewrite_spec_if_needed(
spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
) -> shard_spec.ShardingSpec:
"""
Rewrite ``spec`` to match the device of ``tensor``.
FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec
produces CUDA metadata, ST construction bombs.
"""
if not isinstance(spec, ChunkShardingSpec):
return spec
# let's see if we need
rewrite = False
for p in spec.placements:
p = cast(_remote_device, p)
if p.rank() == rank and p.device() != tensor.device:
rewrite = True
break
if rewrite:
spec = copy.deepcopy(spec)
for i, placement in enumerate(spec.placements):
placement = cast(_remote_device, placement)
if placement.rank() == rank and placement.device() != tensor.device:
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
return spec
def _chunk_tensor(
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
) -> torch.Tensor:
if type(tensor) is ShardedTensor:
assert len(tensor.local_shards()) == 1
inner_param = tensor.local_tensor()
inner_st = _create_chunk_sharded_tensor(
inner_param,
rank,
world_size,
num_devices_per_node,
pg,
)
outer_local_shard = tensor.local_shards()[0]
shards: List[Shard] = [
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
]
st_meta = copy.deepcopy(tensor.metadata())
st_meta.tensor_properties.requires_grad = False
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
shards,
sharded_tensor_metadata=st_meta,
process_group=tensor._process_group,
init_rrefs=False,
)
return st_outer
elif type(tensor) is DTensor:
device_mesh = tensor.device_mesh
assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled"
inner_param = tensor._local_tensor
inner_st = _create_chunk_sharded_tensor(
inner_param,
rank,
world_size,
torch.cuda.device_count(),
pg,
)
dt_pg = _get_dt_pg(tensor)
# We do this differently here, we create a ST with no local shards then patch it
shards = [
Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)))
]
st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg)
st_meta.tensor_properties.requires_grad = False
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
shards,
sharded_tensor_metadata=st_meta,
process_group=dt_pg,
init_rrefs=False,
)
return st_outer
else:
return _create_chunk_sharded_tensor(
tensor,
rank,
world_size,
num_devices_per_node,
pg,
)
def _chunk_dtensor(
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> DTensor:
"""
Shard a tensor to chunks along the first dimension.
The local rank will gets its corresponding chunk as the local tensor to create a DTensor.
"""
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
if root_mesh is None:
raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.")
if root_mesh.ndim < 2:
raise RuntimeError(
f"Found parent device_mesh of ndim={root_mesh.ndim},",
"but meshes must be at least 2D.",
)
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
tensor = tensor.clone().detach()
# When a layer is not involved in TP, then the tensor will not be a DTensor.
# e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer.
# e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):
# For tensors, it is replicated across tp dimension and sharded across FSDP dimension.
# TP is the inner dimension and FSDP is the outer dimension.
# Therefore, shard placements for tensor is (Shard(0), Replicate()).
replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
shard_placements = [Replicate() for _ in range(root_mesh.ndim)]
shard_placements[0] = DShard(0) # type: ignore[call-overload]
return DTensor.from_local(
tensor, root_mesh, replicate_placements, run_check=False
).redistribute(
device_mesh=root_mesh,
placements=shard_placements,
)
else:
tp_placements = tensor.placements
tp_placement = tp_placements[0]
tensor = tensor.to_local()
# For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension.
# TP is the inner dimension and FSDP is the outer dimension.
# Therefore, shard placements for tensor is (Shard(0), tp_placement).
# For higher dimensional meshes, it is replicated across other dimensions. For example, with
# HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement).
replicate_placements = [Replicate() for _ in range(root_mesh.ndim)]
replicate_placements[-1] = tp_placement # type: ignore[call-overload]
shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc]
shard_placements[-2] = DShard(0) # type: ignore[call-overload]
shard_placements[-1] = tp_placement # type: ignore[call-overload]
return DTensor.from_local(
tensor, root_mesh, replicate_placements, run_check=False
).redistribute(
device_mesh=root_mesh,
placements=shard_placements,
)
def _pre_load_state_dict(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
shards = cast(ShardedTensor, tensor).local_shards()
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
inner_tensor = shards[0].tensor
shards = inner_tensor.local_shards() # pyre-ignore[16]
tensor = inner_tensor
return (tensor, shards if len(shards) > 0 else [])
def _all_gather_dtensor(
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
"""All gather a DTensor in its FSDP dimension and return the local tensor."""
assert parent_mesh == tensor.device_mesh
placements = list(copy.deepcopy(tensor.placements))
# FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement]
# HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement]
for i in range(0, len(placements) - 1):
placements[i] = Replicate()
tensor = tensor.redistribute(
device_mesh=tensor.device_mesh,
placements=placements,
)
return tensor.to_local()
class DTensorExtensions(FSDPExtensions):
"""
DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP.
This is the implementation for FSDPExtensions defined in
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py
"""
def __init__(self, device_handle) -> None:
super().__init__()
self.compute_stream = None
self.device_handle = device_handle
# we have to use the dynamo disable this way to disable dynamo as the decorater way would
# trigger build failure with torch deploy...
self.post_unflatten_transform = torch._dynamo.disable(self.post_unflatten_transform) # type: ignore[method-assign]
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Any]]:
return _flatten_tensor(tensor)
def post_unflatten_transform(
self, tensor: torch.Tensor, param_extension: Any
) -> torch.Tensor:
stream = self.compute_stream or self.device_handle.current_stream()
with self.device_handle.stream(stream):
# runtime we put the unflattened tensor call on the compute stream since
# the unflattened tensor might contain computations in fwd/bwd where we
# need to sync properly.
# TODO: this is a short term fix and we should make the get_unflat_views
# directly happen in the compute stream.
result = _unflatten_tensor(
tensor,
param_extension,
device_handle=self.device_handle,
compute_stream=self.compute_stream,
)
_set_fsdp_flattened(result)
return result
def chunk_tensor(
self,
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
device: Optional[torch.device] = None,
) -> torch.Tensor:
return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
def chunk_dtensor(
self,
tensor: torch.Tensor,
rank: int,
device_mesh: DeviceMesh,
) -> torch.Tensor:
return _chunk_dtensor(tensor, rank, device_mesh)
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[Shard]]:
return _pre_load_state_dict(tensor)
def all_gather_dtensor(
self,
tensor: DTensor,
parent_mesh: Optional[DeviceMesh],
) -> torch.Tensor:
return _all_gather_dtensor(tensor, parent_mesh)

View File

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from functools import partial
from typing import Any, Optional, Tuple
import torch
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard
__all__ = [
"input_reshard",
]
def input_reshard(
module: torch.nn.Module,
tp_device_mesh: DeviceMesh,
input_reshard_dim: Optional[int] = None,
) -> torch.nn.Module:
"""
Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation.
Register hooks to an nn.Module with input resharding so that we can shard
per the given `tp_device_mesh` and `input_reshard_dim` and restore the
input back when recomputing the activations in the backward. The reason
why we can do this is that for Tensor Parallel(TP), the input are same
across all TP ranks.
Args:
module (:class:`nn.Module`):
Module to be registered with input resharding.
tp_device_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for Tensor Parallel.
input_reshard_dim (Optional[int]):
The dimension of where we perform the sharding
of input. If set None, there is no sharding of input.
Default: None
Return:
A :class:`nn.Module` object registered with TP input resharding.
"""
cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None
def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None:
saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks(
partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim),
partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim),
)
saved_tensor_hooks.__enter__()
nonlocal cx
cx = saved_tensor_hooks # type: ignore[name-defined]
def input_reshard_backward_hook(
_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any
) -> Any:
nonlocal cx
cx.__exit__() # type: ignore[name-defined, union-attr]
if input_reshard_dim is None:
return module
module.register_forward_pre_hook(input_reshard_forward_pre_hook)
module.register_forward_hook(input_reshard_backward_hook)
return module
def _pack_hook_tp(
mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
) -> Any: # noqa: D401
"""Hook function called after FWD to shard input."""
if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements):
return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
elif (
not isinstance(x, DTensor)
and isinstance(x, torch.Tensor)
and x.numel() >= mesh.size()
):
return (
DTensor.from_local(x, device_mesh=mesh)
.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)])
.to_local()
)
else:
return x
def _unpack_hook_tp(
mesh: DeviceMesh, input_reshard_dim: int, x: Any
) -> torch.Tensor: # noqa: D401
"""Hook function called before activation recomputing in BWD to restore input."""
if (
isinstance(x, DTensor)
and len(x._spec.placements) == 1
and x._spec.placements[0].is_shard()
):
return x.redistribute(device_mesh=mesh, placements=[Replicate()])
elif (
not isinstance(x, DTensor)
and isinstance(x, torch.Tensor)
and x.numel() >= mesh.size()
):
return (
DTensor.from_local(
x, device_mesh=mesh, placements=[Shard(input_reshard_dim)]
)
.redistribute(device_mesh=mesh, placements=[Replicate()])
.to_local()
)
else:
return x

View File

@ -0,0 +1,490 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
from typing import cast, Dict, Optional, Tuple
import torch
import torch._prims_common as utils
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import Tensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
from torch.distributed.tensor._ops._math_ops import (
_skip_dim,
Reduction,
replicate_reduction_dims,
)
from torch.distributed.tensor.placement_types import Placement
aten = torch.ops.aten
__all__ = ["loss_parallel"]
@contextlib.contextmanager
def loss_parallel():
"""
A context manager that enables loss parallelism, where efficient parallelized loss computation
can be performed when the input is sharded on the class dimension. Currently only the cross-entropy
loss is supported.
Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or
:class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters.
The corresponding ``backward()`` call, if any, also needs to happen under this context manager.
Args:
input (:class:`DTensor`):
Input logits. Assumed to be sharded on the class dimension.
target (Union[:class:`torch.Tensor`, :class:`DTensor`]):
Must be ground truth class indices (class probabilities currently not supported).
Assumed to be replicated across the ``DeviceMesh``.
weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional):
If given, assumed to be replicated across the ``DeviceMesh``.
label_smoothing:
Currently not supported.
Returns:
A replicated :class:`DTensor`.
Example:
A sharded DTensor is manually created here to showcase the usage.
In practice, it is usually the output of a TP module.
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>> loss = F.cross_entropy(dist_input, target, reduction="mean")
>>> loss.backward()
>>> ...
"""
_enable_custom_loss_ops()
yield
_disable_custom_loss_ops()
# Currently only needs to support one dimensional DeviceMesh; in general return
# the mesh_dim with placements[mesh_dim].is_shard(dim)
def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int:
if not len(placements) == 1:
raise ValueError(
"Currently loss_parallel() only supports input on one-dimensional DeviceMesh."
)
if not placements[0].is_shard(dim):
raise ValueError(
f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}."
)
return 0
def _cast_to_dtensor(
tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
) -> DTensor:
if isinstance(tensor, DTensor):
if tensor.placements == placements:
return tensor
else:
raise RuntimeError(f"Expected {placements} but got {tensor.placements}.")
elif isinstance(tensor, torch.Tensor):
return DTensor.from_local(
tensor, device_mesh=mesh, placements=placements, run_check=False
)
else:
raise TypeError(f"Unsupported type {type(tensor)}")
def _propagate_tensor_meta(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> TensorMeta:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
op_info.schema
)
if isinstance(tensor_meta, TensorMeta):
return tensor_meta
elif isinstance(tensor_meta, tuple):
return tensor_meta[0]
else:
raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.")
# NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
# with all_reduce manually inserted to perform distributed computation.
def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = utils.elementwise_dtypes(
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
if x.numel() == 0:
shifted = x
else:
x_max = torch.amax(x, dim, keepdim=True)
x_max = funcol.all_reduce(
x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim)
)
shifted = x - x_max
shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True)
shifted_sumexp = funcol.all_reduce(
shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim)
)
shifted_logsumexp = torch.log(shifted_sumexp)
result = shifted - shifted_logsumexp
if not half_to_float:
result = result.to(result_dtype)
return result
def _log_softmax_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
x = cast(DTensor, args[0])
dim = cast(int, args[1])
half_to_float = cast(bool, args[2])
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim)
output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs)
res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim)
res_spec = DTensorSpec(
spec.mesh,
spec.placements,
tensor_meta=output_tensor_meta,
)
return DTensor(
res,
res_spec,
requires_grad=res.requires_grad,
)
# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the
# _log_softmax_backward_handler does not actually do any computation.
def _log_softmax_backward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
grad_output = cast(DTensor, args[0])
input_dtype = cast(torch.dtype, args[3])
return grad_output.to(input_dtype)
# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward,
# with customized communication inserted to perform distributed computation.
def _nll_loss_forward(
x: Tensor,
target: Tensor,
weight: Optional[Tensor],
local_weight: Optional[Tensor],
reduction: int,
ignore_index: int,
input_shape: torch.Size,
channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tuple[Tensor, Tensor]:
n_dims = x.dim()
channel_dim = 1
if n_dims < 2:
channel_dim = 0
def _weight_view(weight: Tensor) -> Tensor:
if n_dims > 1:
shape = [
1,
] * n_dims
shape[channel_dim] = weight.shape[0]
w = weight.view(shape)
else:
w = weight
return w
if weight is not None:
w = _weight_view(weight)
assert local_weight is not None
local_w = _weight_view(local_weight)
x = x * local_w
safe_target = torch.where(target != ignore_index, target, 0)
safe_target_ = safe_target.unsqueeze(channel_dim)
# The following code block is a distributed version of
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target_partial_ = partial_placement._partition_value(
safe_target_, mesh, mesh_dim
)
result_partial = torch.gather(x, channel_dim, safe_target_partial_)
# an all_reduce happens here
result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim)
result = -result_reduced.squeeze(channel_dim)
result = torch.where(target != ignore_index, result, 0)
if reduction == Reduction.NONE.value and n_dims > 1:
total_weight = x.new_full((), 0.0)
return result, total_weight
if weight is not None:
new_shape = list(x.shape)
new_shape[channel_dim] = -1
w = w.expand(new_shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)
total_weight = wsum.sum()
else:
total_weight = (target != ignore_index).sum().to(x)
# NOTE: this is correct only on 1D DeviceMesh; o/w additional
# all-reduce on result and total_weight is needed
if reduction == Reduction.SUM.value:
result = result.sum()
elif reduction == Reduction.MEAN.value:
result = result.sum() / total_weight
return result, total_weight
def _nll_loss_forward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
x = cast(DTensor, args[0])
target = args[1]
weight = args[2]
reduction = cast(int, args[3])
ignore_index = cast(int, args[4])
channel_dim = 1 if x.dim() >= 2 else 0
channel_dim_size = x.shape[channel_dim]
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
# Check user input: if target and weight are not DTensors, convert them to DTensors;
# if they are DTensors, check that they have the desired placements.
target_placements = _skip_dim(
replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
)
all_replicate_placements = (Replicate(),) * spec.mesh.ndim
target = _cast_to_dtensor(target, target_placements, spec.mesh)
local_weight = None
if weight is not None:
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
# For local computation, both (replicated) weight and (sharded) local_weight
# are needed in _nll_loss_forward(). local_weight is generated here using
# DTensor API, without incurring any communication.
sharded_placements = [
Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim)
]
local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor
assert local_weight.shape[0] == x._local_tensor.shape[channel_dim]
if reduction == Reduction.NONE.value:
output_placements = target_placements
else:
output_placements = all_replicate_placements
# tensor inputs to _propagate_tensor_meta need to be DTensors
args = list(args)
args[1], args[2] = target, weight
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
result, total_weight = _nll_loss_forward(
x._local_tensor,
target._local_tensor,
weight._local_tensor if weight is not None else None,
local_weight,
reduction,
ignore_index,
x.shape,
channel_dim,
spec.mesh,
mesh_dim,
)
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
return (
DTensor(
result,
out_spec,
requires_grad=result.requires_grad,
),
total_weight,
)
# NOTE: The backward computation of cross_entropy goes through two steps:
# backward for nll_loss and then backward for log_softmax. In loss parallel,
# the two steps are fused into the following function (called by _nll_loss_backward_handler)
# to avoid communication when target contains class indices not class probabilities.
# Also note that the _log_softmax_backward_handler does not perform computation.
# The implementation resembles _nll_loss_backward and _log_softmax_backward_data
# from torch._decomp.decomposition.
def _nll_loss_and_log_softmax_backward(
grad_output: Tensor,
x: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
input_shape: torch.Size,
channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tensor:
channel_dim = 0 if x.dim() < 2 else 1
if reduction == Reduction.MEAN.value:
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
safe_target = torch.where(target != ignore_index, target, 0)
grad_input = torch.zeros_like(x)
# The following code block is a distributed version of
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target = safe_target.squeeze(channel_dim).flatten()
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
# only update grad_input to -1 if not masked
assert partial_placement.mask_buffer.data is not None
grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0
arange_1d = torch.arange(
masked_safe_target.shape[0], device=masked_safe_target.device
)
# The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default;
# the last case is for aten.nll_loss2d_backward.default.
if x.dim() == 1:
grad_input[masked_safe_target] = grad_update
elif x.dim() == 2:
grad_input[arange_1d, masked_safe_target] = grad_update
else:
grad_input_t = grad_input.transpose(channel_dim, -1)
intermidate_shape = grad_input_t.shape
grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim])
grad_input_2d[arange_1d, masked_safe_target] = grad_update
grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1)
if grad_input.dim() > grad_output.dim() > 0:
grad_output = grad_output.unsqueeze(channel_dim)
if weight is not None:
new_shape = [1 for _ in range(x.dim())]
new_shape[channel_dim] = weight.shape[0]
weight = weight.reshape(new_shape)
# In order for fused computation to work, the following line is rewritten.
# grad_output = grad_output * weight
new_shape = list(x.shape)
new_shape[channel_dim] = -1
w = weight.expand(new_shape)
w_target = torch.gather(w, channel_dim, target)
grad_output = grad_output * w_target
grad_output = torch.where(target != ignore_index, grad_output, 0)
# NOTE: Instead of directly returning the grad_input as grad_output for log_softmax,
# here we perform backward computation for log_softmax altogether to avoid the
# otherwise extra all_gather communication.
# return grad_input * grad_output
return (grad_input + torch.exp(x)) * grad_output
def _nll_loss_backward_handler(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> object:
grad_output = cast(DTensor, args[0])
x = cast(DTensor, args[1])
target = args[2]
weight = args[3]
reduction = cast(int, args[4])
ignore_index = cast(int, args[5])
total_weight = cast(Tensor, args[6])
channel_dim = 1 if x.dim() >= 2 else 0
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
# if target and weight are not DTensors, convert them to DTensors
target_placements = _skip_dim(
replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
)
all_replicate_placements = (Replicate(),) * spec.mesh.ndim
target = _cast_to_dtensor(target, target_placements, spec.mesh)
if weight is not None:
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
# tensor inputs to _propagate_tensor_meta need to be DTensors
args = list(args)
args[2], args[3] = target, weight
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
result = _nll_loss_and_log_softmax_backward(
grad_output._local_tensor,
x._local_tensor,
target._local_tensor,
weight._local_tensor if weight is not None else None,
reduction,
ignore_index,
total_weight,
x.shape,
channel_dim,
spec.mesh,
mesh_dim,
)
# the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim
out_spec = DTensorSpec(
spec.mesh,
spec.placements,
tensor_meta=output_tensor_meta,
)
return DTensor(
result,
out_spec,
requires_grad=result.requires_grad,
)
customized_loss_ops = {
aten._log_softmax.default: _log_softmax_handler,
aten._log_softmax_backward_data.default: _log_softmax_backward_handler,
aten.nll_loss_forward.default: _nll_loss_forward_handler,
aten.nll_loss2d_forward.default: _nll_loss_forward_handler,
aten.nll_loss_backward.default: _nll_loss_backward_handler,
aten.nll_loss2d_backward.default: _nll_loss_backward_handler,
}
def _enable_custom_loss_ops():
DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops)
def _disable_custom_loss_ops():
for custom_op in customized_loss_ops:
DTensor._op_dispatcher._custom_op_handlers.pop(custom_op)

View File

@ -0,0 +1,627 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
DTensor,
Replicate,
Shard,
)
from torch.distributed.tensor.placement_types import Placement
__all__ = [
"ParallelStyle",
"RowwiseParallel",
"SequenceParallel",
"ColwiseParallel",
"PrepareModuleInput",
"PrepareModuleOutput",
]
class ParallelStyle(ABC):
"""
The parallel style contract defines how the module or submodule should be parallelized.
It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
flexibility for different kind of style implementations.
"""
@abstractmethod
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
...
class ColwiseParallel(ParallelStyle):
"""
Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
(i.e. MLP, Attention)
Keyword Args:
input_layouts (Placement, optional):
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
become a DTensor. If not specified, we assume the input tensor to be replicated.
output_layouts (Placement, optional):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
>>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
>>> ...
.. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Shard(-1),)
# colwise linear runtime sharding (desired sharding):
# 1. requires replicate input
# 2. shard output on last dim
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(
input_layouts, desired_input_layouts, mod, inputs, device_mesh
):
# TODO: figure out dynamo support for instance method and switch this to instance method
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, input_layouts, run_check=False
)
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
placements=desired_input_layouts, async_op=True
)
return input_tensor
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
module.register_parameter(name, dist_param)
def _partition_embedding_fn(self, name, module, device_mesh):
# colwise shard embedding.weight is straight forward as Shard(1)
for name, param in module.named_parameters():
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(1)]))
module.register_parameter(name, dist_param)
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
if isinstance(module, nn.Linear):
partition_fn = self._partition_linear_fn
elif isinstance(module, nn.Embedding):
partition_fn = self._partition_embedding_fn
else:
raise NotImplementedError(
"ColwiseParallel currently only support nn.Linear and nn.Embedding!"
)
return distribute_module(
module,
device_mesh,
partition_fn,
partial(
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
),
partial(
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)
class RowwiseParallel(ParallelStyle):
"""
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
(i.e. MLP, Attention)
Keyword Args:
input_layouts (Placement, optional):
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
output_layouts (Placement, optional):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
>>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
>>> ...
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Shard(-1),)
self.output_layouts = (output_layouts or Replicate(),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(
input_layouts, desired_input_layouts, mod, inputs, device_mesh
):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(
input_tensor, device_mesh, input_layouts, run_check=False
)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(
placements=desired_input_layouts, async_op=True
)
return input_tensor
def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(distribute_tensor(module.weight, device_mesh, [Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
distribute_tensor(module.bias, device_mesh, [Replicate()])
),
)
def _partition_embedding_fn(self, name, module, device_mesh):
# rowwise shard embedding.weight is Shard(0)
for name, param in module.named_parameters():
dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)]))
module.register_parameter(name, dist_param)
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
# 1. to replicate -> allreduce
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
if isinstance(module, nn.Linear):
partition_fn = self._partition_linear_fn
# rowwise linear runtime sharding requires input tensor shard on last dim
self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),)
elif isinstance(module, nn.Embedding):
partition_fn = self._partition_embedding_fn
# rowwise embedding runtime sharding requires input tensor replicated
self.desired_input_layouts = (Replicate(),)
else:
raise NotImplementedError(
"RowwiseParallel currently only support nn.Linear and nn.Embedding!"
)
return distribute_module(
module,
device_mesh,
partition_fn,
partial(
self._prepare_input_fn, self.input_layouts, self.desired_input_layouts
),
partial(
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)
class SequenceParallel(ParallelStyle):
"""
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
This style implements the operation that is described in the paper
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
redistribute the input to be sharded on the sequence dimension.
The output of the ``nn.Module`` will be sharded on the sequence dimension.
Keyword Args:
sequence_dim (int, optional):
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
become a DTensor that is sharded on the sequence dimension, default: 1.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
Returns:
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
to ensure that they are replicated.
"""
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False):
super().__init__()
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
def _replicate_module_fn(
self, name: str, module: nn.Module, device_mesh: DeviceMesh
):
for p_name, param in module.named_parameters():
# simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
# us to simply just use from_local
replicated_param = torch.nn.Parameter(
DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
)
module.register_parameter(p_name, replicated_param)
@staticmethod
def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh):
input_tensor = inputs[0]
if isinstance(input_tensor, DTensor):
# if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it
if input_tensor.placements != sequence_sharding:
input_tensor = input_tensor.redistribute(
placements=sequence_sharding, async_op=True
)
return input_tensor
elif isinstance(input_tensor, torch.Tensor):
# assume the input passed in already sharded on the sequence dim and create the DTensor
return DTensor.from_local(
input_tensor, device_mesh, sequence_sharding, run_check=False
)
else:
raise ValueError(
f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}"
)
@staticmethod
def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output else outputs
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._replicate_module_fn,
partial(self._prepare_input_fn, self.sequence_sharding),
partial(self._prepare_output_fn, self.use_local_output),
)
class PrepareModuleInput(ParallelStyle):
"""
Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
Keyword Args:
input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
as a placeholder. default: None.
desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.
input_kwarg_layouts (Dict[str, Placement]):
The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.
default: None
desired_input_kwarg_layouts: (Dict[str, Placement]):
The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module
have the desired DTensor layouts. default: None.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
Returns:
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
>>> # and then redistributed to Replicated DTensor.
>>> parallelize_module(
>>> block, # this can be a submodule or module
>>> tp_mesh,
>>> parallelize_plan={
>>> "attn": PrepareModuleInput(
>>> input_layouts=(Shard(0), None, None, ...),
>>> desired_input_layouts=(Replicate(), None, None, ...)
>>> ),
>>> }
>>> )
"""
def __init__(
self,
*,
input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
desired_input_layouts: Optional[
Union[Placement, Tuple[Optional[Placement]]]
] = None,
input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
use_local_output: bool = False,
):
self.input_layouts = (
(input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
)
self.desired_input_layouts = (
(desired_input_layouts,)
if isinstance(desired_input_layouts, Placement)
else desired_input_layouts
)
self.use_local_output = use_local_output
if self.input_layouts is not None:
assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
assert len(self.input_layouts) == len(
self.desired_input_layouts
), "input_layouts and desired_input_layouts should have same length!"
self.with_kwargs = input_kwarg_layouts is not None
self.input_kwarg_layouts = input_kwarg_layouts or {}
self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
if self.with_kwargs:
assert len(self.input_kwarg_layouts) == len(
self.desired_input_kwarg_layouts
), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
def _prepare_input_arg(
self,
input: Any,
mesh: DeviceMesh,
input_layout: Optional[Placement],
desired_layout: Optional[Placement],
):
if input_layout is not None:
if isinstance(input, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert inp.placements[0] == input_layout
dt_inp = input
else:
assert isinstance(
input, torch.Tensor
), "expecting input to be a torch.Tensor!"
dt_inp = DTensor.from_local(
input, mesh, (input_layout,), run_check=False
)
if desired_layout is not None and input_layout != desired_layout:
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
return dt_inp.to_local() if self.use_local_output else dt_inp
else:
return input
def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
prepared_inputs = []
if not isinstance(inputs, tuple):
inputs = (inputs,)
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")
assert (
self.desired_input_layouts is not None
), "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts
):
prepared_inputs.append(
self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout)
)
return tuple(prepared_inputs)
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
input_layout = self.input_kwarg_layouts.get(kwarg_key)
desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
kwarg_val, device_mesh, input_layout, desired_input_layout
)
return (prepared_arg_inputs, prepared_kwarg_inputs)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
if self.with_kwargs:
module.register_forward_pre_hook(
lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(
inputs, kwargs, device_mesh
),
with_kwargs=True,
) # type: ignore[misc]
else:
module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
return module
class PrepareModuleOutput(ParallelStyle):
"""
Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
Keyword Args:
output_layouts (Union[Placement, Tuple[Placement]]):
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
``None`` need to be specified as a placeholder.
desired_output_layouts (Union[Placement, Tuple[Placement]]):
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
have the desired DTensor layouts.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
Returns:
A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
>>> # and then redistributed to Sharded DTensor.
>>> parallelize_module(
>>> block, # this can be a submodule or module
>>> tp_mesh,
>>> parallelize_plan = PrepareModuleOutput(
>>> output_layouts=Replicate(),
>>> desired_output_layouts=Shard(0)
>>> )
>>> )
"""
def __init__(
self,
*,
output_layouts: Union[Placement, Tuple[Placement]],
desired_output_layouts: Union[Placement, Tuple[Placement]],
use_local_output: bool = True,
):
self.output_layouts = (
(output_layouts,)
if isinstance(output_layouts, Placement)
else output_layouts
)
self.desired_output_layouts = (
(desired_output_layouts,)
if isinstance(desired_output_layouts, Placement)
else desired_output_layouts
)
self.use_local_output = use_local_output
assert len(self.output_layouts) == len(
self.desired_output_layouts
), "output_layouts and desired_output_layouts should have same length!"
def _prepare_out_fn(self, outputs, device_mesh):
prepared_outputs = []
if not isinstance(outputs, tuple):
outputs = (outputs,)
if len(outputs) != len(self.output_layouts):
raise ValueError(
"module outputs and output_layouts should have same length!"
)
for out, out_layout, desired_out_layout in zip(
outputs, self.output_layouts, self.desired_output_layouts
):
if out_layout is not None:
if isinstance(out, DTensor):
# TODO: re-enable the check once we fix the compile path
# assert out.placements[0] == out_layout
dt_out = out
else:
dt_out = DTensor.from_local(
out, device_mesh, (out_layout,), run_check=False
)
if out_layout != desired_out_layout:
dt_out = dt_out.redistribute(placements=(desired_out_layout,))
prepared_outputs.append(
dt_out.to_local() if self.use_local_output else dt_out
)
else:
prepared_outputs.append(out)
if len(prepared_outputs) == 1:
return prepared_outputs[0]
else:
return tuple(prepared_outputs)
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
return module

View File

@ -0,0 +1,652 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import cast, List, Optional, Tuple
import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._collective_utils import (
fill_empty_tensor_to_shards,
mesh_broadcast,
mesh_scatter,
pad_tensor,
shard_dim_alltoall,
unpad_tensor,
)
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
class Placement:
"""
The base class for the Placement type, where it describes how a DTensor is placed onto the
``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout.
It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``,
and ``Partial``.
This class is not meant to be used directly, mainly served as a typing stub.
"""
# convenient utils to check for placement types
def is_shard(self, dim: Optional[int] = None) -> bool:
is_shard_instance = isinstance(self, Shard)
if dim is not None and is_shard_instance:
return cast(Shard, self).dim == dim
else:
return is_shard_instance
def is_replicate(self) -> bool:
return isinstance(self, Replicate)
def is_partial(self) -> bool:
return isinstance(self, Partial)
@dataclass(frozen=True)
class Shard(Placement):
"""
The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension
``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the
DeviceMesh dimension only holds a shard/piece of the global Tensor. The
``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the
last few shards on the DeviceMesh dimension might be empty when the tensor dimension
is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be
used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)
Args:
dim (int): The tensor dimension that describes the DTensor is sharded over its
corresponding DeviceMesh dimension.
.. warning:: sharding on a tensor dimension where the tensor dimension size is not
evenly divisible on a DeviceMesh dimension is currently experimental and subject to change.
"""
dim: int
def _split_tensor(
self,
tensor: torch.Tensor,
num_chunks: int,
*,
with_padding: bool = True,
contiguous: bool = True,
) -> Tuple[List[torch.Tensor], List[int]]:
"""
This function uses torch.chunk to split a tensor into num_chunks shards along
the Shard placement dimension, and return a list of shards with their pad sizes.
Keyword args:
with_padding (bool, optional): when True, we pad the tensor on the last
few ranks before calling the collectives (i.e. scatter/all_gather, etc.).
This is because collectives usually require equal size tensor inputs
"""
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
# chunk tensor over dimension `dim` into n slices
tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim))
num_empty_tensors = num_chunks - len(tensor_list)
# if no need to have padding or tensor dim size is evenly sharded already
# we can return early.
if not with_padding or tensor.size(self.dim) % num_chunks == 0:
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return (
fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors),
[],
)
# compute the chunk size inline with ``torch.chunk`` to calculate padding
full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks
# Compute chunk size for each chunk for ``self.dim``
chunk_sizes = [
tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0
for idx in range(num_chunks)
]
# Compute pad size on each chunk
pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
# Reuse tensor to fill empty chunk with empty tensor
tensor_list = fill_empty_tensor_to_shards(
tensor_list, self.dim, num_empty_tensors
)
shard_list = []
for shard, pad_size in zip(tensor_list, pad_sizes):
# Fill the empty tensor with zeroes with padding.
if with_padding and pad_size > 0:
shard = pad_tensor(shard, self.dim, pad_size)
shard = shard.contiguous() if contiguous else shard
shard_list.append(shard)
return shard_list, pad_sizes
@staticmethod
def _local_shard_size_on_dim(
size_on_dim: int,
num_chunks: int,
rank: int,
return_offset: bool = False,
) -> Tuple[int, int]:
"""
returns the local shard size and offset on a given tensor dim
"""
# Compute the chunk size inline with ``torch.chunk``
if size_on_dim % num_chunks == 0:
full_chunk_size = size_on_dim // num_chunks
return full_chunk_size, full_chunk_size * rank if return_offset else -1
# uneven sharding case
full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks
shard_starting_idx = full_chunk_size * rank
if size_on_dim < shard_starting_idx:
return 0, size_on_dim if return_offset else -1
else:
local_shard_size = (
min(size_on_dim, shard_starting_idx + full_chunk_size)
- shard_starting_idx
)
return local_shard_size, shard_starting_idx if return_offset else -1
def _shard_tensor(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
"""
shard and scatter a tensor on a mesh dimension (use coordinate
0 on the mesh dimension as source of truth)
"""
my_coordinate = mesh.get_coordinate()
num_chunks = mesh.size(mesh_dim=mesh_dim)
if my_coordinate is None:
# if rank is not part of mesh, we simply return an empty tensor
return tensor.new_empty(0, requires_grad=tensor.requires_grad)
scatter_list, pad_sizes = self._split_tensor(
tensor, num_chunks, with_padding=True, contiguous=True
)
mesh_dim_local_rank = my_coordinate[mesh_dim]
output = torch.empty_like(scatter_list[mesh_dim_local_rank])
mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim)
# Only unpad if the local_tensor was padded on the dimension.
if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0:
output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank])
return output
def _reduce_shard_tensor(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
reduce_op: str,
mesh_dim: int,
) -> torch.Tensor:
"""
reduce and scatter a tensor on a mesh dimension
"""
my_coordinate = mesh.get_coordinate()
num_chunks = mesh.size(mesh_dim=mesh_dim)
if my_coordinate is None:
# if rank is not part of mesh, we simply return local_tensor,
# which should be an empty tensor
return tensor
is_padded = tensor.size(self.dim) % num_chunks != 0
if is_padded:
scattered_list, pad_sizes = self._split_tensor(
tensor, num_chunks, with_padding=True, contiguous=True
)
tensor = torch.cat(scattered_list, dim=self.dim)
elif not tensor.is_contiguous():
tensor = tensor.contiguous()
output = funcol.reduce_scatter_tensor(
tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim)
)
if is_padded:
output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined]
return output
def _to_replicate_tensor(
self,
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
) -> torch.Tensor:
"""
This function all_gather all shards and return a tensor that
is replicated on the previously sharded mesh dimension
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
# check if it's uneven, so we need to pad input tensor before all_gather
local_shape = list(local_tensor.size())
logical_dim_size = current_logical_shape[self.dim]
is_padded = logical_dim_size % num_chunks != 0
if is_padded:
full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks
pad_size = full_chunk_size - local_shape[self.dim]
local_tensor = pad_tensor(local_tensor, self.dim, pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
result = funcol.all_gather_tensor(
local_tensor,
gather_dim=self.dim,
group=(mesh, mesh_dim),
)
if is_padded:
unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined]
result = unpad_tensor(result, self.dim, unpad_size)
return result
def _replicate_to_shard(
self,
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_index: int,
) -> torch.Tensor:
"""
transform from replicated tensor to a sharded tensor on
the current rank, which would perform a local chunk
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
shards, _ = self._split_tensor(
local_tensor,
num_chunks,
with_padding=False,
contiguous=False,
)
return shards[shard_index].clone()
def _to_new_shard_dim(
self,
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
new_shard_dim: int,
) -> torch.Tensor:
"""
transform from existing sharded tensor to a new sharded tensor on
that shard on a new dimension, which performs an alltoall
"""
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we simply return local_tensor,
# which should be an empty tensor
return local_tensor
num_chunks = mesh.size(mesh_dim=mesh_dim)
old_dim_logical_size = current_logical_shape[self.dim]
new_dim_logical_size = current_logical_shape[new_shard_dim]
old_dim_padding = old_dim_logical_size % num_chunks != 0
new_dim_padding = new_dim_logical_size % num_chunks != 0
if old_dim_padding:
old_dim_full_chunk_size = (
old_dim_logical_size + num_chunks - 1
) // num_chunks
old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim)
local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size)
if new_dim_padding:
new_dim_full_chunk_size = (
new_dim_logical_size + num_chunks - 1
) // num_chunks
new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size(
new_shard_dim
)
local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size)
if not local_tensor.is_contiguous():
local_tensor = local_tensor.contiguous()
new_tensor = shard_dim_alltoall(
local_tensor, self.dim, new_shard_dim, mesh, mesh_dim
)
if old_dim_padding:
old_dim_unpad_size = (
old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined]
)
new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined]
if new_dim_padding:
local_shard_size_on_new_dim = self._local_shard_size_on_dim(
new_dim_logical_size, num_chunks, my_coordinate[mesh_dim]
)[0]
new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined]
new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined]
return new_tensor
def __eq__(self, other: object) -> bool:
if not isinstance(other, Shard):
return False
return self.dim == other.dim
def __hash__(self) -> int:
return hash(self.dim)
def __repr__(self) -> str:
"""
machine readable representation of the Shard placement
"""
return f"Shard(dim={self.dim})"
def __str__(self) -> str:
"""human readable representation of the Shard placement"""
return f"S({self.dim})"
# kw_only is only available in python >= 3.10
kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {}
@dataclass(frozen=True, **kw_only_dataclass)
class _StridedShard(Shard):
"""
_StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor
is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension.
We call this right-to-left sharding which is the opposite of the default
left-to-right sharding. See the example below:
tensor shape: [8, 8]
mesh: [[0, 1], [2, 3]], names=("dp", "tp")
placements: [Shard(0), Shard(0)]
The default sharding behavior shards the tensor on "dp" mesh dimension first then
"tp" dimension. The sharding result will be:
Rank | Mesh Coordinate | Shard Index
------------------------------------------------
0 | (0, 0) | 0 (row 0-1)
1 | (0, 1) | 1 (row 2-3)
2 | (1, 0) | 2 (row 4-5)
3 | (1, 1) | 3 (row 6-7)
While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on
"tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the
result:
Rank | Mesh Coordinate | Shard Index
------------------------------------------------
0 | (0, 0) | 0 (row 0-1)
1 | (0, 1) | 2 (row 4-5)
2 | (1, 0) | 1 (row 2-3)
3 | (1, 1) | 3 (row 6-7)
The consequence is, any attempt to redistribute this DTensor to a full replica will
produce a wrong result because the shard-to-replicate redistribution always happens
right-to-left, regardless it's left-to-right sharding or right-to-left. To address
this, we use _StridedShard placement to make this right-to-left sharding compatible
with our left-to-right convention on both tensor distribution and redistribution.
Now with _StridedShard, the right-to-left sharding above can be represented as:
tensor shape: [8, 8]
mesh: [[0, 1], [2, 3]], names=("dp", "tp")
placements: [_StridedShard(0, split_factor=2), Shard(0)]
And a left-to-right processing of `placements` will produce the same result, which is
different from using the `Shard` placement:
Rank | Mesh Coordinate | Shard Index
------------------------------------------------
0 | (0, 0) | 0 (row 0-1)
1 | (0, 1) | 2 (row 4-5)
2 | (1, 0) | 1 (row 2-3)
3 | (1, 1) | 3 (row 6-7)
The argument `split_factor` is the number of existing shards over the tensor sharding
dimension before processing the _StridedShard placement, as if the sharding happened
right-to-left. In the example above, the tensor should first be sharded on the "tp"
dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the
`split_factor` of the _StridedShard placement on "dp" dim is 2.
TODO: strided sharding needs to work fine with uneven sharding. Now it forbids
resharding if the tensor is unevenly sharded.
TODO: we should remove _StridedShard placement once we can unify it with Shard
"""
split_factor: int
def __eq__(self, other: object) -> bool:
if isinstance(other, _StridedShard):
return self.dim == other.dim and self.split_factor == other.split_factor
elif isinstance(other, Shard):
# TODO: this is to avoid extra all-gather in dtensor op dispatch
# note that sharding prop would not produce _StridedShard and an
# placement inequality would introduce an all-gather for resharding
return self.dim == other.dim
return False
def __hash__(self) -> int:
return hash((self.dim, self.split_factor))
def __repr__(self) -> str:
"""
machine readable representation of the _StridedShard placement
"""
return f"_StridedShard(dim={self.dim}, sf={self.split_factor})"
def __str__(self) -> str:
"""human readable representation of the _StridedShard placement"""
return f"_S({self.dim}, {self.split_factor})"
def _split_tensor(
self,
tensor: torch.Tensor,
num_chunks: int,
*,
with_padding: bool = True,
contiguous: bool = True,
) -> Tuple[List[torch.Tensor], List[int]]:
"""
TODO: currently _StridedShard does not support padding
"""
assert (
self.dim <= tensor.ndim
), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}"
total_split = num_chunks * self.split_factor
assert tensor.size(self.dim) % total_split == 0, (
"_StridedShard currently only allows even sharding but got tensor size"
f" {tensor.size(self.dim)} on dim {self.dim} and total split"
f" {total_split}={num_chunks} * {self.split_factor}"
)
group_size = self.split_factor
total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim))
tensor_list = [
torch.cat(
[
total_split_tensor_list[i + j * num_chunks] # stride is num_chunks
for j in range(group_size)
],
dim=self.dim,
)
for i in range(num_chunks)
]
if contiguous:
tensor_list = [t.contiguous() for t in tensor_list]
return tensor_list, []
def _to_replicate_tensor(
self,
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
) -> torch.Tensor:
"""
Note: currently _StridedShard does not support padding
"""
num_chunks = mesh.size(mesh_dim=mesh_dim)
total_split = num_chunks * self.split_factor
# NOTE: we require Strided Sharding to be even for now
assert current_logical_shape[self.dim] % total_split == 0, (
"_StridedShard requires even sharding but got tensor size "
f"{current_logical_shape[self.dim]} on dim {self.dim} and "
f"total split {total_split}=num_chunks {num_chunks} "
f"* split_factor {self.split_factor}"
)
result = funcol.all_gather_tensor(
local_tensor,
gather_dim=self.dim,
group=(mesh, mesh_dim),
)
if isinstance(result, funcol.AsyncCollectiveTensor):
result = result.wait()
tensor_shard_list = torch.chunk(result, total_split, dim=self.dim)
# rearrange the order
new_tensor_shard_list = []
for idx in range(len(tensor_shard_list)):
# the shard split of index `idx` is assigned a new index within
# _StridedShard._split_tensor:
# the original tensor was split into `total_split` chunks,
# all chunks with the same `idx % num_chunks` are merged into one
# new shard and placed on mesh's local rank `idx % num_chunks`
idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks
new_tensor_shard_list.append(tensor_shard_list[idx_after_split])
return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous()
@dataclass(frozen=True)
class Replicate(Placement):
"""
The ``Replicate()`` placement describes the DTensor replicating on a corresponding
``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a
replica of the global Tensor. The ``Replicate`` placement can be used by all
DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.)
"""
def __eq__(self, other: object) -> bool:
return isinstance(other, Replicate)
def __hash__(self) -> int:
# every replicate placement is the same
return -1
def __repr__(self) -> str:
"""
machine readable representation of the Replicate placement
"""
return "Replicate()"
def __str__(self) -> str:
"""
human readable representation of the Replicate placement
"""
return "R"
def _replicate_tensor(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
"""
Replicate (broadcast) a torch.Tensor on a mesh dimension (use
the first coordinate on the mesh dimension as source of truth)
"""
my_coordinate = mesh.get_coordinate()
if my_coordinate is None:
# if rank is not part of mesh, we simply return an empty tensor
return tensor.new_empty(0, requires_grad=tensor.requires_grad)
tensor = tensor.contiguous()
mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim)
return tensor
@dataclass(frozen=True)
class Partial(Placement):
"""
The ``Partial(reduce_op)`` placement describes the DTensor that is pending
reduction on a specified ``DeviceMesh`` dimension, where each rank on the
DeviceMesh dimension holds the partial value of the global Tensor. User can
redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)``
placement on the specified ``DeviceMesh`` dimension using ``redistribute``,
which would trigger necessary communication operations under the hood (i.e.
``allreduce``, ``reduce_scatter``).
Args:
reduce_op (str, optional): The reduction op to be used for the partial DTensor
to produce Replicated/Sharded DTensor. Only element-wise reduction operations
are supported, including: "sum", "avg", "product", "max", "min", default: "sum".
.. note:: The ``Partial`` placement can be generated as a result of the DTensor operators,
and can only be used by the ``DTensor.from_local`` API.
"""
reduce_op: str = "sum"
def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# Partial placement contract #1:
# _reduce_value: reduce the value of the tensor on the mesh dimension
return funcol.all_reduce(
tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim)
)
def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# Partial placement contract #2:
# _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# Partial placement contract #3:
# _partition_value: partition the value of a replicated tensor on the mesh dimension
# _partition_value is the conjugate operation of _reduce_value
# - i.e. _partition_value on a sum reduce op is just a divison operation
# - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation
# TODO: if the reduce_op is min/max, etc. the _partition_value should be a
# different operation
assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!"
num_chunks = mesh.size(mesh_dim=mesh_dim)
return tensor / num_chunks
def __eq__(self, other: object) -> bool:
if not isinstance(other, Partial):
return False
return self.reduce_op == other.reduce_op
def __hash__(self) -> int:
return 1 + hash(self.reduce_op)
def __repr__(self) -> str:
"""
machine readable representation of the Partial placement
"""
return f"Partial({self.reduce_op})"
def __str__(self) -> str:
"""
human readable representation of the Partial placement
"""
return "P"
# We keep the old _Partial name for a while for BC reason
_Partial = Partial