I am done

This commit is contained in:
2024-10-30 22:14:35 +01:00
parent 720dc28c09
commit 40e2a747cf
36901 changed files with 5011519 additions and 0 deletions

View File

@ -0,0 +1,9 @@
from .autocast_mode import (
_enter_autocast,
_exit_autocast,
autocast,
custom_bwd,
custom_fwd,
is_autocast_available,
)
from .grad_scaler import GradScaler

View File

@ -0,0 +1,513 @@
# mypy: allow-untyped-defs
import collections
import functools
import warnings
from typing import Any, Optional
import torch
from torch.types import _dtype
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
HAS_NUMPY = False
np = None # type: ignore[assignment]
__all__ = [
"autocast_decorator",
"autocast",
"is_autocast_available",
"custom_fwd",
"custom_bwd",
]
def is_autocast_available(device_type: str) -> bool:
r"""
Return a bool indicating if autocast is available on :attr:`device_type`.
Args:
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
return torch._C._is_autocast_available(device_type)
def autocast_decorator(autocast_instance, func):
@functools.wraps(func)
def decorate_autocast(*args, **kwargs):
with autocast_instance:
return func(*args, **kwargs)
decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]
return decorate_autocast
class autocast:
r"""
Instances of :class:`autocast` serve as context managers or decorators that
allow regions of your script to run in mixed precision.
In these regions, ops run in an op-specific dtype chosen by autocast
to improve performance while maintaining accuracy.
See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
When entering an autocast-enabled region, Tensors may be any type.
You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
:class:`autocast` should wrap only the forward pass(es) of your network, including the loss
computation(s). Backward passes under autocast are not recommended.
Backward ops run in the same type that autocast used for corresponding forward ops.
Example for CUDA Devices::
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
:class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
class AutocastModel(nn.Module):
...
@torch.autocast(device_type="cuda")
def forward(self, input):
...
Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
After returning to an autocast-disabled region, using them with floating-point
Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)
produced in the autocast region back to ``float32`` (or other dtype if desired).
If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
and incurs no additional overhead.
CUDA Example::
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda"):
# torch.mm is on autocast's list of ops that should run in float16.
# Inputs are float32, but the op runs in float16 and produces float16 output.
# No manual casts are required.
e_float16 = torch.mm(a_float32, b_float32)
# Also handles mixed input types
f_float16 = torch.mm(d_float32, e_float16)
# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
CPU Training Example::
# Creates model and optimizer in default precision
model = Net()
optimizer = optim.SGD(model.parameters(), ...)
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
CPU Inference Example::
# Creates model in default precision
model = Net().eval()
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
for input in data:
# Runs the forward pass with autocasting.
output = model(input)
CPU Inference Example with Jit Trace::
class TestModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.fc1(x)
input_size = 2
num_classes = 2
model = TestModel(input_size, num_classes).eval()
# For now, we suggest to disable the Jit Autocast Pass,
# As the issue: https://github.com/pytorch/pytorch/issues/75956
torch._C._jit_set_autocast_mode(False)
with torch.cpu.amp.autocast(cache_enabled=False):
model = torch.jit.trace(model, torch.randn(1, input_size))
model = torch.jit.freeze(model)
# Models Run
for _ in range(3):
model(torch.randn(1, input_size))
Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
please file an issue.
``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
Locally disabling autocast can be useful, for example, if you want to force a subregion
to run in a particular ``dtype``. Disabling autocast gives you explicit control over
the execution type. In the subregion, inputs from the surrounding region
should be cast to ``dtype`` before use::
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())
# No manual casts are required when re-entering the autocast-enabled region.
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
g_float16 = torch.mm(d_float32, f_float32)
The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator
must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and
:class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
Args:
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
enabled(bool, optional): Whether autocasting should be enabled in the region.
Default: ``True``
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
:func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
Default: ``None``
cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
Default: ``True``
"""
def __init__(
self,
device_type: str,
dtype: Optional[_dtype] = None,
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
self.fast_dtype = dtype
assert dtype is not None
return
self.device = device_type
if not is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
for func in necessary_funcs:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and torch.cuda.amp.common.amp_definitely_not_available()
and self.device == "cuda"
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if dtype is not None:
self.fast_dtype = dtype
if cache_enabled is not None:
self._cache_enabled = cache_enabled
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"MPS Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):
if torch._jit_internal.is_scripting():
assert self.fast_dtype is not None
return self
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return autocast_decorator(self, func)
# These functions aren't meant for public usage.
# They are what we trace into a graph during pre_dispatch tracing
# when we encounter an autocast context manager.
def _enter_autocast(*vals):
# For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(
torch.amp._enter_autocast, [], *vals
)
mode = torch.amp.autocast(*vals)
mode.__enter__()
return mode
def _exit_autocast(mode):
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
mode.__exit__(None, None, None)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, device_type: str, dtype: _dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.device.type == device_type
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {
_cast(k, device_type, dtype): _cast(v, device_type, dtype)
for k, v in value.items()
}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, device_type, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
def custom_fwd(
fwd=None,
*,
device_type: str,
cast_inputs: Optional[_dtype] = None,
):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if fwd is None:
return functools.partial(
custom_fwd, device_type=device_type, cast_inputs=cast_inputs
)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled(device_type)
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(device_type=device_type, enabled=False):
return fwd(
*_cast(args, device_type, cast_inputs),
**_cast(kwargs, device_type, cast_inputs),
)
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd=None, *, device_type: str):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if bwd is None:
return functools.partial(custom_bwd, device_type=device_type)
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(
device_type=device_type,
enabled=args[0]._fwd_used_autocast,
dtype=args[0]._dtype,
):
return bwd(*args, **kwargs)
return decorate_bwd

View File

@ -0,0 +1,685 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import inspect
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Optional, overload, Tuple, Union
import torch
__all__ = ["OptState", "GradScaler"]
class _MultiDeviceReplicator:
"""Lazily serves copies of a tensor to requested devices.
Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state() -> Dict[str, Any]:
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
"""An instance ``scaler`` of :class:`GradScaler`.
Helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(
self,
device: str = "cuda",
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
self._device = device
self._enabled = enabled
if self._device == "cuda":
if enabled and torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling."
)
self._enabled = False
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale: Optional[torch.Tensor] = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker: Optional[torch.Tensor] = None
self._per_optimizer_states: Dict[int, Dict[str, Any]] = defaultdict(
_refresh_per_optimizer_state
)
def _check_scale_growth_tracker(
self, funcname: str
) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, (
f"Attempted {funcname} but _scale is None. " + fix
)
assert self._growth_tracker is not None, (
f"Attempted {funcname} but _growth_tracker is None. " + fix
)
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(), self._init_growth_tracker, dtype=torch.int32, device=dev
)
@overload
def scale(self, outputs: torch.Tensor) -> torch.Tensor:
...
@overload
def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def scale(self, outputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
...
@overload
def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
...
def scale(
self,
outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
if isinstance(val, torch.Tensor):
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
if isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
return iterable
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(
self,
optimizer: torch.optim.Optimizer,
inv_scale: torch.Tensor,
found_inf: torch.Tensor,
allow_fp16: bool,
) -> Dict[torch.device, torch.Tensor]:
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads: Dict[
torch.device, Dict[torch.dtype, List[torch.Tensor]]
] = defaultdict(lambda: defaultdict(list))
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
assert isinstance(param, torch.Tensor)
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(
self,
optimizer: torch.optim.Optimizer,
optimizer_state: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Optional[float]:
retval: Optional[float] = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
def step(
self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
) -> Optional[float]:
"""Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN.
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if not self._enabled:
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled."
)
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update()."
)
retval: Optional[float] = None
if getattr(optimizer, "_step_supports_amp_scaling", False):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = (
"grad_scaler" in inspect.signature(optimizer.step).parameters
)
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning,
)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
assert scaler is not None
found_inf = cast(
torch.Tensor,
sum(
[ # noqa: C419
t.to(scaler.device, non_blocking=True)
for t in optimizer_state["found_inf_per_device"].values()
]
),
)
# Take the product of the scales, if the user has already set `optimizer.grad_scale`.
optimizer.grad_scale = ( # type: ignore[attr-defined]
getattr(optimizer, "grad_scale", None)
if optimizer_state["stage"] == OptState.UNSCALED
else scaler * getattr(optimizer, "grad_scale", 1)
)
optimizer.found_inf = found_inf # type: ignore[attr-defined]
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale # type: ignore[attr-defined]
del optimizer.found_inf # type: ignore[attr-defined]
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert (
len(optimizer_state["found_inf_per_device"]) > 0
), "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
"""Update the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
.. warning::
For performance reasons, we do not check the scale factor value to avoid synchronizations,
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
assert self._scale is not None
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale)
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \
torch.FloatTensor with requires_grad=False."
assert new_scale.device.type == self._device, reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale)
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self) -> Optional[torch.Tensor]:
return self._scale
def get_scale(self) -> float:
"""Return a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return (
self._init_scale
if (scale := self._get_scale_async()) is None
else cast(float, scale.item())
)
return 1.0
def get_growth_factor(self) -> float:
r"""Return a Python float containing the scale growth factor."""
return self._growth_factor
def set_growth_factor(self, new_factor: float) -> None:
r"""Set a new scale growth factor.
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self) -> float:
r"""Return a Python float containing the scale backoff factor."""
return self._backoff_factor
def set_backoff_factor(self, new_factor: float) -> None:
r"""Set a new scale backoff factor.
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self) -> int:
r"""Return a Python int containing the growth interval."""
return self._growth_interval
def set_growth_interval(self, new_interval: int) -> None:
r"""Set a new growth interval.
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self) -> int:
if self._enabled:
return (
self._init_growth_tracker
if self._growth_tracker is None
else cast(int, self._growth_tracker.item())
)
return 0
def is_enabled(self) -> bool:
r"""Return a bool indicating whether this instance is enabled."""
return self._enabled
def state_dict(self) -> Dict[str, Any]:
r"""Return the state of the scaler as a :class:`dict`.
It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
if self._enabled:
return {
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(),
}
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
r"""Load the scaler state.
If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError(
"The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_scale = cast(float, state_dict["scale"])
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = cast(float, state_dict["growth_factor"])
self._backoff_factor = cast(float, state_dict["backoff_factor"])
self._growth_interval = cast(int, state_dict["growth_interval"])
self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, (
"A GradScaler instance may only be pickled at the beginning "
"of an iteration, or at the end after scaler.update()."
)
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state["_init_scale"] = self.get_scale()
state["_init_growth_tracker"] = self._get_growth_tracker()
state["_scale"] = None
state["_growth_tracker"] = None
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)][
"found_inf_per_device"
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> Dict[str, Any]:
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]