I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,53 @@
"""
:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list
of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the
optimizer locally on the workers where the parameters live. The distributed
optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to
apply the gradients on each worker.
"""
import warnings
import torch
from torch import optim
from .apply_optimizer_in_backward import (
_apply_optimizer_in_backward,
_get_in_backward_optimizers,
)
from .functional_adadelta import _FunctionalAdadelta
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamax import _FunctionalAdamax
from .functional_adamw import _FunctionalAdamW
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_sgd import _FunctionalSGD
from .named_optimizer import _NamedOptimizer
from .utils import as_functional_optim
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`TorchScript` support for functional optimizers is deprecated "
"and will be removed in a future PyTorch release. "
"Consider using the `torch.compile` optimizer instead.",
DeprecationWarning,
stacklevel=2,
)
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
# based on RPC being available.
if hasattr(torch._C, "_rpc_init"):
from .optimizer import DistributedOptimizer
from .post_localSGD_optimizer import PostLocalSGDOptimizer
from .zero_redundancy_optimizer import ZeroRedundancyOptimizer
__all__ = [
"as_functional_optim",
"DistributedOptimizer",
"PostLocalSGDOptimizer",
"ZeroRedundancyOptimizer",
]

View File

@ -0,0 +1,120 @@
from typing import Any, Dict, Iterable, List, no_type_check, Type
import torch
__all__: List[str] = []
# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
# without changing it's life-time.
# NOTE: Alternative is to add the meta-data as an attribute to the tensor,
# but that will serialize the meta-data if Tensor is serialized.
param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
@no_type_check
def _apply_optimizer_in_backward(
optimizer_class: Type[torch.optim.Optimizer],
params: Iterable[torch.nn.Parameter],
optimizer_kwargs: Dict[str, Any],
register_hook: bool = True,
) -> None:
"""
Upon ``backward()``, the optimizer specified for each parameter will fire after
the gradient has been accumulated into the parameter.
Note - gradients for these parameters will be set to None after ``backward()``.
This means that any other optimizer not specified via `_apply_optimizer_in_backward`
over this parameter will be a no-op.
Args:
optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
register_hook: (bool): whether to register a hook that runs the optimizer
after gradient for this parameter is accumulated. This is the default
way that optimizer in backward is implemented, but specific use cases
(such as DDP) may wish to override this to implement custom behavior.
(Default = True)
Example::
params_generator = model.parameters()
param_1 = next(params_generator)
remainder_params = list(params_generator)
apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
model(...).sum().backward() # after backward, parameters will already
# have their registered optimizer(s) applied.
"""
torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward")
@no_type_check
def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
# view_as creates a node in autograd graph that allows us access to the
# parameter's AccumulateGrad autograd function object. We register a
# hook on this object to fire the optimizer when the gradient for
# this parameter is ready (has been accumulated into .grad field)
# Don't create a new acc_grad if we already have one
# i.e. for shared parameters or attaching multiple optimizers to a param.
if param not in param_to_acc_grad_map:
param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[
0
][0]
optimizer = optimizer_class([param], **optimizer_kwargs)
if not hasattr(param, "_in_backward_optimizers"):
param._in_backward_optimizers = [] # type: ignore[attr-defined]
# TODO: Remove these attributes once we have a better way of accessing
# optimizer classes and kwargs for a parameter.
param._optimizer_classes = [] # type: ignore[attr-defined]
param._optimizer_kwargs = [] # type: ignore[attr-defined]
param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined]
param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined]
param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined]
if not register_hook:
return
def optimizer_hook(*_unused) -> None:
for opt in param._in_backward_optimizers: # type: ignore[attr-defined]
opt.step()
param.grad = None
handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
if param not in param_to_optim_hook_handle_map:
param_to_optim_hook_handle_map[param] = []
param_to_optim_hook_handle_map[param].append(handle)
for param in params:
_apply_optimizer_in_backward_to_param(param)
def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]:
"""
Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these
optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called
by the user and are intended to be used for things like checkpointing.
Args:
module: (torch.nn.Module): model to retrieve in-backward optimizers for
Returns:
List[torch.optim.Optimizer]: the in-backward optimizers.
Example::
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
optims = _get_optimizers_in_backward(model)
"""
optims: List[torch.optim.Optimizer] = []
for param in module.parameters():
optims.extend(getattr(param, "_in_backward_optimizers", []))
return optims

View File

