I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,63 @@
"""
:mod:`torch.optim` is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can also be easily integrated in the
future.
"""
from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils
from torch.optim._adafactor import Adafactor as Adafactor
from torch.optim.adadelta import Adadelta as Adadelta
from torch.optim.adagrad import Adagrad as Adagrad
from torch.optim.adam import Adam as Adam
from torch.optim.adamax import Adamax as Adamax
from torch.optim.adamw import AdamW as AdamW
from torch.optim.asgd import ASGD as ASGD
from torch.optim.lbfgs import LBFGS as LBFGS
from torch.optim.nadam import NAdam as NAdam
from torch.optim.optimizer import Optimizer as Optimizer
from torch.optim.radam import RAdam as RAdam
from torch.optim.rmsprop import RMSprop as RMSprop
from torch.optim.rprop import Rprop as Rprop
from torch.optim.sgd import SGD as SGD
from torch.optim.sparse_adam import SparseAdam as SparseAdam
Adafactor.__module__ = "torch.optim"
del adadelta # type: ignore[name-defined] # noqa: F821
del adagrad # type: ignore[name-defined] # noqa: F821
del adam # type: ignore[name-defined] # noqa: F821
del adamw # type: ignore[name-defined] # noqa: F821
del sparse_adam # type: ignore[name-defined] # noqa: F821
del adamax # type: ignore[name-defined] # noqa: F821
del asgd # type: ignore[name-defined] # noqa: F821
del sgd # type: ignore[name-defined] # noqa: F821
del radam # type: ignore[name-defined] # noqa: F821
del rprop # type: ignore[name-defined] # noqa: F821
del rmsprop # type: ignore[name-defined] # noqa: F821
del optimizer # type: ignore[name-defined] # noqa: F821
del nadam # type: ignore[name-defined] # noqa: F821
del lbfgs # type: ignore[name-defined] # noqa: F821
__all__ = [
"Adafactor",
"Adadelta",
"Adagrad",
"Adam",
"Adamax",
"AdamW",
"ASGD",
"LBFGS",
"lr_scheduler",
"NAdam",
"Optimizer",
"RAdam",
"RMSprop",
"Rprop",
"SGD",
"SparseAdam",
"swa_utils",
]

View File

@ -0,0 +1,656 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch import Tensor
from .optimizer import (
_disable_dynamo_if_unsupported,
_get_scalar_dtype,
_maximize_doc,
Optimizer,
ParamsT,
TensorListList,
)
__all__ = ["Adafactor", "adafactor"]
class Adafactor(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
beta2_decay: float = -0.8,
eps: Tuple[Optional[float], float] = (None, 1e-3),
d: float = 1.0,
weight_decay: float = 0.0,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
if not 0.0 >= beta2_decay:
raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}")
if eps[0] is not None and not 0.0 <= eps[0]:
raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}")
if not 0.0 <= eps[1]:
raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}")
if not 1.0 <= d:
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
if not 0.0 <= weight_decay:
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
defaults = dict(
lr=lr,
beta2_decay=beta2_decay,
eps=eps,
d=d,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype())
def _init_group(
self,
group,
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
):
for p in group["params"]:
if p.grad is None:
continue
if torch.is_complex(p):
raise RuntimeError("Adafactor does not support complex parameters")
if p.grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients")
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
if p.grad.dim() > 1:
row_shape = list(p.grad.shape)
row_shape[-1] = 1
# Row factor of variance, NOT the same shape as grads (will be reduced along last dim)
state["row_var"] = p.grad.new_zeros(row_shape)
col_shape = list(p.grad.shape)
col_shape[-2] = 1
# Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim)
state["col_var"] = p.grad.new_zeros(col_shape)
else:
state["variance"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
row_vars.append(state.get("row_var", None))
col_vars.append(state.get("col_var", None))
variances.append(state.get("variance", None))
state_steps.append(state["step"])
return False # has_complex
@torch.no_grad()
def step(self, closure=None):
r"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
row_vars: List[Optional[Tensor]] = []
col_vars: List[Optional[Tensor]] = []
variances: List[Optional[Tensor]] = []
state_steps: List[Tensor] = []
eps1, eps2 = group["eps"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
)
adafactor(
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
d=group["d"],
lr=group["lr"],
beta2_decay=group["beta2_decay"],
weight_decay=group["weight_decay"],
eps1=eps1,
eps2=eps2,
foreach=group["foreach"],
maximize=group["maximize"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
has_complex=has_complex,
)
return loss
Adafactor.__doc__ = (
r"""Implements Adafactor algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{(lr)}, \: \tau
\text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\
&\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
&\hspace{15mm} \: \lambda \text{(weight decay)},
\: \textit{maximize} \\
&\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\
&\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\
&\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\
&\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2,
\text{RMS}(\theta_{t-1}))\rho_t \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\
&\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\
&\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\
&\hspace{10mm}\widehat{V}_t \leftarrow
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
(1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\
&\hspace{5mm}U_t \leftarrow
\frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
&\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all.
Deviating from the paper, this implementation uses lr for applying weight
decay and as the maximum value for relative step size rho_t. Note that in
the paper, a constant of 0.01 is used as the maximum value for relative
step size, and so we set 0.01 as the default value. (default: 1e-2)
beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
to the coefficient used for computing the running average of the gradient
squared. (default: -0.8)
eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
of the update calculation to improve numerical stability. This use of epsilon1
deviates from the algorithm written in the paper! See note below for more details.
epsilon2 is the term used to avoid having too small a weight update when applying
parameter scaling. (default: (None, 1e-3))
d (float, optional): the clipping threshold, used to avoid larger-than-desired
updates.
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
foreach (bool, optional): whether foreach implementation of optimizer is used. Note
that the foreach implementation uses ~ sizeof(params) more peak memory than the
for-loop version due to the intermediates being a tensorlist vs just one tensor.
As Adafactor is commonly used when memory is prohibitive, Adafactor will default
to the slower single tensor for-loop implementation unless this flag is explicitly
True. This behavior is contrary to other optimizers, which will attempt defaulting
to foreach on CUDA for faster runtime. (default: None)
{_maximize_doc}"""
+ r"""
.. Note::
The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern
and implementations in some other frameworks with its use of learning rate and
:math:`\epsilon_1`.
Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not
use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
affect the step size.
This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:
.. math::
\begin{aligned}
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
\end{aligned}
This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as
the maximum value of :math:`\rho_t`
.. math::
\begin{aligned}
&\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
\end{aligned}
Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should
be computed, and so we use the learning rate as a coefficient for decoupled weight
decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.
Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as
a stabilizing term when the squared gradient becomes small.
This stabilization can be written as
.. math::
\begin{aligned}
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\
&\hspace{5mm}\widehat{V}_t \leftarrow
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
&\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
\end{aligned}
where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
are left alone, and we apply :math:`\epsilon_1` at the final calculation of
the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.
This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which
apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
not in the calculations after:
.. math::
\begin{aligned}
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\
&\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\
&\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\
\end{aligned}
.. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
https://arxiv.org/pdf/1804.04235
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
"""
)
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
# so row_var and col_var will be None while variance will be filled.
# Contrarily, for a grad with multiple dimensions, we will factor along the last
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
d: float,
lr: Union[Tensor, float],
beta2_decay: float,
weight_decay: float,
eps1: Optional[float],
eps2: float,
maximize: bool,
has_complex: bool,
):
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
step_t = state_steps[i]
row_var = row_vars[i]
col_var = col_vars[i]
variance = variances[i]
if eps1 is None:
eps1 = torch.finfo(param.dtype).eps
# update step
step_t += 1
step_float = step_t.item()
one_minus_beta2_t = step_float**beta2_decay
rho_t = min(lr, 1 / (step_float**0.5))
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
# Perform stepweight decay
if weight_decay != 0:
param.mul_(1 - lr * weight_decay)
if grad.dim() > 1:
assert (
row_var is not None and col_var is not None
), "row_var and col_var should be defined when grad is multidimensional"
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_mean = (
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
)
row_var.lerp_(row_mean, one_minus_beta2_t)
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
col_mean = (
torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
)
col_var.lerp_(col_mean, one_minus_beta2_t)
var_estimate = row_var @ col_var
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
else:
assert (
variance is not None
), "variance should be defined when grad is a vector"
grad_squared = grad * grad
variance.lerp_(grad_squared, one_minus_beta2_t)
# avoid writing into variance during update
var_estimate = variance.clone()
# square the eps1 as we sqrt after to keep eps1's magnitude
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_()
update.mul_(grad)
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))
param.add_(update, alpha=-alpha / denom)
def _group_tensors_by_device_dtype_and_is_multidim(
tensorlists: TensorListList,
) -> Dict[
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
List[List[Optional[Tensor]]],
]:
"""Groups tensors by device, dtype, AND multidimensionality -- whether the tensor
has multiple dims or just one dim (is a vector). This allows the foreach impl of
Adafactor to assume that every group of params will either be factored or not."""
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(tensorlists)
ultra_grouped_tensors: Dict[
Tuple[Optional[torch.device], Optional[torch.dtype], bool],
List[List[Optional[Tensor]]],
] = {}
for (device, dtype), (tensorlists, _) in grouped_tensors.items():
matrix_key = (device, dtype, True)
vector_key = (device, dtype, False)
# assumes grad is the second tensorlist
for j, tensor in enumerate(tensorlists[1]):
assert tensor is not None, "grad should not be None"
if tensor.dim() > 1:
if matrix_key not in ultra_grouped_tensors:
ultra_grouped_tensors[matrix_key] = [[] for _ in tensorlists]
for i in range(len(tensorlists)):
ultra_grouped_tensors[matrix_key][i].append(tensorlists[i][j])
else:
if vector_key not in ultra_grouped_tensors:
ultra_grouped_tensors[vector_key] = [[] for _ in tensorlists]
for i in range(len(tensorlists)):
ultra_grouped_tensors[vector_key][i].append(tensorlists[i][j])
return ultra_grouped_tensors
def _multi_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
# so row_var and col_var will be None while variance will be filled.
# Contrarily, for a grad with multiple dimensions, we will factor along the last
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
d: float,
lr: Union[Tensor, float],
beta2_decay: float,
weight_decay: float,
eps1: Optional[float],
eps2: float,
maximize: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
grouped_tensors = _group_tensors_by_device_dtype_and_is_multidim(
[params, grads, row_vars, col_vars, variances, state_steps] # type: ignore[list-item]
)
for (_, dtype, is_multidim), (
(
device_params_,
device_grads_,
device_row_vars_,
device_col_vars_,
device_variances_,
device_state_steps_,
)
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_steps = cast(List[Tensor], device_state_steps_)
if eps1 is None:
assert (
dtype is not None
), "dtype is needed to compute eps1 when eps1 is unset"
eps1 = torch.finfo(dtype).eps
if TYPE_CHECKING:
assert device_state_steps[0] is not None
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1.0)
one_minus_beta2_ts = []
beta2_ts = []
rho_ts = []
for s in device_state_steps:
one_minus_beta2_ts.append(s.item() ** beta2_decay)
beta2_ts.append(1 - s.item() ** beta2_decay)
rho_ts.append(min(lr, 1 / (s.item() ** 0.5)))
alphas = [
max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r
for p, r in zip(device_params, rho_ts)
]
# Perform stepweight decay
if weight_decay != 0:
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
if is_multidim:
device_row_vars = cast(List[Tensor], device_row_vars_)
device_col_vars = cast(List[Tensor], device_col_vars_)
assert (
device_row_vars[0] is not None and device_col_vars[0] is not None
), "row_var and col_var should be defined when grad is multidimensional"
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_means = [
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
]
torch._foreach_mul_(row_means, row_means)
torch._foreach_div_(row_means, [grad.size(-1) for grad in device_grads])
torch._foreach_mul_(device_row_vars, beta2_ts)
torch._foreach_mul_(row_means, one_minus_beta2_ts)
torch._foreach_add_(device_row_vars, row_means)
del row_means
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
col_means = [
torch.norm(grad, dim=-2, keepdim=True) for grad in device_grads
]
torch._foreach_mul_(col_means, col_means)
torch._foreach_div_(col_means, [grad.size(-2) for grad in device_grads])
torch._foreach_mul_(device_col_vars, beta2_ts)
torch._foreach_mul_(col_means, one_minus_beta2_ts)
torch._foreach_add_(device_col_vars, col_means)
del col_means
var_estimates = [
row_var @ col_var
for row_var, col_var in zip(device_row_vars, device_col_vars)
]
row_var_means = [
row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars
]
torch._foreach_clamp_min_(row_var_means, eps1)
torch._foreach_div_(var_estimates, row_var_means)
del row_var_means
else:
device_variances = cast(List[Tensor], device_variances_)
assert (
device_variances[0] is not None
), "variance should be defined when grad is a vector"
grads_squared = torch._foreach_mul(device_grads, device_grads)
torch._foreach_mul_(device_variances, beta2_ts)
torch._foreach_mul_(grads_squared, one_minus_beta2_ts)
torch._foreach_add_(device_variances, grads_squared)
del grads_squared
# avoid writing into variance during update
var_estimates = [v.clone() for v in device_variances]
# square the eps1 as we sqrt after to keep eps1's magnitude
torch._foreach_clamp_min_(var_estimates, eps1 * eps1)
torch._foreach_sqrt_(var_estimates)
torch._foreach_reciprocal_(var_estimates)
torch._foreach_mul_(var_estimates, device_grads)
updates = var_estimates
alphas = [
-a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)))
for a, update in zip(alphas, updates)
]
torch._foreach_mul_(updates, alphas)
torch._foreach_add_(device_params, updates)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
def adafactor(
params: List[Tensor],
grads: List[Tensor],
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
d: float,
lr: Union[float, Tensor],
beta2_decay: float,
weight_decay: float,
eps1: float,
eps2: float,
maximize: bool,
):
r"""Functional API that performs Adafactor algorithm computation.
See :class:`~torch.optim.Adafactor` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"`state_steps` argument must contain a list of singleton tensors"
)
if foreach:
func = _multi_tensor_adafactor
else:
func = _single_tensor_adafactor
func(
params,
grads,
row_vars,
col_vars,
variances,
state_steps,
d=d,
lr=lr,
beta2_decay=beta2_decay,
weight_decay=weight_decay,
eps1=eps1,
eps2=eps2,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)

View File

@ -0,0 +1,84 @@
# mypy: allow-untyped-defs
r"""Functional interface."""
import math
from typing import List
from torch import Tensor
from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401
from .adam import adam # type: ignore[attr-defined] # noqa: F401
from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
from .radam import radam # type: ignore[attr-defined] # noqa: F401
from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
# TODO: use foreach API in optim._functional to do all the computation
def sparse_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
maximize: bool,
):
r"""Functional API that performs Sparse Adam algorithm computation.
See :class:`~torch.optim.SparseAdam` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
if grad_values.numel() == 0:
# Skip update for empty grad
continue
size = grad.size()
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
def make_sparse(values):
constructor = grad.new
if grad_indices.dim() == 0 or values.dim() == 0:
return constructor().resize_as_(grad)
return constructor(grad_indices, values, size)
# Decay the first and second moment running average coefficient
# old <- b * old + (1 - b) * new
# <==> old += (1 - b) * (new - old)
old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
exp_avg.add_(make_sparse(exp_avg_update_values))
old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
exp_avg_sq_update_values = (
grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
)
exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
# Dense addition again is intended, avoiding another sparse_mask
numer = exp_avg_update_values.add_(old_exp_avg_values)
exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
denom = exp_avg_sq_update_values.sqrt_().add_(eps)
del exp_avg_update_values, exp_avg_sq_update_values
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr * math.sqrt(bias_correction2) / bias_correction1
param.add_(make_sparse(-step_size * numer.div_(denom)))

View File

@ -0,0 +1,30 @@
"""
:mod:`torch.optim._multi_tensor` is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
"""
from functools import partialmethod
from torch import optim
def partialclass(cls, *args, **kwargs): # noqa: D103
class NewCls(cls):
__init__ = partialmethod(cls.__init__, *args, **kwargs)
return NewCls
Adam = partialclass(optim.Adam, foreach=True)
AdamW = partialclass(optim.AdamW, foreach=True)
NAdam = partialclass(optim.NAdam, foreach=True)
SGD = partialclass(optim.SGD, foreach=True)
RAdam = partialclass(optim.RAdam, foreach=True)
RMSprop = partialclass(optim.RMSprop, foreach=True)
Rprop = partialclass(optim.Rprop, foreach=True)
ASGD = partialclass(optim.ASGD, foreach=True)
Adamax = partialclass(optim.Adamax, foreach=True)
Adadelta = partialclass(optim.Adadelta, foreach=True)
Adagrad = partialclass(optim.Adagrad, foreach=True)

View File

@ -0,0 +1,15 @@
from functools import partial
from torch import optim
Adam = partial(optim.Adam, foreach=True)
AdamW = partial(optim.AdamW, foreach=True)
NAdam = partial(optim.NAdam, foreach=True)
SGD = partial(optim.SGD, foreach=True)
RAdam = partial(optim.RAdam, foreach=True)
RMSprop = partial(optim.RMSprop, foreach=True)
Rprop = partial(optim.Rprop, foreach=True)
ASGD = partial(optim.ASGD, foreach=True)
Adamax = partial(optim.Adamax, foreach=True)
Adadelta = partial(optim.Adadelta, foreach=True)
Adagrad = partial(optim.Adagrad, foreach=True)

View File