@ -0,0 +1,107 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Adadelta Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdadelta:
def __init__(
self,
params: List[Tensor],
lr: float = 1.0,
rho: float = 0.9,
eps: float = 1e-6,
weight_decay: float = 0.0,
foreach: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
"rho": rho,
"eps": eps,
"weight_decay": weight_decay,
}
self.foreach = foreach
self.maximize = maximize
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
square_avgs = []
acc_deltas = []
state_steps = []
lr = self.defaults["lr"]
rho = self.defaults["rho"]
eps = self.defaults["eps"]
weight_decay = self.defaults["weight_decay"]
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_complex = False
for param, gradient in zip(params, gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
state["square_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["acc_delta"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
square_avgs.append(state["square_avg"])
acc_deltas.append(state["acc_delta"])
state_steps.append(state["step"])
with torch.no_grad():
F.adadelta(
params_with_grad,
grads,
square_avgs,
acc_deltas,
state_steps,
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
)

View File

@ -0,0 +1,111 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Adagrad Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly let the user pass gradients to the `step` function
# this is so that we could separate the gradients and parameters
# and allow multithreaded trainer to update the parameters
# without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdagrad:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-2,
lr_decay: float = 0.0,
weight_decay: float = 0.0,
initial_accumulator_value: float = 0.0,
warmup_lr_multiplier: float = 1.0,
warmup_num_iters: float = 0.0,
eps: float = 1e-10,
coalesce_grad: bool = True,
foreach: bool = False,
fused: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
"lr_decay": lr_decay,
"eps": eps,
"weight_decay": weight_decay,
"initial_accumulator_value": initial_accumulator_value,
"warmup_lr_multiplier": warmup_lr_multiplier,
"warmup_num_iters": warmup_num_iters,
}
self.coalesce_grad = coalesce_grad
self.foreach = foreach
self.fused = fused
self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
# TODO: no union or any types in TorchScript, make step a scalar tensor instead
# This is also needed by if we want to share_memory on the step across processes
for p in self.param_group["params"]:
self.state[p] = {
"sum": torch.full_like(p.data, initial_accumulator_value),
"step": torch.tensor(0.0),
}
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
state_sums = []
state_steps: List[Tensor] = []
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_sparse_grad, has_complex = False, False
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
has_sparse_grad |= gradient.is_sparse
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
state = self.state[param]
state_sums.append(state["sum"])
state_steps.append(state["step"])
with torch.no_grad():
F.adagrad(
params,
grads,
state_sums,
state_steps,
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
lr_decay=self.defaults["lr_decay"],
eps=self.defaults["eps"],
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
fused=self.fused,
grad_scale=None,
found_inf=None,
)

View File

@ -0,0 +1,198 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Tuple
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Adam Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdam:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
maximize: bool = False,
foreach: bool = False,
fused: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.defaults = {
"lr": lr,
"eps": eps,
"beta1": betas[0],
"beta2": betas[1],
"weight_decay": weight_decay,
}
self.amsgrad = amsgrad
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
"""
Similar to step, but operates on a single parameter and optionally a
gradient tensor.
"""
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
has_complex = torch.is_complex(param)
if grad is not None:
params_with_grad.append(param)
grads.append(grad)
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
has_complex=has_complex,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
has_complex = False
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
has_complex=has_complex,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)

View File

@ -0,0 +1,119 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Tuple
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Adamax Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdamax:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
foreach: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.defaults = {
"lr": lr,
"eps": eps,
"beta1": betas[0],
"beta2": betas[1],
"weight_decay": weight_decay,
}
self.foreach = foreach
self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_infs = []
state_steps: List[Tensor] = []
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_complex = False
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_inf"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_infs.append(state["exp_inf"])
state_steps.append(state["step"])
with torch.no_grad():
F.adamax(
params_with_grad,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=self.defaults["eps"],
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
)

View File

@ -0,0 +1,199 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Tuple
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional AdamW Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalAdamW:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
maximize: bool = False,
foreach: bool = False,
fused: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.defaults = {
"lr": lr,
"eps": eps,
"beta1": betas[0],
"beta2": betas[1],
"weight_decay": weight_decay,
}
self.amsgrad = amsgrad
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
has_complex = torch.is_complex(param)
if grad is not None:
params_with_grad.append(param)
grads.append(grad)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
has_complex=has_complex,
)
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_complex = False
for param, gradient in zip(self.param_group["params"], gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if self.amsgrad:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
state_steps.append(state["step"])
with torch.no_grad():
F.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults["beta1"],
beta2=self.defaults["beta2"],
lr=self.defaults["lr"],
weight_decay=self.defaults["weight_decay"],
eps=self.defaults["eps"],
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
has_complex=has_complex,
)

View File

@ -0,0 +1,126 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional RMSprop Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalRMSprop:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False,
foreach: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
"alpha": alpha,
"eps": eps,
"weight_decay": weight_decay,
"momentum": momentum,
}
self.centered = centered
self.foreach = foreach
self.maximize = maximize
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
square_avgs = []
grad_avgs = []
momentum_buffer_list = []
state_steps = []
lr = self.defaults["lr"]
alpha = self.defaults["alpha"]
eps = self.defaults["eps"]
momentum = self.defaults["momentum"]
weight_decay = self.defaults["weight_decay"]
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_complex = False
for param, gradient in zip(params, gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
state["square_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if momentum > 0:
state["momentum_buffer"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
if self.centered:
state["grad_avg"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state = self.state[param]
square_avgs.append(state["square_avg"])
if momentum > 0:
momentum_buffer_list.append(state["momentum_buffer"])
if self.centered:
grad_avgs.append(state["grad_avg"])
state_steps.append(state["step"])
with torch.no_grad():
F.rmsprop(
params_with_grad,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=self.centered,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
)

View File

@ -0,0 +1,103 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Tuple
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional Rprop Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalRprop:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-2,
etas: Tuple[float, float] = (0.5, 1.2),
step_sizes: Tuple[float, float] = (1e-6, 50),
foreach: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
}
self.etas = etas
self.step_sizes = step_sizes
self.foreach = foreach
self.maximize = maximize
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
prevs = []
step_sizes = []
state_steps = []
lr = self.defaults["lr"]
etaminus, etaplus = self.etas
step_size_min, step_size_max = self.step_sizes
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_complex = False
for param, gradient in zip(params, gradients):
if gradient is not None:
has_complex |= torch.is_complex(param)
params_with_grad.append(param)
grads.append(gradient)
# Lazy state initialization
if param not in self.state:
self.state[param] = {}
state = self.state[param]
state["step"] = torch.tensor(0.0)
state["prev"] = torch.zeros_like(
param, memory_format=torch.preserve_format
)
state["step_size"] = torch.full_like(gradient, lr)
state = self.state[param]
prevs.append(state["prev"])
step_sizes.append(state["step_size"])
state_steps.append(state["step"])
with torch.no_grad():
F.rprop(
params_with_grad,
grads,
prevs,
step_sizes,
state_steps,
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
etaplus=etaplus,
foreach=self.foreach,
maximize=self.maximize,
has_complex=has_complex,
)

View File

@ -0,0 +1,162 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
import torch
import torch.optim._functional as F
from torch import Tensor
__all__: List[str] = []
# Define a TorchScript compatible Functional SGD Optimizer
# where we use these optimizer in a functional way.
# Instead of using the `param.grad` when updating parameters,
# we explicitly allow the distributed optimizer pass gradients to
# the `step` function. In this way, we could separate the gradients
# and parameters and allow multithreaded trainer to update the
# parameters without data traces on accumulating to the same .grad.
# NOTE: This should be only used by distributed optimizer internals
# and not meant to expose to the user.
@torch.jit.script
class _FunctionalSGD:
def __init__(
self,
params: List[Tensor],
lr: float = 1e-2,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
maximize: bool = False,
foreach: bool = False,
fused: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
"momentum": momentum,
"dampening": dampening,
"weight_decay": weight_decay,
}
self.nesterov = nesterov
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional
# param group as it's not a common use case.
self.param_group = {"params": params}
def step_param(self, param: Tensor, grad: Optional[Tensor]):
"""Similar to self.step, but operates on a single parameter and
its gradient.
"""
# TODO: Once step_param interface is robust, refactor step to call
# step param on each param.
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]
dampening = self.defaults["dampening"]
lr = self.defaults["lr"]
params = [param]
momentum_buffer_list: List[Optional[Tensor]] = []
grads = []
has_sparse_grad = False
if grad is not None:
grads.append(grad)
if grad.is_sparse:
has_sparse_grad = True
if param not in self.state:
self.state[param] = {}
state = self.state[param]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
with torch.no_grad():
F.sgd(
params,
grads,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=self.nesterov,
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)
# update momentum_buffer in state
state = self.state[param]
momentum_buffer = momentum_buffer_list[0]
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer
def step(self, gradients: List[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
momentum_buffer_list: List[Optional[Tensor]] = []
lr = self.defaults["lr"]
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]
dampening = self.defaults["dampening"]
if len(params) != len(gradients):
raise ValueError(
"the gradients passed in does not equal to the size of the parameters!"
+ f"Params length: {len(params)}. "
+ f"Gradients length: {len(gradients)}"
)
has_sparse_grad = False
for param, gradient in zip(params, gradients):
if gradient is not None:
params_with_grad.append(param)
grads.append(gradient)
if gradient.is_sparse:
has_sparse_grad = True
if param not in self.state:
self.state[param] = {}
state = self.state[param]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
with torch.no_grad():
F.sgd(
params_with_grad,
grads,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=self.nesterov,
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)
# update momentum_buffers in state
for i, p in enumerate(params_with_grad):
state = self.state[p]
momentum_buffer = momentum_buffer_list[i]
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer

View File

@ -0,0 +1,341 @@
# mypy: allow-untyped-defs
import logging
import warnings
from copy import deepcopy
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
Optional,
overload,
Union,
)
import torch
import torch.nn as nn
from torch import optim
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
__all__: List[str] = []
logger = logging.getLogger(__name__)
class _NamedOptimizer(optim.Optimizer):
"""
``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.
We replace the original key (number) in an optim to the
fully qualified name (FQN) string. User can initialize the optim as they
initialize a PyTorch optim, the only difference is that they also need to
pass in the FQN of each parameters.
Args:
named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
Mapping from FQN to parameter.
optimizer_class (optim.Optimizer):
The class of optimizer to instantiate.
param_groups (Collection[Mapping[str, Any]]):
`param_groups` to pass to optimizer if specified.
The key of the inner map needs to be FQNs.
Default: None
module (nn.Module): the module whose parameters to updated
by the optimizer.
args: arguments to pass to the optimizer constructor.
kwargs: arguments to pass to the optimizer constructor.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch import optim
>>> from torch.distributed.optim import _NamedOptimizer
>>>
>>> # Define the named optimizer.
>>> m = Model(...)
>>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
>>> # Forward pass + backward pass.
>>> named_optim.step()
>>> ...
>>> # Call state_dict for the named optimizer returns a FQN state_dict.
>>> named_optim.state_dict()
Warning: This API is still in development and subject to change.
TODO: Add tutorial for _NamedOptimizer.
TODO: Add documentation in the docstring for the public attributes
like self.param_groups and self.named_parameters.
"""
def __init__(
self,
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
optimizer_class: optim.Optimizer,
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
module: Optional[nn.Module] = None,
*args,
**kwargs,
) -> None:
torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
self._param_groups_check()
self.named_parameters = dict(named_parameters)
params_for_optimizer = (
self.named_parameters.values() if param_groups is None else param_groups
)
self._optimizer = optimizer_class( # type: ignore[operator]
params_for_optimizer,
*args,
**kwargs,
)
self.module = module
if param_groups is None:
self.ordered_param_keys = list(self.named_parameters.keys())
else:
warnings.warn(
"Since we pass in param_groups, we will use param_groups to "
"initialize the optimizer, not all parameters of the module."
)
param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
ordered_param_keys = []
for group in param_groups:
for param in group["params"]:
if param not in param_to_key:
raise ValueError(
f"Expect param name {param} found in param group but is missing."
)
ordered_param_keys.append(param_to_key[param])
self.ordered_param_keys = ordered_param_keys
# Update param_groups from optimizer.
self.param_groups = self._optimizer.param_groups
def _param_groups_check(self):
if self.param_groups is not None:
for param_group in self.param_groups:
assert isinstance(param_group, dict), "param group must be a dict"
assert "params" in param_group, "param group must contain key params"
params = param_group["params"]
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
for param in params:
if not isinstance(param, torch.Tensor):
raise TypeError(
"optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param)
)
param_group["params"] = params
def state_dict(self) -> Dict[str, Any]:
"""
Return the ``state_dict`` of the optimizer.
Instead of using number to index
parameters, we will use module fully qualified name (FQN) as the key.
"""
state_dict = self._optimizer.state_dict()
param_groups = state_dict["param_groups"]
ret_state = {
self.ordered_param_keys[st_key]: state_val
for st_key, state_val in state_dict["state"].items()
}
ret_groups = []
for group in param_groups:
param_keys = []
for param in group["params"]:
param_keys.append(self.ordered_param_keys[param])
ret_group = {"params": sorted(param_keys)}
for k, v in group.items():
if k != "params":
ret_group[k] = deepcopy(v)
ret_groups.append(ret_group)
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
@overload
def step(self, closure: None = ...) -> None:
...
@overload
def step(self, closure: Callable[[], float]) -> float:
...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""
Perform a single optimization step.
This will call :meth:`torch.optim.Optimizer.step` on the wrapped
optimizer.
"""
return self._optimizer.step(closure=closure)
@property
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
return self._optimizer.state
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
"""
Define the default behavior to load a state_dict for ``_NamedOptimizer``.
Sample Code
```
my_model = MyModule()
optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
...
optim_state_dict = optimizer.state_dict()
...
...
optimizer.load_state_dict(optim_state_dict)
...
```
Args:
state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
Note that this state dict update is performed in place.
.. note:: PyTorch is using lazy init to initialize the optim states.
So it is possible that there is no optim state when user call
``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
that users can only call ``load_state_dict`` after the state is initialized.
By doing this, we can validate the optim ``state_dict`` to be loaded.
"""
new_state_dict = self._optimizer.state_dict()
state_dict = self._pre_load_state_dict(state_dict)
state = state_dict["state"]
new_state = new_state_dict["state"]
if len(new_state) == 0:
raise ValueError(
"Expects the optim to be initialized before load but found not initialized."
)
for idx, param_key in enumerate(self.ordered_param_keys):
# When the conditional training is performed, not all parameters are updated in the optim.
if param_key not in state.keys():
continue
if len(state[param_key]) != len(new_state[idx]):
raise ValueError(
f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
)
# Iterate through all optimizer states.
for state_key, state_val in new_state[idx].items():
if state_key not in state[param_key]:
raise ValueError(
f"Expects state {state_key} for parameter {param_key} but not found."
)
src_state_val = state[param_key][state_key]
if isinstance(state_val, ShardedTensor):
assert isinstance(src_state_val, ShardedTensor)
num_shards = len(state_val.local_shards())
num_new_shards = len(src_state_val.local_shards())
if num_shards != num_new_shards:
raise ValueError(
f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
)
for shard, src_shard in zip(
state_val.local_shards(), src_state_val.local_shards()
):
shard.tensor.detach().copy_(src_shard.tensor)
elif isinstance(state_val, torch.Tensor):
assert isinstance(src_state_val, torch.Tensor)
state_val.detach().copy_(src_state_val)
else:
new_state[idx][state_key] = deepcopy(src_state_val)
# Load param_groups of state_dict
src_param_groups = state_dict["param_groups"]
new_param_groups = new_state_dict["param_groups"]
src_group_map = {}
for group in src_param_groups:
param_keys = list(group["params"])
src_group_map[_gen_param_group_key(param_keys)] = group
new_group_map = {}
for new_group in new_param_groups:
param_keys = []
for param_key in new_group["params"]:
param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload]
new_group_map[_gen_param_group_key(param_keys)] = new_group
for group_key, new_group in new_group_map.items():
# When not all parameters are used in training or receive gradient, aka., not all parameters
# would be in the param_group. Thus we skip the group_key here.
if group_key not in src_group_map:
continue
src_group = src_group_map[group_key]
if len(src_group) != len(new_group):
raise ValueError(
f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
)
for k in src_group:
if k not in new_group:
raise ValueError(
f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
)
if k != "params":
new_group[k] = deepcopy(src_group[k])
self._optimizer.load_state_dict(new_state_dict)
def add_param_group(self, param_group: Mapping[str, Any]) -> None:
"""
Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
Warning: This API is still in development and subject to change.
"""
assert isinstance(param_group, dict), "param group must be a dict"
params = param_group["params"]
if isinstance(params, torch.Tensor):
param_group["params"] = [params]
else:
param_group["params"] = list(params)
param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
for param in param_group["params"]:
if param not in param_to_key:
raise ValueError("some parameters are not in the module")
self.ordered_param_keys.append(param_to_key[param])
self._optimizer.add_param_group(param_group)
# Update param_groups from optimizer.
self.param_groups = self._optimizer.param_groups
def init_state(self) -> None:
"""
Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.
This allows doing in-place loading of optimizer state from a checkpoint.
"""
for param in self.named_parameters.values():
if param.requires_grad:
t = torch.zeros_like(param)
param.grad = torch.autograd.Variable(t)
# Calling ``step`` will load the initial state for optimizer states.
self.step(closure=None)
def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
return FSDP.optim_state_dict_to_load(
self.module, self._optimizer, state_dict, is_named_optimizer=True
)
return state_dict
def _post_state_dict(self, state_dict) -> Dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
FSDP.optim_state_dict(self.module, self._optimizer, state_dict)
return state_dict
def _gen_param_group_key(param_keys: List[str]) -> str:
"""Concatenate all param keys as a unique indentifier for one param group."""
return "/".join(sorted(param_keys))

View File

@ -0,0 +1,257 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import logging
from collections import defaultdict
from threading import Lock
from typing import List, Optional
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.jit as jit
import torch.nn as nn
from torch import Tensor
from torch.distributed.rpc import RRef
from .utils import functional_optim_map
__all__ = ["DistributedOptimizer"]
logger = logging.getLogger(__name__)
# XXX: we define a _ScriptModuleOptimizer here to explicitly
# compile the FunctionalOptimizer class into TorchScript
# This is because ScriptClass instance still lives in
# python unless you explicitly compile it as an attribute
# in ScriptModule or pass it to a ScriptFunction
# _ScriptLocalOptimizerInterface serves as a common
# interface type for Optimizer ScriptModules.
#
# TODO (wanchaol): remove this once we added TorchScript
# class reference semantics
@jit.interface
class _ScriptLocalOptimizerInterface:
def step(self, autograd_ctx_id: int) -> None:
pass
class _ScriptLocalOptimizer(nn.Module):
# TorchScript does not support multithread concurrent compiling.
# request_callback might invoke concurrent compiling, so we
# serialize the compiling with a lock
compile_lock = Lock()
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
super().__init__()
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(self._local_params, *args, **kwargs)
@jit.export
def step(self, autograd_ctx_id: int):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: List[Optional[Tensor]] = [
all_local_grads[p] if p in all_local_grads else None
for p in self._local_params
]
self.optim.step(grads)
# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
# we have converted all to functional optimizer in distributed.optim
class _LocalOptimizer:
# Ideally we would only need to share a lock for instances of
# _LocalOptimizer that deal with the same parameters. We are
# making a simplifying assumption here that if there is more
# than one instance of _LocalOptimizer per worker, they will
# be optimizing the same parameters (e.g. each data parallel
# trainer will create its own instance of _LocalOptimizer but
# they will all optimize the same parameters on each worker)
global_lock = Lock()
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(self._local_params, *args, **kwargs)
def step(self, autograd_ctx_id):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
with _LocalOptimizer.global_lock:
for param, grad in all_local_grads.items():
param.grad = grad
self.optim.step()
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
local_optim = local_optim_rref.local_value()
local_optim.step(autograd_ctx_id)
# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
with _ScriptLocalOptimizer.compile_lock:
script_optim = jit.script(optim)
return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
@jit.script
def _script_local_optimizer_step(
local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int
) -> None:
local_optim = local_optim_rref.local_value()
local_optim.step(autograd_ctx_id)
def _wait_for_all(rpc_futs):
# TODO: improve error propagation
exception = None
results = []
for fut in rpc_futs:
try:
results.append(fut.wait())
except Exception as e:
results.append(e)
exception = e
if exception is not None:
raise exception
return results
class DistributedOptimizer:
"""
DistributedOptimizer takes remote references to parameters scattered
across workers and applies the given optimizer locally for each parameter.
This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
to retrieve the gradients for specific parameters.
Concurrent calls to
:meth:`~torch.distributed.optim.DistributedOptimizer.step`,
either from the same or different clients, will
be serialized on each worker -- as each worker's optimizer can only work
on one set of gradients at a time. However, there is no guarantee that
the full forward-backward-optimizer sequence will execute for one client
at a time. This means that the gradients being applied may not correspond
to the latest forward pass executed on a given worker. Also, there is no
guaranteed ordering across workers.
`DistributedOptimizer` creates the local optimizer with TorchScript enabled
by default, so that optimizer updates are not blocked by the Python Global
Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
Model Parallel). This feature is currently enabled for most optimizers. You
can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
for your own custom optimizers.
Args:
optimizer_class (optim.Optimizer): the class of optimizer to
instantiate on each worker.
params_rref (list[RRef]): list of RRefs to local or remote parameters
to optimize.
args: arguments to pass to the optimizer constructor on each worker.
kwargs: arguments to pass to the optimizer constructor on each worker.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> import torch.distributed.autograd as dist_autograd
>>> import torch.distributed.rpc as rpc
>>> from torch import optim
>>> from torch.distributed.optim import DistributedOptimizer
>>>
>>> with dist_autograd.context() as context_id:
>>> # Forward pass.
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> loss = rref1.to_here() + rref2.to_here()
>>>
>>> # Backward pass.
>>> dist_autograd.backward(context_id, [loss.sum()])
>>>
>>> # Optimizer.
>>> dist_optim = DistributedOptimizer(
>>> optim.SGD,
>>> [rref1, rref2],
>>> lr=0.05,
>>> )
>>> dist_optim.step(context_id)
__ https://github.com/pytorch/tutorials/pull/1465
"""
def __init__(self, optimizer_class, params_rref, *args, **kwargs):
torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
per_worker_params_rref = defaultdict(list)
for param in params_rref:
per_worker_params_rref[param.owner()].append(param)
if optimizer_class in functional_optim_map and jit._state._enabled:
optim_ctor = functional_optim_map.get(optimizer_class)
else:
optim_ctor = optimizer_class
self.is_functional_optim = optim_ctor != optimizer_class
if self.is_functional_optim:
optimizer_new_func = _new_script_local_optimizer
else:
logger.warning(
"Creating the optimizer %s without TorchScript support, "
"this might result in slow computation time in multithreading environment"
"(i.e. Distributed Model Parallel training on CPU) due to the Python's "
"Global Interpreter Lock (GIL). Please file an issue if you need this "
"optimizer in TorchScript. ",
optimizer_class,
)
optimizer_new_func = _new_local_optimizer
remote_optim_futs = []
for worker, param_rrefs in per_worker_params_rref.items():
remote_optim_rref_fut = rpc.rpc_async(
worker,
optimizer_new_func,
args=(optim_ctor, param_rrefs) + args,
kwargs=kwargs,
)
remote_optim_futs.append(remote_optim_rref_fut)
self.remote_optimizers = _wait_for_all(remote_optim_futs)
def step(self, context_id):
"""
Performs a single optimization step.
This will call :meth:`torch.optim.Optimizer.step` on each worker
containing parameters to be optimized, and will block until all workers
return. The provided ``context_id`` will be used to retrieve the
corresponding :class:`~torch.distributed.autograd.context` that
contains the gradients that should be applied to the parameters.
Args:
context_id: the autograd context id for which we should run the
optimizer step.
"""
dist_autograd._is_valid_context(context_id)
optimizer_step_func = (
_script_local_optimizer_step
if self.is_functional_optim
else _local_optimizer_step
)
rpc_futs = []
for optimizer in self.remote_optimizers:
rpc_futs.append(
rpc.rpc_async(
optimizer.owner(),
optimizer_step_func,
args=(optimizer, context_id),
)
)
_wait_for_all(rpc_futs)

View File

@ -0,0 +1,110 @@
# mypy: allow-untyped-defs
import warnings
import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers
class PostLocalSGDOptimizer(torch.optim.Optimizer):
r"""
Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
This optimizer runs local optimizer at every step.
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
Args:
optim: The local optimizer.
averager: A model averager instance to run post-localSGD algorithm.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>> PostLocalSGDState,
>>> post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>> optim=local_optim,
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>> opt.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> opt.step()
"""
def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
self.optim = optim
self.param_groups = self.optim.param_groups
self.averager = averager
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
r"""
This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
but adds an extra entry to record model averager's step to the checkpoint
to ensure reload does not cause unnecessary warm up again.
"""
optim_state_dict = self.optim.state_dict()
optim_state_dict["step"] = self.averager.step
return optim_state_dict
def load_state_dict(self, state_dict):
r"""
This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
but also restores model averager's step value to the one
saved in the provided ``state_dict``.
If there is no ``"step"`` entry in ``state_dict``,
it will raise a warning and initialize the model averager's step to 0.
"""
self.optim.load_state_dict(state_dict)
if "step" in state_dict:
self.averager.step = state_dict["step"]
else:
warnings.warn(
"Loaded state dict does not contain a step counter for an averager. "
"Setting step counter to 0."
)
self.averager.step = 0
def step(self):
r"""
Performs a single optimization step (parameter update).
"""
self.optim.step()
self.averager.average_parameters(params=self.param_groups)
def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
self.optim.zero_grad(set_to_none=set_to_none)
def add_param_group(self, param_group):
self.optim.add_param_group(param_group)

View File

@ -0,0 +1,66 @@
# mypy: allow-untyped-defs
from typing import Type
from torch import optim
from .functional_adadelta import _FunctionalAdadelta
from .functional_adagrad import _FunctionalAdagrad
from .functional_adam import _FunctionalAdam
from .functional_adamax import _FunctionalAdamax
from .functional_adamw import _FunctionalAdamW
from .functional_rmsprop import _FunctionalRMSprop
from .functional_rprop import _FunctionalRprop
from .functional_sgd import _FunctionalSGD
# dict to map a user passed in optimizer_class to a functional
# optimizer class if we have already defined inside the
# distributed.optim package, this is so that we hide the
# functional optimizer to user and still provide the same API.
functional_optim_map = {
optim.Adagrad: _FunctionalAdagrad,
optim.Adam: _FunctionalAdam,
optim.AdamW: _FunctionalAdamW,
optim.SGD: _FunctionalSGD,
optim.Adadelta: _FunctionalAdadelta,
optim.RMSprop: _FunctionalRMSprop,
optim.Rprop: _FunctionalRprop,
optim.Adamax: _FunctionalAdamax,
}
def register_functional_optim(key, optim):
"""
Interface to insert a new functional optimizer to functional_optim_map
``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
Example::
>>> # import the new functional optimizer
>>> # xdoctest: +SKIP
>>> from xyz import fn_optimizer
>>> from torch.distributed.optim.utils import register_functional_optim
>>> fn_optim_key = "XYZ_optim"
>>> register_functional_optim(fn_optim_key, fn_optimizer)
"""
if key not in functional_optim_map:
functional_optim_map[key] = optim
def as_functional_optim(optim_cls: Type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError as e:
raise ValueError(
f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
) from e
return _create_functional_optim(functional_cls, *args, **kwargs)
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
return functional_optim_cls(
[],
*args,
**kwargs,
_allow_empty_param_list=True,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,84 @@
# mypy: allow-untyped-defs
import enum
from typing import Any, Callable, overload
import torch
from torch.distributed.algorithms.join import Joinable, JoinHook
from torch.optim import Optimizer
class _ZeROJoinHook(JoinHook):
zero: Any = ...
def __init__(self, zero: Any) -> None: ...
def main_hook(self) -> None: ...
class _DDPBucketAssignment:
bucket_index: int
parameters: list[torch.Tensor]
offset: int
device: torch.device
tensor: torch.Tensor | None
class _OverlapStatus(enum.IntEnum):
UNINITIALIZED: int = ...
DDP_HAS_REBUILT_BUCKETS: int = ...
INITIALIZED: int = ...
class _OverlapInfo:
status: Any = ...
params_per_bucket: Any = ...
params_per_rank: Any = ...
offsets: Any = ...
broadcast_handles: Any = ...
bucket_index_to_future: Any = ...
bucket_index_to_bucket: Any = ...
bucket_indices_seen: Any = ...
assigned_ranks_per_bucket: list[set[int]] = ...
total_size: int = ...
shard_buckets: bool = ...
def __init__(self) -> None: ...
def wait_for_broadcasts(self) -> None: ...
def clear_per_iter_info(self) -> None: ...
class ZeroRedundancyOptimizer(Optimizer, Joinable):
functional_optim_map: Any = ...
initialized: bool = ...
process_group: Any = ...
world_size: int = ...
rank: int = ...
global_rank: int = ...
parameters_as_bucket_view: bool = ...
optim: Any = ...
_device_to_device_index: dict[torch.device, int] = ...
_overlap_with_ddp: bool = ...
_overlap_info: _OverlapInfo = ...
_buckets: list[list[torch.Tensor]] = ...
_bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
def __init__(
self,
params: Any,
optimizer_class: type[Optimizer],
process_group: Any | None = ...,
parameters_as_bucket_view: bool = ...,
overlap_with_ddp: bool = ...,
**defaults: Any,
) -> None: ...
def add_param_group(self, param_group: dict[str, Any]) -> None: ...
def consolidate_state_dict(self, to: int = ...) -> None: ...
@overload
def step(self, closure: None = ..., **kwargs: Any) -> None: ...
@overload
def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
def state_dict(self) -> dict[str, Any]: ...
def _local_step(
self,
gradients: list[torch.Tensor | None] | None = None,
closure: Callable[[], float] | None = None,
**kwargs: Any,
) -> float | None: ...
def _get_assigned_rank(self, bucket_index: int) -> int: ...
def _init_zero_for_overlap(self) -> None: ...
def join_hook(self, **kwargs): ...
@property
def join_device(self) -> torch.device: ...
def join_process_group(self) -> Any: ...