@ -0,0 +1,461 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Any, cast, Dict, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["Adadelta", "adadelta"]
class Adadelta(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1.0,
rho: float = 0.9,
eps: float = 1e-6,
weight_decay: float = 0,
foreach: Optional[bool] = None,
*,
capturable: bool = False,
maximize: bool = False,
differentiable: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= rho <= 1.0:
raise ValueError(f"Invalid rho value: {rho}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
capturable=capturable,
foreach=foreach,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group: Dict[str, Any],
params_with_grad: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
):
has_complex = False
p: Tensor
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("Adadelta does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.zeros((), dtype=_get_scalar_dtype())
)
state["square_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["acc_delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
square_avgs.append(state["square_avg"])
acc_deltas.append(state["acc_delta"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
square_avgs: List[Tensor] = []
acc_deltas: List[Tensor] = []
state_steps: List[Tensor] = []
(
lr,
rho,
eps,
weight_decay,
foreach,
maximize,
differentiable,
capturable,
) = (
group["lr"],
group["rho"],
group["eps"],
group["weight_decay"],
group["foreach"],
group["maximize"],
group["differentiable"],
group["capturable"],
)
has_complex = self._init_group(
group, params_with_grad, grads, square_avgs, acc_deltas, state_steps
)
adadelta(
params_with_grad,
grads,
square_avgs,
acc_deltas,
state_steps,
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
return loss
Adadelta.__doc__ = (
r"""Implements Adadelta algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
\: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
\: \lambda \text{ (weight decay)} \\
&\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
\: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}if \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\
&\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} +
\epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\
&\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
\Delta x^2_t (1 - \rho) \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
rho (float, optional): coefficient used for computing a running average
of squared gradients (default: 0.9). A higher value of `rho` will
result in a slower average, which can be helpful for preventing
oscillations in the learning process.
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-6).
lr (float, Tensor, optional): coefficient that scale delta before it is applied
to the parameters (default: 1.0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_capturable_doc}
{_maximize_doc}
{_differentiable_doc}
.. _ADADELTA\: An Adaptive Learning Rate Method:
https://arxiv.org/abs/1212.5701
"""
)
def _single_tensor_adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
rho: float,
eps: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
for param, grad, square_avg, acc_delta, step in zip(
params, grads, square_avgs, acc_deltas, state_steps
):
step += 1
grad = grad if not maximize else -grad
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
square_avg = torch.view_as_real(square_avg)
acc_delta = torch.view_as_real(acc_delta)
grad = torch.view_as_real(grad)
square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_()
if differentiable:
delta = delta.clone()
delta.div_(std).mul_(grad)
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
if torch.is_complex(param):
delta = torch.view_as_complex(delta)
param.add_(delta, alpha=-lr)
def _multi_tensor_adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
rho: float,
eps: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
if len(params) == 0:
return
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, square_avgs, acc_deltas, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_square_avgs_,
device_acc_deltas_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_square_avgs = cast(List[Tensor], device_square_avgs_)
device_acc_deltas = cast(List[Tensor], device_acc_deltas_)
device_state_steps = cast(List[Tensor], device_state_steps_)
if has_complex:
_view_as_real(
device_params, device_grads, device_square_avgs, device_acc_deltas
)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
torch._foreach_mul_(device_square_avgs, rho)
torch._foreach_addcmul_(
device_square_avgs, device_grads, device_grads, value=1 - rho
)
std = torch._foreach_add(device_square_avgs, eps)
torch._foreach_sqrt_(std)
deltas = torch._foreach_add(device_acc_deltas, eps)
torch._foreach_sqrt_(deltas)
torch._foreach_div_(deltas, std)
torch._foreach_mul_(deltas, device_grads)
torch._foreach_mul_(device_acc_deltas, rho)
torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho)
# If LR is a tensor, the else branch will internally call item()
# which will cause silent incorrectness if we are capturing
if capturable and isinstance(lr, torch.Tensor):
torch._foreach_mul_(deltas, -lr)
torch._foreach_add_(device_params, deltas)
else:
torch._foreach_add_(device_params, deltas, alpha=-lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
def adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
capturable: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
has_complex: bool = False,
*,
lr: float,
rho: float,
eps: float,
weight_decay: float,
maximize: bool,
):
r"""Functional API that performs Adadelta algorithm computation.
See :class:`~torch.optim.Adadelta` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# We still respect when the user inputs False for foreach.
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adadelta
else:
func = _single_tensor_adadelta
func(
params,
grads,
square_avgs,
acc_deltas,
state_steps,
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -0,0 +1,564 @@
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["Adagrad", "adagrad"]
class Adagrad(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
lr_decay: float = 0,
weight_decay: float = 0,
initial_accumulator_value: float = 0,
eps: float = 1e-10,
foreach: Optional[bool] = None,
*,
maximize: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= lr_decay:
raise ValueError(f"Invalid lr_decay value: {lr_decay}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= initial_accumulator_value:
raise ValueError(
f"Invalid initial_accumulator_value value: {initial_accumulator_value}"
)
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
defaults = dict(
lr=lr,
lr_decay=lr_decay,
eps=eps,
weight_decay=weight_decay,
initial_accumulator_value=initial_accumulator_value,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
self._need_device_dtype_check_for_fused = True
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
init_value = (
complex(initial_accumulator_value, initial_accumulator_value)
if torch.is_complex(p)
else initial_accumulator_value
)
state["sum"] = torch.full_like(
p, init_value, memory_format=torch.preserve_format
)
def __setstate__(self, state):
super().__setstate__(state)
# define "fused" for
# MYPY error: Name "fused" may be undefined
fused = None
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(
float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused)
)
def share_memory(self):
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
state["sum"].share_memory_()
def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
has_sparse_grad, has_complex = False, False
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self,
"_need_device_dtype_check_for_fused",
True,
):
_device_dtype_check_for_fused(p, cuda_unsupported=True)
self._need_device_dtype_check_for_fused = False
has_sparse_grad |= p.grad.is_sparse
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
state_sums.append(state["sum"])
state_steps.append(state["step"])
return has_sparse_grad, has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
state_sums: List[Tensor] = []
state_steps: List[Tensor] = []
has_sparse_grad, has_complex = self._init_group(
group, params_with_grad, grads, state_sums, state_steps
)
adagrad(
params_with_grad,
grads,
state_sums,
state_steps,
lr=group["lr"],
weight_decay=group["weight_decay"],
lr_decay=group["lr_decay"],
eps=group["eps"],
has_sparse_grad=has_sparse_grad,
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
has_complex=has_complex,
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
return loss
Adagrad.__doc__ = (
r"""Implements Adagrad algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
&\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
&\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
&\hspace{5mm}\theta_t \leftarrow
\theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
and Stochastic Optimization`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-2)
lr_decay (float, optional): learning rate decay (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
initial_accumulator_value (float, optional): initial value of the
sum of squares of gradients (default: 0)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-10)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
fused (bool, optional): whether the fused implementation (CPU only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. (default: None). Please note that the fused implementations does not
support sparse or complex gradients.
.. _Adaptive Subgradient Methods for Online Learning and Stochastic
Optimization: http://jmlr.org/papers/v12/duchi11a.html
"""
)
def adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting these as kwargs for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
has_complex: bool = False,
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs Adagrad algorithm computation.
See :class:`~torch.optim.Adagrad` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if fused is None:
fused = False
if foreach is None:
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adagrad
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adagrad
else:
func = _single_tensor_adagrad
func(
params,
grads,
state_sums,
state_steps,
lr=lr,
weight_decay=weight_decay,
lr_decay=lr_decay,
eps=eps,
has_sparse_grad=has_sparse_grad,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _make_sparse(grad, grad_indices, values):
size = grad.size()
return torch.sparse_coo_tensor(grad_indices, values, size)
def _single_tensor_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float,
has_sparse_grad: bool,
maximize: bool,
differentiable: bool,
has_complex: bool,
):
assert grad_scale is None and found_inf is None
for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
# update step
step_t += 1
step = _get_value(step_t)
grad = grad if not maximize else -grad
if weight_decay != 0:
if grad.is_sparse:
raise RuntimeError(
"weight_decay option is not compatible with sparse gradients"
)
grad = grad.add(param, alpha=weight_decay)
clr = lr / (1 + (step - 1) * lr_decay)
if grad.is_sparse:
grad = grad.coalesce() # the update is non-linear so indices must be unique
grad_indices = grad._indices()
grad_values = grad._values()
state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
std = state_sum.sparse_mask(grad)
std_values = std._values().sqrt_().add_(eps)
param.add_(
_make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
)
else:
is_complex = torch.is_complex(param)
if is_complex:
grad = torch.view_as_real(grad)
state_sum = torch.view_as_real(state_sum)
param = torch.view_as_real(param)
state_sum.addcmul_(grad, grad, value=1)
if differentiable:
std = state_sum.sqrt() + eps
else:
std = state_sum.sqrt().add_(eps)
param.addcdiv_(grad, std, value=-clr)
if is_complex:
param = torch.view_as_complex(param)
state_sum = torch.view_as_complex(state_sum)
def _multi_tensor_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float,
has_sparse_grad: bool,
maximize: bool,
differentiable: bool,
has_complex: bool,
):
assert not differentiable, "_foreach ops don't support autograd"
assert grad_scale is None and found_inf is None
# Foreach functions will throw errors if given empty lists
if len(params) == 0:
return
grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, state_sums, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_state_sums_,
device_state_steps_,
), _ in grouped_tensorlists.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_sums = cast(List[Tensor], device_state_sums_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
)
if device_has_sparse_grad:
_single_tensor_adagrad(
device_params,
device_grads,
device_state_sums,
device_state_steps,
lr=lr,
weight_decay=weight_decay,
lr_decay=lr_decay,
eps=eps,
has_sparse_grad=True,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
grad_scale=grad_scale,
found_inf=found_inf,
)
continue
# Handle complex parameters
if has_complex:
_view_as_real(device_params, device_grads, device_state_sums)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
minus_clr = [
-lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps
]
torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
std = torch._foreach_sqrt(device_state_sums)
torch._foreach_add_(std, eps)
if weight_decay != 0 or maximize:
# Again, re-use the intermediate memory (device_grads) already allocated
torch._foreach_mul_(device_grads, minus_clr)
numerator = device_grads
else:
numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment]
torch._foreach_addcdiv_(device_params, numerator, std)
def _fused_adagrad(
params: List[Tensor],
grads: List[Tensor],
state_sums: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
lr: float,
weight_decay: float,
lr_decay: float,
eps: float,
has_sparse_grad: bool,
maximize: bool,
differentiable: bool,
has_complex: bool,
) -> None:
if not params:
return
if has_sparse_grad or has_complex:
raise RuntimeError("`fused` does not support sparse grad or complex param")
if differentiable:
raise RuntimeError(
"adagrad with fused=True does not support differentiable=True"
)
grad_scale_dict = (
{grad_scale.device: grad_scale} if grad_scale is not None else None
)
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, state_sums, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
device_params_,
device_grads_,
device_state_sums_,
device_state_steps_,
),
_,
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_state_sums = cast(List[Tensor], device_state_sums_)
device_state_steps = cast(List[Tensor], device_state_steps_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None and grad_scale_dict is not None:
if device not in grad_scale_dict:
grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index]
device_grad_scale = grad_scale_dict[device] # type: ignore[index]
if found_inf is not None and found_inf_dict is not None:
if found_inf not in found_inf_dict:
found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index]
device_found_inf = found_inf_dict[device] # type: ignore[index]
torch._foreach_add_(device_state_steps, 1)
torch._fused_adagrad_(
device_params,
device_grads,
device_state_sums,
device_state_steps,
lr=lr,
lr_decay=lr_decay,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)

View File

@ -0,0 +1,803 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
DeviceDict,
Optimizer,
ParamsT,
)
__all__ = ["Adam", "adam"]
class Adam(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
# TODO(crcrpar): [low prec params & their higher prec copy]
# Support AMP with FP16/BF16 model params which would need
# higher prec copy of params to do update math in higher prec to
# alleviate the loss of information.
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val,
dtype=_get_scalar_dtype(is_fused=fused),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["amsgrad"]:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and torch.is_tensor(group["lr"])
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)
adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group["amsgrad"],
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
return loss
Adam.__doc__ = (
r"""Implements Adam algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
\:\textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\textbf{if} \: amsgrad \\
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
\widehat{v_t}) \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
{_foreach_doc}
{_maximize_doc}
{_capturable_doc}
{_differentiable_doc}
{_fused_doc}
.. Note::
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
)
def _single_tensor_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
assert grad_scale is None and found_inf is None
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
param = torch.view_as_real(param)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
if capturable or differentiable:
step = step_t
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
step_size_neg = step_size.neg()
bias_correction2_sqrt = bias_correction2.sqrt()
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
param.addcdiv_(exp_avg, denom)
else:
step = _get_value(step_t)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2**0.5
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)
# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
def _multi_tensor_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
has_complex: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
):
if len(params) == 0:
return
if isinstance(lr, Tensor) and not capturable:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
assert grad_scale is None and found_inf is None
assert not differentiable, "_foreach ops don't support autograd"
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
# Handle complex parameters
if has_complex:
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
_view_as_real(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
)
else:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
)
# Delete the local intermediate since it won't be used anymore to save on peak memory
del device_grads
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
# we do not negate bias_correction1 as it'll need to be negated later anyway
torch._foreach_neg_(bias_correction2)
# foreach_div doesn't allow a scalar as the first arg
torch._foreach_div_(bias_correction1, lr)
torch._foreach_reciprocal_(bias_correction1)
torch._foreach_sqrt_(bias_correction2)
# Re-assign for clarity as we maintain minimal intermediates: we'll have
# step_size = - lr / (1 - beta1 ^ t) where t = num_steps
# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
step_size = bias_correction1
bias_correction2_sqrt = bias_correction2
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
# Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_div_(exp_avg_sq_sqrt, step_size)
# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
else:
bias_correction1 = [
1 - beta1 ** _get_value(step) for step in device_state_steps
]
bias_correction2 = [
1 - beta2 ** _get_value(step) for step in device_state_steps
]
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
)
def _fused_adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
has_complex: bool, # Needed for consistency.
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool, # Needed for consistency.
differentiable: bool,
) -> None:
if not params:
return
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict: Optional[DeviceDict] = (
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs,
device_state_steps_,
),
_,
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
if device.type == "mps": # type: ignore[union-attr]
assert found_inf is None and grad_scale is None
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device, non_blocking=True)
)
if found_inf is not None:
device_found_inf = found_inf_dict.setdefault(
device, found_inf.to(device, non_blocking=True)
)
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
torch._fused_adam_(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs, # type: ignore[arg-type]
device_state_steps,
amsgrad=amsgrad,
lr=lr, # type: ignore[arg-type]
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
def adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adam
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adam
else:
func = _single_tensor_adam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
)

View File

@ -0,0 +1,473 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["Adamax", "adamax"]
class Adamax(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
foreach: Optional[bool] = None,
*,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("Adamax does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_inf"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_infs.append(state["exp_inf"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_infs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]
eps = group["eps"]
lr = group["lr"]
weight_decay = group["weight_decay"]
foreach = group["foreach"]
maximize = group["maximize"]
differentiable = group["differentiable"]
capturable = group["capturable"]
has_complex = self._init_group(
group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
)
adamax(
params_with_grad,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
return loss
Adamax.__doc__ = (
r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
\: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}if \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 2e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
{_capturable_doc}
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
"""
)
def _single_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
exp_avg = exp_avgs[i]
exp_inf = exp_infs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param):
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_inf = torch.view_as_real(exp_inf)
# Update biased first moment estimate.
exp_avg.lerp_(grad, 1 - beta1)
# Update the exponentially weighted infinity norm.
if not differentiable:
torch.maximum(
exp_inf.mul_(beta2),
grad.abs().add_(eps),
out=exp_inf,
)
else:
norm_buf = torch.cat(
[exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)],
0,
)
exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False))
if capturable:
# why jump through extra hoops and negate bias_correction? check out #121238
# once fixed, we should use bias_correction with addcdiv value=-1 for readability
neg_bias_correction = beta1**step_t - 1
neg_bias_correction.div_(lr)
denom = exp_inf * neg_bias_correction
param.addcdiv_(exp_avg, denom)
else:
bias_correction = 1 - beta1 ** _get_value(step_t)
clr = lr / bias_correction
param.addcdiv_(exp_avg, exp_inf, value=-clr)
def _multi_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
assert not differentiable, "_foreach ops don't support autograd"
if len(params) == 0:
return
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item]
)
for (
grouped_params_,
grouped_grads_,
grouped_exp_avgs_,
grouped_exp_infs_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
if has_complex:
_view_as_real(
grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs
)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
if weight_decay != 0:
if maximize:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
else:
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)
# Update biased first moment estimate.
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
# Update the exponentially weighted infinity norm.
torch._foreach_mul_(grouped_exp_infs, beta2)
# in this case, we need to introduce a copy of the grads
# since one has not been introduced previously
if not maximize and weight_decay == 0:
grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment]
else:
torch._foreach_abs_(grouped_grads)
torch._foreach_add_(grouped_grads, eps)
torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_corrections, 1)
torch._foreach_div_(bias_corrections, lr)
denom = torch._foreach_mul(grouped_exp_infs, bias_corrections)
torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom)
else:
bias_corrections = [
1 - beta1 ** _get_value(step) for step in grouped_state_steps
]
step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections]
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
def adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs adamax algorithm computation.
See :class:`~torch.optim.Adamax` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamax
else:
func = _single_tensor_adamax
func(
params,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
capturable=capturable,
)

View File

@ -0,0 +1,801 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
DeviceDict,
Optimizer,
ParamsT,
)
__all__ = ["AdamW", "adamw"]
class AdamW(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
foreach=foreach,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)
if fused:
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
self._step_supports_amp_scaling = True
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val,
dtype=_get_scalar_dtype(is_fused=fused),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
if group["fused"]:
_device_dtype_check_for_fused(p)
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros(
(),
dtype=_get_scalar_dtype(is_fused=group["fused"]),
device=p.device,
)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
if group["amsgrad"]:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
)
# Foreach without capturable does not support a tensor lr
if (
group["foreach"]
and isinstance(group["lr"], Tensor)
and not group["capturable"]
):
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
max_exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
amsgrad: bool = group["amsgrad"]
beta1, beta2 = cast(Tuple[float, float], group["betas"])
has_complex = self._init_group(
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)
adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
has_complex=has_complex,
)
return loss
AdamW.__doc__ = (
r"""Implements AdamW algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
\: \epsilon \text{ (epsilon)} \\
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
\: \textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\textbf{if} \: amsgrad \\
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
\widehat{v_t}) \\
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (bool, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
{_maximize_doc}
{_foreach_doc}
{_capturable_doc}
{_differentiable_doc}
{_fused_doc}
.. Note::
A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
)
def _single_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
assert grad_scale is None and found_inf is None
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
param = torch.view_as_real(param)
# update step
step_t += 1
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if capturable or differentiable:
step = step_t
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
step_size_neg = step_size.neg()
bias_correction2_sqrt = bias_correction2.sqrt()
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
param.addcdiv_(exp_avg, denom)
else:
step = _get_value(step_t)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
step_size = lr / bias_correction1
bias_correction2_sqrt = bias_correction2**0.5
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size)
# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
def _multi_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
if len(params) == 0:
return
if isinstance(lr, Tensor) and not capturable:
raise RuntimeError(
"lr as a Tensor is not supported for capturable=False and foreach=True"
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
assert not differentiable, "_foreach ops don't support autograd"
assert grad_scale is None and found_inf is None
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs_,
device_state_steps_,
), _ in grouped_tensors.values():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
if has_complex:
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
_view_as_real(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
)
else:
_view_as_real(
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
# Perform stepweight decay
if weight_decay != 0:
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, 1 - beta2
)
# Delete the local intermediate since it won't be used anymore to save on peak memory
del device_grads
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
# we do not negate bias_correction1 as it'll need to be negated later anyway
torch._foreach_neg_(bias_correction2)
# foreach_div doesn't allow a scalar as the first arg
torch._foreach_div_(bias_correction1, lr)
torch._foreach_reciprocal_(bias_correction1)
torch._foreach_sqrt_(bias_correction2)
# Re-assign for clarity as we maintain minimal intermediates: we'll have
# step_size = - lr / (1 - beta1 ^ t) where t = num_steps
# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
step_size = bias_correction1
bias_correction2_sqrt = bias_correction2
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_div_(exp_avg_sq_sqrt, step_size)
# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
else:
bias_correction1 = [
1 - beta1 ** _get_value(step) for step in device_state_steps
]
bias_correction2 = [
1 - beta2 ** _get_value(step) for step in device_state_steps
]
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
bias_correction2_sqrt = [
bc**0.5 for bc in bias_correction2 # type: ignore[arg-type]
]
if amsgrad:
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params,
device_exp_avgs,
exp_avg_sq_sqrt,
step_size, # type: ignore[arg-type]
)
def _fused_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool, # Needed for consistency.
differentiable: bool,
has_complex: bool, # Needed for consistency.
) -> None:
if not params:
return
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict: Optional[DeviceDict] = (
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
device_params_,
device_grads_,
device_exp_avgs_,
device_exp_avg_sqs_,
device_max_exp_avg_sqs,
device_state_steps_,
),
_,
) in grouped_tensors.items():
device_params = cast(List[Tensor], device_params_)
device_grads = cast(List[Tensor], device_grads_)
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
device_state_steps = cast(List[Tensor], device_state_steps_)
if device.type == "mps": # type: ignore[union-attr]
assert found_inf is None and grad_scale is None
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device, non_blocking=True)
)
if found_inf is not None:
device_found_inf = found_inf_dict.setdefault(
device, found_inf.to(device, non_blocking=True)
)
if lr_dict is not None and device not in lr_dict:
lr = lr_dict.setdefault(
device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
)
torch._foreach_add_(device_state_steps, 1)
torch._fused_adamw_(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs, # type: ignore[arg-type]
device_state_steps,
amsgrad=amsgrad,
lr=lr, # type: ignore[arg-type]
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw)
def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adamw
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamw
else:
func = _single_tensor_adamw
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)

View File

@ -0,0 +1,465 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["ASGD", "asgd"]
class ASGD(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
lambd: float = 1e-4,
alpha: float = 0.75,
t0: float = 1e6,
weight_decay: float = 0,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
lambd=lambd,
alpha=alpha,
t0=t0,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0:
if not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if not torch.is_tensor(p_state["eta"]):
p_state["eta"] = torch.tensor(
p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
)
if not torch.is_tensor(p_state["mu"]):
p_state["mu"] = torch.tensor(
p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
)
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("ASGD does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = torch.zeros(
(), device=p.device, dtype=_get_scalar_dtype()
)
state["eta"] = (
torch.as_tensor(
group["lr"], device=p.device, dtype=_get_scalar_dtype()
)
.clone()
.detach()
)
state["mu"] = torch.ones(
(), device=p.device, dtype=_get_scalar_dtype()
)
state["ax"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
mus.append(state["mu"])
axs.append(state["ax"])
etas.append(state["eta"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
mus: List[Tensor] = []
axs: List[Tensor] = []
etas: List[Tensor] = []
state_steps: List[Tensor] = []
has_complex = self._init_group(
group, params_with_grad, grads, mus, axs, etas, state_steps
)
asgd(
params_with_grad,
grads,
axs,
mus,
etas,
state_steps,
lambd=group["lambd"],
lr=group["lr"],
t0=group["t0"],
alpha=group["alpha"],
weight_decay=group["weight_decay"],
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
has_complex=has_complex,
)
return loss
ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
It has been proposed in `Acceleration of stochastic approximation by
averaging`_.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-2)
lambd (float, optional): decay term (default: 1e-4)
alpha (float, optional): power for eta update (default: 0.75)
t0 (float, optional): point at which to start averaging (default: 1e6)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
{_capturable_doc}
.. _Acceleration of stochastic approximation by averaging:
https://dl.acm.org/citation.cfm?id=131098
"""
def _single_tensor_asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
mu = mus[i]
ax = axs[i]
eta = etas[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type
== mu.device.type
== eta.device.type
== step_t.device.type
and param.device.type in capturable_supported_devices
), (
f"If capturable=True, params, mus, etas, and state_steps must be "
f"on supported devices: {capturable_supported_devices}."
)
if torch.is_complex(param):
grad = torch.view_as_real(grad)
param = torch.view_as_real(param)
ax = torch.view_as_real(ax)
# update step
step_t += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if capturable:
param.mul_(1 - lambd * eta)
param.addcmul_(grad, eta, value=-1) # update parameter
else:
eta_value = _get_value(eta)
param.mul_(1 - lambd * eta_value) # decay term
param.add_(grad, alpha=-eta_value) # update parameter
# averaging
if capturable or mu.item() != 1:
ax.add_(param.sub(ax).mul_(mu))
else:
ax.copy_(param)
if capturable:
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:
step = _get_value(step_t)
new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
eta.copy_(new_eta)
new_mu = torch.as_tensor(1 / max(1, step - t0))
mu.copy_(new_mu)
def _multi_tensor_asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == mu.device.type == eta.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mu, eta, step in zip(params, mus, etas, state_steps)
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, axs, mus, etas, state_steps] # type: ignore[list-item]
)
for (device, _), (
(
grouped_params_,
grouped_grads_,
grouped_axs_,
grouped_mus_,
grouped_etas_,
grouped_state_steps_,
),
_,
) in grouped_tensors.items():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_axs = cast(List[Tensor], grouped_axs_)
grouped_mus = cast(List[Tensor], grouped_mus_)
grouped_etas = cast(List[Tensor], grouped_etas_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_axs)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
# intermediate = grad + param * lambd
intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
if weight_decay != 0:
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
intermediate = grouped_grads
else:
intermediate = torch._foreach_add(
grouped_grads, grouped_params, alpha=weight_decay
)
torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
else:
intermediate = torch._foreach_add(
grouped_grads, grouped_params, alpha=lambd
)
# update param
# param * (1 - lambd * eta) - eta * grad
# => param - param * lambd * eta - eta * grad
# => param - eta * intermediate
torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
del intermediate
# update grouped_axs
# averaging: ax = ax + mu * (param - ax)
# Note (mlazos): We can't use lerp here since it requires weight to be float64
# and our grouping code requires dtypes to match for all tensors in a group (and it should, since
# we use the mus in other places)
# all dtypes need to match, so we could introduce a cast in a loop
# but since this only adds one additional kernel launch, this looks like the cleaner
# and faster solution
intermediate = torch._foreach_sub(grouped_params, grouped_axs)
torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
del intermediate
new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
# update grouped_mus
new_mus = torch._foreach_sub(grouped_state_steps, t0)
torch._foreach_maximum_(new_mus, 1.0)
torch._foreach_reciprocal_(new_mus)
torch._foreach_copy_(grouped_mus, new_mus)
del new_mus
# update eta = lr / ((1 + lambd * lr * step)^alpha)
new_etas = torch._foreach_mul(grouped_state_steps, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_pow_(new_etas, alpha)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
new_etas = [
torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
for step in grouped_state_steps
]
new_mus = [
torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
for step in grouped_state_steps
]
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
def asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
):
r"""Functional API that performs asgd algorithm computation.
See :class:`~torch.optim.ASGD` for details.
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
func = _single_tensor_asgd
func(
params,
grads,
axs,
mus,
etas,
state_steps,
lambd=lambd,
lr=lr,
t0=t0,
alpha=alpha,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -0,0 +1,495 @@
# mypy: allow-untyped-defs
from typing import Optional, Union
import torch
from torch import Tensor
from .optimizer import Optimizer, ParamsT
__all__ = ["LBFGS"]
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1**2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
else:
return (xmin_bound + xmax_bound) / 2.0
def _strong_wolfe(
obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone(memory_format=torch.contiguous_format)
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)
# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break
if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(
t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
)
# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone(memory_format=torch.contiguous_format)
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]
# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined]
while not done and ls_iter < max_ls:
# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined]
break
# compute new trial value
t = _cubic_interpolate(
bracket[0],
bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined]
bracket[1],
bracket_f[1],
bracket_gtd[1],
)
# test that we are making sufficient progress:
# in case `t` is so close to boundary, we mark that we are making
# insufficient progress, and if
# + we have made insufficient progress in the last step, or
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False
# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
bracket_gtd[low_pos] = gtd_new
# return stuff
t = bracket[low_pos] # type: ignore[possibly-undefined]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
return f_new, g_new, t, ls_func_evals
class LBFGS(Optimizer):
"""Implements L-BFGS algorithm.
Heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
.. warning::
This optimizer doesn't support per-parameter options and parameter
groups (there can be only one).
.. warning::
Right now all parameters have to be on a single device. This will be
improved in the future.
.. note::
This is a very memory intensive optimizer (it requires additional
``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
try reducing the history size, or use a different algorithm.
Args:
params (iterable): iterable of parameters to optimize. Parameters must be real.
lr (float): learning rate (default: 1)
max_iter (int): maximal number of iterations per optimization step
(default: 20)
max_eval (int): maximal number of function evaluations per optimization
step (default: max_iter * 1.25).
tolerance_grad (float): termination tolerance on first order optimality
(default: 1e-7).
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_wolfe' or None (default: None).
"""
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1,
max_iter: int = 20,
max_eval: Optional[int] = None,
tolerance_grad: float = 1e-7,
tolerance_change: float = 1e-9,
history_size: int = 100,
line_search_fn: Optional[str] = None,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn,
)
super().__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError(
"LBFGS doesn't support per-parameter options " "(parameter groups)"
)
self._params = self.param_groups[0]["params"]
self._numel_cache = None
def _numel(self):
if self._numel_cache is None:
self._numel_cache = sum(
2 * p.numel() if torch.is_complex(p) else p.numel()
for p in self._params
)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
if torch.is_complex(view):
view = torch.view_as_real(view).view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
if torch.is_complex(p):
p = torch.view_as_real(p)
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
offset += numel
assert offset == self._numel()
def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
@torch.no_grad()
def step(self, closure):
"""Perform a single optimization step.
Args:
closure (Callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
group = self.param_groups[0]
lr = group["lr"]
max_iter = group["max_iter"]
max_eval = group["max_eval"]
tolerance_grad = group["tolerance_grad"]
tolerance_change = group["tolerance_change"]
line_search_fn = group["line_search_fn"]
history_size = group["history_size"]
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault("func_evals", 0)
state.setdefault("n_iter", 0)
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state["func_evals"] += 1
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
# optimal condition
if opt_cond:
return orig_loss
# tensors cached in state (for tracing)
d = state.get("d")
t = state.get("t")
old_dirs = state.get("old_dirs")
old_stps = state.get("old_stps")
ro = state.get("ro")
H_diag = state.get("H_diag")
prev_flat_grad = state.get("prev_flat_grad")
prev_loss = state.get("prev_loss")
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state["n_iter"] += 1
############################################################
# compute gradient descent direction
############################################################
if state["n_iter"] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)
# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1.0 / ys)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if "al" not in state:
state["al"] = [None] * history_size
al = state["al"]
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if state["n_iter"] == 1:
t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
else:
t = lr
# directional derivative
gtd = flat_grad.dot(d) # g * d
# directional derivative is below tolerance
if gtd > -tolerance_change:
break
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
if line_search_fn != "strong_wolfe":
raise RuntimeError("only 'strong_wolfe' is supported")
else:
x_init = self._clone_param()
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd
)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
with torch.enable_grad():
loss = float(closure())
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state["func_evals"] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
# optimal condition
if opt_cond:
break
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state["d"] = d
state["t"] = t
state["old_dirs"] = old_dirs
state["old_stps"] = old_stps
state["ro"] = ro
state["H_diag"] = H_diag
state["prev_flat_grad"] = prev_flat_grad
state["prev_loss"] = prev_loss
return orig_loss

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,649 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["NAdam", "nadam"]
class NAdam(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
momentum_decay: float = 4e-3,
decoupled_weight_decay: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= momentum_decay:
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
decoupled_weight_decay=decoupled_weight_decay,
maximize=maximize,
foreach=foreach,
capturable=capturable,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
group.setdefault("decoupled_weight_decay", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0:
if not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
if not torch.is_tensor(p_state["mu_product"]):
mu_prod_val = p_state["mu_product"]
p_state["mu_product"] = (
torch.tensor(
mu_prod_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("NAdam does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
# note(crcrpar): [special device hosting for step]
# Deliberately host `step` and `mu_product` on CPU if capturable is False.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
state["mu_product"] = (
torch.ones((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.tensor(1.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
mu_products.append(state["mu_product"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
mu_products: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"])
has_complex = self._init_group(
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
)
nadam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
momentum_decay=group["momentum_decay"],
eps=group["eps"],
maximize=group["maximize"],
decoupled_weight_decay=group["decoupled_weight_decay"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
has_complex=has_complex,
)
return loss
NAdam.__doc__ = (
r"""Implements NAdam algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
&\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
&\hspace{13mm} \: \textit{decoupled\_weight\_decay}, \:\textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} \\
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\
&\hspace{15mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{10mm}\textbf{else} \\
&\hspace{15mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
&\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
& \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 2e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
decoupled_weight_decay (bool, optional): whether to use decoupled weight
decay as in AdamW to obtain NAdamW (default: False)
{_foreach_doc}
{_maximize_doc}
{_capturable_doc}
{_differentiable_doc}
.. _Incorporating Nesterov Momentum into Adam:
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
"""
)
def _single_tensor_nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float,
decoupled_weight_decay: bool,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
mu_product = mu_products[i]
step_t = state_steps[i]
if torch.is_complex(param):
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == mu_product.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), (
f"If capturable=True, params, mu_products and state_steps must be "
f"on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
if capturable:
step = step_t
else:
step = _get_value(step_t)
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if decoupled_weight_decay:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# calculate the momentum cache \mu^{t} and \mu^{t+1}
mu = beta1 * (1.0 - 0.5 * (0.96 ** (step * momentum_decay)))
mu_next = beta1 * (1.0 - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
# update mu_product
mu_product *= mu
# decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.div(bias_correction2).sqrt()
if differentiable or capturable:
denom = denom.add(eps)
# Make autograd track the operations
# by updating the grad and exp_avg directly and not using the
# scalar "value" argument of addcdiv.
mu_product_next = mu_product * mu_next
grad = grad * (-lr * (1.0 - mu) / (1.0 - mu_product))
exp_avg = exp_avg * (-lr * mu_next / (1.0 - mu_product_next))
param.addcdiv_(grad, denom)
param.addcdiv_(exp_avg, denom)
else:
mu_product_next = _get_value(mu_product) * mu_next
denom.add_(eps)
param.addcdiv_(
grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
)
param.addcdiv_(
exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
)
def _multi_tensor_nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float,
decoupled_weight_decay: bool,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == mp.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mp, step in zip(params, mu_products, state_steps)
), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps] # type: ignore[list-item]
)
for (
grouped_params_,
grouped_grads_,
grouped_exp_avgs_,
grouped_exp_avg_sqs_,
grouped_mu_products_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
grouped_mu_products = cast(List[Tensor], grouped_mu_products_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
# handle complex
if has_complex:
_view_as_real(
grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
if weight_decay != 0:
if decoupled_weight_decay:
# Perform stepweight decay
torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
else:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(
grouped_grads, grouped_params, alpha=weight_decay
)
else:
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
)
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
mus: Union[Tuple[Tensor, ...], List[Tensor]]
mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
# mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
mus = torch._foreach_pow(0.96, exponent)
torch._foreach_mul_(mus, -0.5)
torch._foreach_add_(mus, 1.0)
torch._foreach_mul_(mus, beta1)
# mu_nexts will be beta1 * (1 - 0.5 * 0.96 ** ((step + 1) * momentum_decay))
torch._foreach_add_(exponent, momentum_decay)
mu_nexts = torch._foreach_pow(0.96, exponent)
torch._foreach_mul_(mu_nexts, -0.5)
torch._foreach_add_(mu_nexts, 1.0)
torch._foreach_mul_(mu_nexts, beta1)
# save peak memory as we don't need exponent anymore
del exponent
bias_correction_sqrt = torch._foreach_pow(beta2, grouped_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction_sqrt, 1.0)
torch._foreach_neg_(bias_correction_sqrt)
torch._foreach_sqrt_(bias_correction_sqrt)
else:
bias_correction_sqrt = [
(1 - beta2 ** _get_value(step)) ** 0.5 for step in grouped_state_steps
]
mus = [
beta1 * (1.0 - 0.5 * (0.96 ** (_get_value(step) * momentum_decay)))
for step in grouped_state_steps
]
mu_nexts = [
beta1
* (1.0 - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
for step in grouped_state_steps
]
# update mu_products
torch._foreach_mul_(grouped_mu_products, mus)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
# explicitly delete bias_correction refs to save memory
del bias_correction_sqrt
if capturable:
# Build up the step_size multiplier for grad, reusing mus' memory
torch._foreach_sub_(mus, 1.0)
torch._foreach_mul_(mus, lr)
# foreach_sub doesn't allow a scalar as the first arg
denom = torch._foreach_sub(grouped_mu_products, 1.0)
torch._foreach_neg_(denom)
torch._foreach_div_(mus, denom)
# - lr * (1 - mu) / (1 - mu_product)
step_size_grads = mus
# explicitly delete denom to save memory
del denom
# Build up the step_size multiplier for exp_avg, reusing mu_nexts' memory
denom = torch._foreach_mul(grouped_mu_products, mu_nexts)
torch._foreach_mul_(mu_nexts, lr)
# foreach_sub doesn't allow a scalar as the first arg, but it's okay because
# we need a negative here anyway
torch._foreach_sub_(denom, 1.0)
torch._foreach_div_(mu_nexts, denom)
# - lr * mu_next / (1 - mu_product * mu_next)
step_size_expavg = mu_nexts
# explicitly delete denom to save memory
del denom
# we cannot inplace into step_size_grads cuz it is a list of ScalarTensors
# and mul'ing with grouped_grads will result in a list of bigger Tensors
numerator = torch._foreach_mul(step_size_grads, grouped_grads)
torch._foreach_addcmul_(numerator, step_size_expavg, grouped_exp_avgs)
# finally, update params
torch._foreach_addcdiv_(grouped_params, numerator, exp_avg_sq_sqrt)
else:
step_size_grads = _stack_if_compiling(
[
(_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
for mu_product, mu in zip(grouped_mu_products, mus)
]
)
step_size_expavg = _stack_if_compiling(
[
(
_get_value(lr)
* mu_next
/ (1.0 - _get_value(mu_product) * mu_next)
)
* -1
for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)
]
)
torch._foreach_addcdiv_(
grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type]
)
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type]
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
def nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
has_complex: bool = False,
maximize: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float,
):
r"""Functional API that performs NAdam algorithm computation.
See :class:`~torch.optim.NAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if not all(isinstance(t, torch.Tensor) for t in mu_products):
raise RuntimeError(
"API has changed, `mu_products` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_nadam
else:
func = _single_tensor_nadam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
maximize=maximize,
decoupled_weight_decay=decoupled_weight_decay,
eps=eps,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,608 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_get_value,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["RAdam", "radam"]
class RAdam(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
decoupled_weight_decay: bool = False,
*,
foreach: Optional[bool] = None,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
foreach=foreach,
capturable=capturable,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("decoupled_weight_decay", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
):
has_complex = False
for p in group["params"]:
if p.grad is not None:
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("RAdam does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"])
has_complex = self._init_group(
group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
)
radam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
decoupled_weight_decay=group["decoupled_weight_decay"],
has_complex=has_complex,
)
return loss
RAdam.__doc__ = (
r"""Implements RAdam algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
\text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
\lambda \text{ (weightdecay)}, \:\textit{maximize} \\
&\hspace{13mm} \epsilon \text{ (epsilon)}, \textit{decoupled\_weight\_decay} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0 \leftarrow 0 \text{ ( second moment)}, \\
&\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{6mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{12mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{6mm}\textbf{else} \\
&\hspace{12mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{6mm} \theta_t \leftarrow \theta_{t-1} \\
&\hspace{6mm} \textbf{if} \: \lambda \neq 0 \\
&\hspace{12mm}\textbf{if} \: \textit{decoupled\_weight\_decay} \\
&\hspace{18mm} \theta_t \leftarrow \theta_{t} - \gamma \lambda \theta_{t} \\
&\hspace{12mm}\textbf{else} \\
&\hspace{18mm} g_t \leftarrow g_t + \lambda \theta_{t} \\
&\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
&\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
&\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\
&\hspace{12mm} r_t \leftarrow
\sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
&\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} r_t l_t \\
&\hspace{6mm}\textbf{else} \\
&\hspace{12mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
This implementation provides an option to use either the original weight_decay implementation as in Adam
(where the weight_decay is applied to the gradient) or the one from AdamW (where weight_decay is applied
to the weight) through the decoupled_weight_decay option. When decoupled_weight_decay is set to False
(default), it uses the original Adam style weight decay, otherwise, it uses the AdamW style which
corresponds more closely to the `author's implementation`_ in the RAdam paper. Further information
about decoupled weight decay can be found in `Decoupled Weight Decay Regularization`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
decoupled_weight_decay (bool, optional): whether to use decoupled weight
decay as in AdamW to obtain RAdamW (default: False)
{_foreach_doc}
{_maximize_doc}
{_differentiable_doc}
{_capturable_doc}
.. _On the variance of the adaptive learning rate and beyond:
https://arxiv.org/abs/1908.03265
.. _author's implementation:
https://github.com/LiyuanLucasLiu/RAdam
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
"""
)
def _single_tensor_radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
decoupled_weight_decay: bool,
differentiable: bool,
maximize: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
if torch.is_complex(param):
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
# update step
step_t += 1
step = step_t if capturable else _get_value(step_t)
if weight_decay != 0:
if decoupled_weight_decay:
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
# correcting bias for the first moving moment
bias_corrected_exp_avg = exp_avg / bias_correction1
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
def _compute_rect():
return (
(rho_t - 4)
* (rho_t - 2)
* rho_inf
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
) ** 0.5
def _compute_adaptive_lr():
exp_avg_sq_sqrt = exp_avg_sq.sqrt()
if differentiable:
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps)
else:
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
return (bias_correction2**0.5) / exp_avg_sq_sqrt
# Compute the variance rectification term and update parameters accordingly
if capturable:
update = torch.where(
rho_t > 5.0, _compute_rect() * _compute_adaptive_lr(), 1.0
)
param.add_(bias_corrected_exp_avg * lr * update, alpha=-1.0)
else:
if rho_t > 5.0:
param.add_(
bias_corrected_exp_avg
* lr
* _compute_adaptive_lr()
* _compute_rect(),
alpha=-1.0,
)
else:
param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
def _multi_tensor_radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
decoupled_weight_decay: bool,
differentiable: bool,
maximize: bool,
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
)
for (
grouped_params_,
grouped_grads_,
grouped_exp_avgs_,
grouped_exp_avg_sqs_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
if has_complex:
_view_as_real(
grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs
)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]]
if capturable:
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction1)
torch._foreach_add_(bias_correction1, 1)
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_mul_(bias_correction2, grouped_state_steps)
torch._foreach_mul_(bias_correction2, 2)
torch._foreach_div_(bias_correction2, bias_correction1)
torch._foreach_neg_(bias_correction2)
torch._foreach_add_(bias_correction2, rho_inf)
rho_t_list = bias_correction2
else:
rho_t_list = [
rho_inf
- 2
* _get_value(step)
* (beta2 ** _get_value(step))
/ (1 - beta2 ** _get_value(step))
for step in grouped_state_steps
]
if weight_decay != 0:
if decoupled_weight_decay:
torch._foreach_mul_(grouped_params, 1 - lr * weight_decay)
else:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(
grouped_grads, grouped_params, alpha=weight_decay
)
else:
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2
)
# Delete the local intermediate since it won't be used anymore to save on peak memory
del grouped_grads
if capturable:
num = torch._foreach_sub(rho_t_list, 4)
sub2 = torch._foreach_sub(rho_t_list, 2)
torch._foreach_mul_(num, sub2)
del sub2
torch._foreach_mul_(num, rho_inf)
rho_inf = (rho_inf - 4) * (rho_inf - 2)
denom = torch._foreach_mul(rho_t_list, rho_inf)
torch._foreach_div_(num, denom)
del denom
torch._foreach_sqrt_(num)
# TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
rect = [
torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)
]
del num
del rho_t_list
unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
torch._foreach_mul_(unrect_step_size, lr)
bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps)
torch._foreach_neg_(bias_correction1)
torch._foreach_add_(bias_correction1, 1)
torch._foreach_div_(unrect_step_size, bias_correction1)
torch._foreach_neg_(unrect_step_size)
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction2)
torch._foreach_add_(bias_correction2, 1)
torch._foreach_sqrt_(bias_correction2)
torch._foreach_mul_(bias_correction2, lr)
torch._foreach_mul_(bias_correction2, rect)
del rect
torch._foreach_neg_(bias_correction2)
torch._foreach_div_(bias_correction2, bias_correction1)
del bias_correction1
else:
rect = [
(
(rho_t - 4) # type: ignore[arg-type]
* (rho_t - 2)
* rho_inf
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
)
** 0.5
if rho_t > 5
else 0
for rho_t in rho_t_list
]
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
bias_correction1 = [
1 - beta1 ** _get_value(step) for step in grouped_state_steps
]
unrect_step_size = [
(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)
]
bias_correction2 = [
((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
]
buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
torch._foreach_add_(buffer, eps)
torch._foreach_div_(buffer, bias_correction2)
torch._foreach_reciprocal_(buffer)
torch._foreach_add_(buffer, unrect_step_size)
# Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size
torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
def radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
maximize: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
):
r"""Functional API that performs RAdam algorithm computation.
See :class:`~torch.optim.RAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_radam
else:
func = _single_tensor_radam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -0,0 +1,528 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["RMSprop", "rmsprop"]
class RMSprop(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
weight_decay: float = 0,
momentum: float = 0,
centered=False,
capturable=False,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= momentum:
raise ValueError(f"Invalid momentum value: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict(
lr=lr,
momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
capturable=capturable,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
)
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(
self,
group,
params_with_grad,
grads,
square_avgs,
momentum_buffer_list,
grad_avgs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("RMSprop does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.zeros((), dtype=_get_scalar_dtype())
)
state["square_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
square_avgs.append(state["square_avg"])
state_steps.append(state["step"])
if group["momentum"] > 0:
momentum_buffer_list.append(state["momentum_buffer"])
if group["centered"]:
grad_avgs.append(state["grad_avg"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
square_avgs: List[Tensor] = []
grad_avgs: List[Tensor] = []
momentum_buffer_list: List[Tensor] = []
state_steps: List[Tensor] = []
has_complex = self._init_group(
group,
params_with_grad,
grads,
square_avgs,
momentum_buffer_list,
grad_avgs,
state_steps,
)
rmsprop(
params_with_grad,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=group["lr"],
alpha=group["alpha"],
eps=group["eps"],
weight_decay=group["weight_decay"],
momentum=group["momentum"],
centered=group["centered"],
foreach=group["foreach"],
maximize=group["maximize"],
differentiable=group["differentiable"],
capturable=group["capturable"],
has_complex=has_complex,
)
return loss
RMSprop.__doc__ = (
r"""Implements RMSprop algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
\: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
&\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
&\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
\textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}if \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
\hspace{8mm} \\
&\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
&\hspace{5mm}if \: centered \\
&\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
&\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
&\hspace{5mm}if \: \mu > 0 \\
&\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
&\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
&\hspace{5mm} else \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
\gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to
`lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
and centered version `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
The implementation here takes the square root of the gradient average before
adding epsilon (note that TensorFlow interchanges these two operations). The effective
learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
is the scheduled learning rate and :math:`v` is the weighted moving average
of the squared gradient.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing constant (default: 0.99)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
{_foreach_doc}
{_maximize_doc}
{_capturable_doc}
{_differentiable_doc}
"""
)
def _single_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
step = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grad = grads[i]
grad = grad if not maximize else -grad
square_avg = square_avgs[i]
step += 1
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
is_complex_param = torch.is_complex(param)
if is_complex_param:
param = torch.view_as_real(param)
grad = torch.view_as_real(grad)
square_avg = torch.view_as_real(square_avg)
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if centered:
grad_avg = grad_avgs[i]
if is_complex_param:
grad_avg = torch.view_as_real(grad_avg)
grad_avg.lerp_(grad, 1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
else:
avg = square_avg.sqrt()
if differentiable:
avg = avg.add(eps)
else:
avg = avg.add_(eps)
if momentum > 0:
buf = momentum_buffer_list[i]
if is_complex_param:
buf = torch.view_as_real(buf)
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
param.addcdiv_(grad, avg, value=-lr)
def _multi_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
maximize: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item]
)
for (
(
grouped_params_,
grouped_grads_,
grouped_square_avgs_,
grouped_grad_avgs_,
grouped_momentum_buffer_list_,
grouped_state_steps_,
)
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
if has_complex:
state_and_grads = [grouped_grads, grouped_square_avgs]
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
)
state_and_grads.append(grouped_momentum_buffer_list)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
state_and_grads.append(grouped_grad_avgs)
_view_as_real(grouped_params, *state_and_grads)
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
if weight_decay != 0:
# Re-use the intermediate memory (grouped_grads) already allocated for maximize
if maximize:
torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
else:
grouped_grads = torch._foreach_add( # type: ignore[assignment]
grouped_grads, grouped_params, alpha=weight_decay
)
torch._foreach_mul_(grouped_square_avgs, alpha)
torch._foreach_addcmul_(
grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
)
if centered:
grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
avg = torch._foreach_addcmul(
grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
)
torch._foreach_sqrt_(avg)
torch._foreach_add_(avg, eps)
else:
avg = torch._foreach_sqrt(grouped_square_avgs)
torch._foreach_add_(avg, eps)
if momentum > 0:
grouped_momentum_buffer_list = cast(
List[Tensor], grouped_momentum_buffer_list_
)
torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
# If LR is a tensor, the else branch will internally call item()
# which will cause silent incorrectness if we are capturing
if capturable and isinstance(lr, torch.Tensor):
momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
torch._foreach_add_(grouped_params, momentum_lr)
else:
torch._foreach_add_(
grouped_params, grouped_momentum_buffer_list, alpha=-lr
)
else:
# If LR is a tensor, the else branch will internally call item()
# which will cause silent incorrectness if we are capturing
if capturable and isinstance(lr, torch.Tensor):
torch._foreach_div_(avg, -lr)
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
else:
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
def rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rmsprop
else:
func = _single_tensor_rmsprop
func(
params,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)

View File

@ -0,0 +1,464 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
r"""Implementation for the Resilient backpropagation."""
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_capturable_supported_devices,
_get_scalar_dtype,
_maximize_doc,
_use_grad_for_differentiable,
_view_as_real,
Optimizer,
ParamsT,
)
__all__ = ["Rprop", "rprop"]
class Rprop(Optimizer): # noqa: D101
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
etas: Tuple[float, float] = (0.5, 1.2),
step_sizes: Tuple[float, float] = (1e-6, 50),
*,
capturable: bool = False,
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
defaults = dict(
lr=lr,
etas=etas,
step_sizes=step_sizes,
foreach=foreach,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
)
super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("foreach", None)
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = (
torch.tensor(
step_val, dtype=_get_scalar_dtype(), device=p.device
)
if group["capturable"]
else torch.tensor(step_val, dtype=_get_scalar_dtype())
)
def _init_group(self, group, params, grads, prevs, step_sizes, state_steps):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params.append(p)
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Rprop does not support sparse gradients")
grads.append(grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"]
else torch.zeros((), dtype=_get_scalar_dtype())
)
state["prev"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if p.dtype.is_complex:
# Complex Number should be as if they are two independent real numbers.
# Hence the step_size shouldn't be zero for imaginary part.
state["step_size"] = torch.full_like(
grad, complex(group["lr"], group["lr"])
)
else:
state["step_size"] = torch.full_like(grad, group["lr"])
prevs.append(state["prev"])
step_sizes.append(state["step_size"])
state_steps.append(state["step"])
return has_complex
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params: List[Tensor] = []
grads: List[Tensor] = []
prevs: List[Tensor] = []
step_sizes: List[Tensor] = []
state_steps: List[Tensor] = []
etaminus, etaplus = group["etas"]
step_size_min, step_size_max = group["step_sizes"]
foreach = group["foreach"]
maximize = group["maximize"]
has_complex = self._init_group(
group, params, grads, prevs, step_sizes, state_steps
)
rprop(
params,
grads,
prevs,
step_sizes,
state_steps,
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
etaplus=etaplus,
foreach=foreach,
maximize=maximize,
differentiable=group["differentiable"],
capturable=group["capturable"],
has_complex=has_complex,
)
return loss
Rprop.__doc__ = (
r"""Implements the resilient backpropagation algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
\text{ (objective)}, \\
&\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
\text{ (step sizes)} \\
&\textbf{initialize} : g^0_{prev} \leftarrow 0,
\: \eta_0 \leftarrow \text{lr (learning rate)} \\
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
&\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
\Gamma_{max}) \\
&\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
&\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
\Gamma_{min}) \\
&\hspace{15mm} g^i_t \leftarrow 0 \\
&\hspace{10mm} \textbf{else} \: \\
&\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
&\hspace{5mm}g_{prev} \leftarrow g_t \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to the paper
`A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
<http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
are multiplicative increase and decrease factors
(default: (0.5, 1.2))
step_sizes (Tuple[float, float], optional): a pair of minimal and
maximal allowed step sizes (default: (1e-6, 50))
{_foreach_doc}
{_capturable_doc}
{_maximize_doc}
{_differentiable_doc}
"""
)
def _single_tensor_rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
prev = prevs[i]
step_size = step_sizes[i]
step = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
step += 1
if torch.is_complex(param):
grad = torch.view_as_real(grad)
prev = torch.view_as_real(prev)
param = torch.view_as_real(param)
step_size = torch.view_as_real(step_size)
if differentiable:
sign = grad.mul(prev.clone()).sign()
else:
sign = grad.mul(prev).sign()
if capturable:
sign.copy_(torch.where(sign.gt(0), etaplus, sign))
sign.copy_(torch.where(sign.lt(0), etaminus, sign))
sign.copy_(torch.where(sign.eq(0), 1, sign))
else:
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
step_size.mul_(sign).clamp_(step_size_min, step_size_max)
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
grad = grad.clone(memory_format=torch.preserve_format)
if capturable:
grad.copy_(torch.where(sign.eq(etaminus), 0, grad))
else:
grad[sign.eq(etaminus)] = 0
# update parameters
param.addcmul_(grad.sign(), step_size, value=-1)
prev.copy_(grad)
def _multi_tensor_rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
capturable_supported_devices = _get_capturable_supported_devices()
assert all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item]
)
for (
grouped_params_,
grouped_grads_,
grouped_prevs_,
grouped_step_sizes_,
grouped_state_steps_,
), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_)
grouped_prevs = cast(List[Tensor], grouped_prevs_)
grouped_step_sizes = cast(List[Tensor], grouped_step_sizes_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
torch._foreach_add_(
grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(grouped_state_steps, 1)
# Handle complex params
if has_complex:
_view_as_real(
grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes
)
signs = torch._foreach_mul(grouped_grads, grouped_prevs)
if maximize:
torch._foreach_neg_(signs)
# At the end of the step, grouped_prevs will contain the current grads, so we reuse
# grouped_prevs memory instead of creating a new buffer, but, for clarity, we reassign
# to keep referring to the buffer as grouped_grads.
torch._foreach_copy_(grouped_prevs, grouped_grads)
if maximize:
torch._foreach_neg_(grouped_prevs)
grouped_grads = grouped_prevs
torch._foreach_sign_(signs)
if capturable:
for sign in signs:
sign.copy_(torch.where(sign.gt(0), etaplus, sign))
sign.copy_(torch.where(sign.lt(0), etaminus, sign))
sign.copy_(torch.where(sign.eq(0), 1, sign))
else:
for sign in signs:
sign[sign.gt(0)] = etaplus
sign[sign.lt(0)] = etaminus
sign[sign.eq(0)] = 1
# update stepsizes with step size updates
torch._foreach_mul_(grouped_step_sizes, signs)
for step_size in grouped_step_sizes:
step_size.clamp_(step_size_min, step_size_max)
# for dir<0, dfdx=0
# for dir>=0 dfdx=dfdx
grouped_grads = list(grouped_grads)
for i in range(len(grouped_grads)):
grouped_grads[i].copy_(
torch.where(signs[i].eq(etaminus), 0, grouped_grads[i])
)
# explicitly del signs as it's not used after here to save memory
del signs
# update parameters
grad_signs = [grad.sign() for grad in grouped_grads]
torch._foreach_addcmul_(
grouped_params, grad_signs, grouped_step_sizes, value=-1
)
# Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
# basically already happened since we've been using grouped_prevs' memory to store
# updated grouped_grads!
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
def rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
maximize: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float,
):
r"""Functional API that performs rprop algorithm computation.
See :class:`~torch.optim.Rprop` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rprop
else:
func = _single_tensor_rprop
func(
params,
grads,
prevs,
step_sizes,
state_steps,
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
etaplus=etaplus,
capturable=capturable,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
)

View File

@ -0,0 +1,511 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, List, Optional, Union
import torch
from torch import Tensor
from .optimizer import (
_default_to_fused_or_foreach,
_device_dtype_check_for_fused,
_differentiable_doc,
_foreach_doc,
_fused_doc,
_maximize_doc,
_use_grad_for_differentiable,
DeviceDict,
Optimizer,
)
__all__ = ["SGD", "sgd"]
class SGD(Optimizer): # noqa: D101
def __init__(
self,
params,
lr: Union[float, Tensor] = 1e-3,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov=False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
fused: Optional[bool] = None,
): # noqa: D107
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
maximize=maximize,
foreach=foreach,
differentiable=differentiable,
fused=fused,
)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
if fused:
self._step_supports_amp_scaling = True
self._need_device_dtype_check_for_fused = True
if differentiable:
raise RuntimeError("`fused` does not support `differentiable`")
if foreach:
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
def __setstate__(self, state): # noqa: D105
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("nesterov", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("differentiable", False)
group.setdefault("fused", False)
def _init_group(self, group, params, grads, momentum_buffer_list):
has_sparse_grad = False
for p in group["params"]:
if p.grad is not None:
if group["fused"] and getattr(
self, "_need_device_dtype_check_for_fused", True
):
_device_dtype_check_for_fused(p)
self._need_device_dtype_check_for_fused = False
params.append(p)
grads.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True
if group["momentum"] != 0:
state = self.state[p]
momentum_buffer_list.append(state.get("momentum_buffer"))
return has_sparse_grad
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params: List[Tensor] = []
grads: List[Tensor] = []
momentum_buffer_list: List[Optional[Tensor]] = []
has_sparse_grad = self._init_group(
group, params, grads, momentum_buffer_list
)
sgd(
params,
grads,
momentum_buffer_list,
weight_decay=group["weight_decay"],
momentum=group["momentum"],
lr=group["lr"],
dampening=group["dampening"],
nesterov=group["nesterov"],
maximize=group["maximize"],
has_sparse_grad=has_sparse_grad,
foreach=group["foreach"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
if group["momentum"] != 0:
# update momentum_buffers in state
for p, momentum_buffer in zip(params, momentum_buffer_list):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer
return loss
SGD.__doc__ = (
r"""Implements stochastic gradient descent (optionally with momentum).
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
&\hspace{10mm}\textbf{if} \: t > 1 \\
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
&\hspace{10mm}\textbf{else} \\
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
&\hspace{10mm}\textbf{else} \\[-1.ex]
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
&\hspace{5mm}\textbf{else} \\[-1.ex]
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
{_maximize_doc}
{_foreach_doc}
{_differentiable_doc}
{_fused_doc}
"""
+ r"""
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et al. and
other frameworks which employ an update of the form
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} - v_{t+1}.
\end{aligned}
The Nesterov version is analogously modified.
Moreover, the initial value of the momentum buffer is set to the
gradient value at the first step. This is in contrast to some other
frameworks that initialize it to all zeros.
"""
)
def sgd(
params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = False,
foreach: Optional[bool] = None,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
):
r"""Functional API that performs SGD algorithm computation.
See :class:`~torch.optim.SGD` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if foreach is None and fused is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
if not torch.jit.is_scripting():
fused, foreach = _default_to_fused_or_foreach(
params, differentiable=False, use_fused=False
)
else:
foreach = False
fused = False
if foreach is None:
foreach = False
if fused is None:
fused = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_sgd
elif fused and not torch.jit.is_scripting():
func = _fused_sgd
else:
func = _single_tensor_sgd
func(
params,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
has_sparse_grad=has_sparse_grad,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _single_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(grad).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
param.add_(grad, alpha=-lr)
def _multi_tensor_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
):
assert grad_scale is None and found_inf is None
if len(params) == 0:
return
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
)
for (
device_params_,
device_grads_,
device_momentum_buffer_list,
), indices in grouped_tensors.values():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_has_sparse_grad = has_sparse_grad and any(
grad.is_sparse for grad in device_grads
)
if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if momentum != 0:
bufs: List[Tensor] = []
all_states_with_momentum_buffer = True
for i in range(len(device_momentum_buffer_list)):
if device_momentum_buffer_list[i] is None:
all_states_with_momentum_buffer = False
break
else:
bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
if all_states_with_momentum_buffer:
torch._foreach_mul_(bufs, momentum)
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
else:
bufs = []
for i in range(len(device_momentum_buffer_list)):
if device_momentum_buffer_list[i] is None:
buf = device_momentum_buffer_list[i] = momentum_buffer_list[
indices[i]
] = torch.clone(device_grads[i]).detach()
else:
buf = cast(Tensor, device_momentum_buffer_list[i])
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
bufs.append(buf)
if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum)
else:
device_grads = bufs
if not device_has_sparse_grad:
# handle internal item() call if lr is a tensor
if isinstance(lr, torch.Tensor) and torch._utils.is_compiling():
grads_x_lr = torch._foreach_mul(device_grads, -lr)
torch._foreach_add_(device_params, grads_x_lr)
else:
torch._foreach_add_(device_params, device_grads, alpha=-lr)
else:
# foreach APIs don't support sparse
for i in range(len(device_params)):
device_params[i].add_(device_grads[i], alpha=-lr)
def _fused_sgd(
params: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool,
) -> None:
if not params:
return
if has_sparse_grad:
raise RuntimeError("`_fused_sgd` does not support sparse gradients")
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
found_inf_dict: DeviceDict = (
{found_inf.device: found_inf} if found_inf is not None else {}
)
no_momentum_buffer = momentum == 0
is_first_step = (
all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
)
if is_first_step:
for i, g in enumerate(grads):
momentum_buffer_list[i] = torch.empty_like(g)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
)
for (device, _), (
(device_params_, device_grads_, device_momentum_buffer_list),
_,
) in grouped_tensors.items():
device_params: List[Tensor] = cast(List[Tensor], device_params_)
device_grads: List[Tensor] = cast(List[Tensor], device_grads_)
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
device_grad_scale = grad_scale_dict.setdefault(
device, grad_scale.to(device)
)
if found_inf_dict is not None and found_inf is not None:
device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
torch._fused_sgd_(
device_params,
device_grads,
[]
if no_momentum_buffer
else cast(List[Tensor], device_momentum_buffer_list),
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
maximize=maximize,
is_first_step=is_first_step,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)

View File

@ -0,0 +1,185 @@
# mypy: allow-untyped-defs
from typing import List, Tuple, Union
import torch
from torch import Tensor
from . import _functional as F
from .optimizer import _maximize_doc, Optimizer, ParamsT
__all__ = ["SparseAdam"]
class SparseAdam(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
maximize: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 < lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 < eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
super().__init__(params, defaults)
sparse_params = []
complex_params = []
for index, param_group in enumerate(self.param_groups):
assert isinstance(
param_group, dict
), f"param_groups must be a list of dicts, but got {type(param_group)}"
# given param group, convert given params to a list first before iterating
for d_index, d_param in enumerate(param_group["params"]):
if d_param.is_sparse:
sparse_params.append([index, d_index])
if d_param.is_complex():
complex_params.append([index, d_index])
if sparse_params:
raise ValueError(
f"Sparse params at indices {sparse_params}: SparseAdam requires dense parameter tensors"
)
if complex_params:
raise ValueError(
f"Complex params at indices {complex_params}: SparseAdam does not support complex parameters"
)
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
exp_avgs: List[Tensor] = []
exp_avg_sqs: List[Tensor] = []
state_steps: List[int] = []
beta1, beta2 = group["betas"]
maximize = group.get("maximize", False)
for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
if not p.grad.is_sparse:
raise RuntimeError(
"SparseAdam does not support dense gradients, please consider Adam instead"
)
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])
F.sparse_adam(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
eps=group["eps"],
beta1=beta1,
beta2=beta2,
lr=group["lr"],
maximize=maximize,
)
return loss
SparseAdam.__doc__ = rf"""SparseAdam implements a masked version of the Adam algorithm
suitable for sparse gradients. Currently, due to implementation constraints (explained
below), SparseAdam is only intended for a narrow subset of use cases, specifically
parameters of a dense layout with gradients of a sparse layout. This occurs in a
special case where the module backwards produces grads already in a sparse layout.
One example NN module that behaves as such is ``nn.Embedding(sparse=True)``.
SparseAdam approximates the Adam algorithm by masking out the parameter and moment
updates corresponding to the zero values in the gradients. Whereas the Adam algorithm
will update the first moment, the second moment, and the parameters based on all values
of the gradients, SparseAdam only updates the moments and parameters corresponding
to the non-zero values of the gradients.
A simplified way of thinking about the `intended` implementation is as such:
1. Create a mask of the non-zero values in the sparse gradients. For example,
if your gradient looks like [0, 5, 0, 0, 9], the mask would be [0, 1, 0, 0, 1].
2. Apply this mask over the running moments and do computation on only the
non-zero values.
3. Apply this mask over the parameters and only apply an update on non-zero values.
In actuality, we use sparse layout Tensors to optimize this approximation, which means the
more gradients that are masked by not being materialized, the more performant the optimization.
Since we rely on using sparse layout tensors, we infer that any materialized value in the
sparse layout is non-zero and we do NOT actually verify that all values are not zero!
It is important to not conflate a semantically sparse tensor (a tensor where many
of its values are zeros) with a sparse layout tensor (a tensor where ``.is_sparse``
returns ``True``). The SparseAdam approximation is intended for `semantically` sparse
tensors and the sparse layout is only a implementation detail. A clearer implementation
would be to use MaskedTensors, but those are experimental.
.. note::
If you suspect your gradients are semantically sparse (but do not have sparse
layout), this variant may not be the best for you. Ideally, you want to avoid
materializing anything that is suspected to be sparse in the first place, since
needing to convert all your grads from dense layout to sparse layout may outweigh
the performance gain. Here, using Adam may be the best alternative, unless you
can easily rig up your module to output sparse grads similar to
``nn.Embedding(sparse=True)``. If you insist on converting your grads, you can do
so by manually overriding your parameters' ``.grad`` fields with their sparse
equivalents before calling ``.step()``.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
{_maximize_doc}
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
"""

View File

@ -0,0 +1,467 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Weight Averaging implementation."""
import itertools
import math
import warnings
from copy import deepcopy
from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim.lr_scheduler import _format_param, LRScheduler
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
from .optimizer import Optimizer
__all__ = [
"AveragedModel",
"update_bn",
"SWALR",
"get_ema_multi_avg_fn",
"get_swa_multi_avg_fn",
"get_ema_avg_fn",
"get_swa_avg_fn",
]
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
def get_ema_multi_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across multiple params."""
@torch.no_grad()
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
# foreach lerp only handles float and complex
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
ema_param_list[0]
):
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
else:
for p_ema, p_model in zip(ema_param_list, current_param_list):
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
return ema_update
def get_swa_multi_avg_fn():
"""Get the function applying stochastic weight average (SWA) across multiple params."""
@torch.no_grad()
def swa_update(
averaged_param_list: PARAM_LIST,
current_param_list: PARAM_LIST,
num_averaged: Union[Tensor, int],
):
# foreach lerp only handles float and complex
if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
averaged_param_list[0]
):
torch._foreach_lerp_(
averaged_param_list, current_param_list, 1 / (num_averaged + 1)
)
else:
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
if isinstance(num_averaged, Tensor):
torch._foreach_addcdiv_(
averaged_param_list,
diffs,
[num_averaged + 1] * len(averaged_param_list),
)
else:
torch._foreach_add_(
averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
)
return swa_update
def get_ema_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across a single param."""
@torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
return decay * ema_param + (1 - decay) * current_param
return ema_update
def get_swa_avg_fn():
"""Get the function applying stochastic weight average (SWA) across a single param."""
@torch.no_grad()
def swa_update(
averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]
):
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
return swa_update
class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
Exponential Moving Average is a variation of `Polyak averaging`_,
but using exponential weights instead of equal weights across iterations.
AveragedModel class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows to compute running averages of the
parameters of the :attr:`model`.
Args:
model (torch.nn.Module): model to use with SWA/EMA
device (torch.device, optional): if provided, the averaged model will be
stored on the :attr:`device`
avg_fn (function, optional): the averaging function used to update
parameters; the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter, and the number of models already averaged; if None,
an equally weighted average is used (default: None)
multi_avg_fn (function, optional): the averaging function used to update
parameters inplace; the function must take in the current values of the
:class:`AveragedModel` parameters as a list, the current values of :attr:`model`
parameters as a list, and the number of models already averaged; if None,
an equally weighted average is used (default: None)
use_buffers (bool): if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``False``)
Example:
>>> # xdoctest: +SKIP("undefined variables")
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
>>> T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>> for i in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if i > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
If no averaging function is provided, the default is to compute
equally-weighted average of the weights (SWA).
Example:
>>> # xdoctest: +SKIP("undefined variables")
>>> # Compute exponential moving averages of the weights and buffers
>>> ema_model = torch.optim.swa_utils.AveragedModel(model,
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
.. note::
When using SWA/EMA with models containing Batch Normalization you may
need to update the activation statistics for Batch Normalization.
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
or by setting :attr:`use_buffers` to `True`. The first approach updates the
statistics in a post-training step by passing data through the model. The
second does it during the parameter update phase by averaging all buffers.
Empirical evidence has shown that updating the statistics in normalization
layers increases accuracy, but you may wish to empirically test which
approach yields the best results in your problem.
.. note::
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
.. note::
When :meth:`update_parameters` is called for the first time (i.e.
:attr:`n_averaged` is `0`) the parameters of `model` are copied
to the parameters of :class:`AveragedModel`. For every subsequent
call of :meth:`update_parameters` the function `avg_fn` is used
to update the parameters.
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
https://arxiv.org/abs/1803.05407
.. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
Average:
https://arxiv.org/abs/1806.05594
.. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
https://arxiv.org/abs/1904.11943
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
Generalizes Well:
https://arxiv.org/abs/2001.02312
.. _Polyak averaging:
https://paperswithcode.com/method/polyak-averaging
"""
n_averaged: Tensor
def __init__(
self,
model: Module,
device: Optional[Union[int, torch.device]] = None,
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
multi_avg_fn: Optional[
Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]
] = None,
use_buffers=False,
): # noqa: D107
super().__init__()
assert (
avg_fn is None or multi_avg_fn is None
), "Only one of avg_fn and multi_avg_fn should be provided"
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer(
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
)
self.avg_fn = avg_fn
self.multi_avg_fn = multi_avg_fn
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
"""Forward pass."""
return self.module(*args, **kwargs)
def update_parameters(self, model: Module):
"""Update model parameters."""
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers
else self.parameters()
)
model_param = (
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers
else model.parameters()
)
self_param_detached: List[Optional[Tensor]] = []
model_param_detached: List[Optional[Tensor]] = []
for p_averaged, p_model in zip(self_param, model_param):
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
if self.n_averaged == 0:
p_averaged.detach().copy_(p_model_)
if self.n_averaged > 0:
if self.multi_avg_fn is not None or self.avg_fn is None:
grouped_tensors = _group_tensors_by_device_and_dtype(
[self_param_detached, model_param_detached]
)
for (device, _), (
[self_params, model_params],
_,
) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
)
elif (
device is not None
and device.type in _get_foreach_kernels_supported_devices()
):
multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn(
self_params, model_params, self.n_averaged.to(device)
)
else:
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached
):
n_averaged = self.n_averaged.to(p_averaged.device)
p_averaged.detach().copy_(
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
)
if not self.use_buffers:
# If not apply running averages to the buffers,
# keep the buffers in sync with the source model.
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
self.n_averaged += 1
@torch.no_grad()
def update_bn(
loader: Iterable[Any],
model: Module,
device: Optional[Union[int, torch.device]] = None,
):
r"""Update BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
statistics for BatchNorm layers in the model.
Args:
loader (torch.utils.data.DataLoader): dataset loader to compute the
activation statistics on. Each data batch should be either a
tensor, or a list/tuple whose first element is a tensor
containing data.
model (torch.nn.Module): model for which we seek to update BatchNorm
statistics.
device (torch.device, optional): If set, data will be transferred to
:attr:`device` before being passed into :attr:`model`.
Example:
>>> # xdoctest: +SKIP("Undefined variables")
>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)
.. note::
The `update_bn` utility assumes that each data batch in :attr:`loader`
is either a tensor or a list or tuple of tensors; in the latter case it
is assumed that :meth:`model.forward()` should be called on the first
element of the list or tuple corresponding to the data batch.
"""
momenta = {}
for module in model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.reset_running_stats()
momenta[module] = module.momentum
if not momenta:
return
was_training = model.training
model.train()
for module in momenta.keys():
module.momentum = None
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
if device is not None:
input = input.to(device)
model(input)
for bn_module in momenta.keys():
bn_module.momentum = momenta[bn_module]
model.train(was_training)
class SWALR(LRScheduler):
r"""Anneals the learning rate in each parameter group to a fixed value.
This learning rate scheduler is meant to be used with Stochastic Weight
Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
Args:
optimizer (torch.optim.Optimizer): wrapped optimizer
swa_lrs (float or list): the learning rate value for all param groups
together or separately for each group.
annealing_epochs (int): number of epochs in the annealing phase
(default: 10)
annealing_strategy (str): "cos" or "linear"; specifies the annealing
strategy: "cos" for cosine annealing, "linear" for linear annealing
(default: "cos")
last_epoch (int): the index of the last epoch (default: -1)
The :class:`SWALR` scheduler can be used together with other
schedulers to switch to a constant learning rate late in the training
as in the example below.
Example:
>>> # xdoctest: +SKIP("Undefined variables")
>>> loader, optimizer, model = ...
>>> lr_lambda = lambda epoch: 0.9
>>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
>>> lr_lambda=lr_lambda)
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
>>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
>>> swa_start = 160
>>> for i in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if i > swa_start:
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
https://arxiv.org/abs/1803.05407
"""
def __init__(
self,
optimizer: Optimizer,
swa_lr: float,
anneal_epochs=10,
anneal_strategy: Literal["cos", "linear"] = "cos",
last_epoch=-1,
): # noqa: D107
swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
group["swa_lr"] = swa_lr
if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
"anneal_strategy must by one of 'cos' or 'linear', "
f"instead got {anneal_strategy}"
)
elif anneal_strategy == "cos":
self.anneal_func = self._cosine_anneal
elif anneal_strategy == "linear":
self.anneal_func = self._linear_anneal
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
raise ValueError(
f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
)
self.anneal_epochs = anneal_epochs
super().__init__(optimizer, last_epoch)
@staticmethod
def _linear_anneal(t):
return t
@staticmethod
def _cosine_anneal(t):
return (1 - math.cos(math.pi * t)) / 2
@staticmethod
def _get_initial_lr(lr, swa_lr, alpha):
if alpha == 1:
return swa_lr
return (lr - alpha * swa_lr) / (1 - alpha)
def get_lr(self):
"""Get learning rate."""
# `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
# so we ignore the type error here. See `LRScheduler.step()` for more details.
if not self._get_lr_called_within_step: # type: ignore[attr-defined]
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)
# Set in `LRScheduler._initial_step()`
step = self._step_count - 1 # type: ignore[attr-defined]
if self.anneal_epochs == 0:
step = max(1, step)
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
prev_alpha = self.anneal_func(prev_t)
prev_lrs = [
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
for group in self.optimizer.param_groups
]
t = max(0, min(1, step / max(1, self.anneal_epochs)))
alpha = self.anneal_func(t)
return [
group["swa_lr"] * alpha + lr * (1 - alpha)
for group, lr in zip(self.optimizer.param_groups, prev_lrs)
]