diff --git a/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu b/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu index 770b08c6..032c8220 100644 --- a/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu +++ b/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu @@ -1,4 +1,5 @@ #include +#include #include #include "cuda_util.h" diff --git a/oslo/lightseq2/training/ops/pytorch/builder/builder.py b/oslo/lightseq2/training/ops/pytorch/builder/builder.py index 329cb693..fb059fb7 100644 --- a/oslo/lightseq2/training/ops/pytorch/builder/builder.py +++ b/oslo/lightseq2/training/ops/pytorch/builder/builder.py @@ -60,7 +60,7 @@ def assert_no_cuda_mismatch(): torch_cuda_version = ".".join(torch.version.cuda.split(".")[:2]) # This is a show-stopping error, should probably not proceed past this if sys_cuda_version != torch_cuda_version: - if cuda_major == 11 and torch_cuda_version.split(".")[0] == "11": + if cuda_major in [11, 12] and torch_cuda_version.split(".")[0] == "11": # it works to build against installed cuda-11 while torch was built with cuda-11 return raise Exception( diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py new file mode 100644 index 00000000..c4c362e5 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py @@ -0,0 +1,4 @@ +from .base_grad_scaler import BaseGradScaler +from .dynamic_grad_scaler import DynamicGradScaler + +__ALL__ = ["BaseGradScaler", "DynamicGradScaler"] diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py new file mode 100644 index 00000000..12469be1 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import Tensor + +from oslo.torch.utils.logging import get_dist_logger + +__all__ = ["BaseGradScaler"] + + +class BaseGradScaler(ABC): + """A base class for the gradient scaler. + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: float, verbose: bool): + assert initial_scale > 0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + # @property + def scale(self, loss) -> Tensor: + """Returns the loss scale.""" + + return self._scale * loss + + def step(self, optimizer): + self.found_inf = optimizer._check_overflow() + return optimizer.step() + + @property + def inv_scale(self) -> Tensor: + """Returns the inverse of the loss scale.""" + + return self._scale.double().reciprocal().float() + + def state_dict(self) -> Dict: + """Returns the states of the gradient scaler as a dict object.""" + + state_dict = dict() + state_dict["scale"] = self.scale + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + """Load the states of the gradient scaler from a dict object. + + Args: + state_dict (dict): the states of the gradient scaler + """ + + self._scale = state_dict["scale"] + + @abstractmethod + def update(self) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + pass + + def log(self, message, *args, **kwargs): + """Log messages. + + Args: + message (str): the message to log + *args: positional arguments for :class:`oslo.torch.utils.logging.DistributedLogger` + **kwargs: key-word arguments for :class:`oslo.torch.utils.logging.DistributedLogger` + """ + + if self._verbose: + self._logger.info(message, *args, **kwargs) diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py new file mode 100644 index 00000000..d9dde8a3 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py @@ -0,0 +1,133 @@ +from typing import Optional + +import torch + +from .base_grad_scaler import BaseGradScaler + +__all__ = ["DynamicGradScaler"] + + +class DynamicGradScaler(BaseGradScaler): + """A gradient scaler which uses dynamic loss scale + + Args: + initial_scale (float): the initial loss scale, defaults to 2**16 + growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2 + backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5 + growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000 + min_scale (float): the minimum loss scale, defaults to None + max_scale (float): the maximum loss scale, defaults to None + hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2 + verbose (bool): whether to log messages, defaults to False + """ + + def __init__( + self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False, + ): + super().__init__(initial_scale, verbose) + if min_scale: + self._min_scale = torch.cuda.FloatTensor([min_scale]) + else: + self._min_scale = None + + if max_scale: + self._max_scale = torch.cuda.FloatTensor([max_scale]) + else: + self._max_scale = None + + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._growth_step = 0 + self._hysteresis = hysteresis + self._hysteresis_step = 0 + self._sanity_checks() + + def _sanity_checks(self) -> None: + """Check if the arguments are correct.""" + + if self._min_scale: + assert ( + self._min_scale > 0 + ), "The minimum gradient scale cannot be zero or negative" + assert ( + self._min_scale <= self._scale + ), "The minimum gradient scale cannot be greater than the current scale" + if self._max_scale: + assert ( + self._max_scale > 0 + ), "The maximum gradient scale cannot be zero or negative" + assert ( + self._max_scale >= self._scale + ), "The maximum gradient scale cannot be smaller than the current scale" + assert ( + self._growth_factor > 1 + ), "The growth factor cannot be equal or smaller than 1" + assert ( + 0 < self._backoff_factor < 1 + ), "The backoff factor must be between 0 and 1" + assert self._hysteresis >= 0, "The hysteresis cannot be negative" + + def update(self) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + if self.found_inf: + self._hysteresis_step += 1 + self._growth_step = 0 + + if self._hysteresis_step >= self._hysteresis: + self._backoff_scale() + self.log( + f"Overflow occurs, the loss scale is adjusted to {self._scale.item()}", + ranks=[0], + ) + else: + self._growth_step += 1 + if self._growth_step == self._growth_interval: + self._growth_step = 0 + self._hysteresis_step = 0 + self._grow_scale() + self.log( + f"No overflow for consecutive {self._growth_interval} steps, " + f"the loss scale is adjusted to {self._scale.item()}", + ranks=[0], + ) + + def _backoff_scale(self) -> None: + """Decrease the loss scale""" + + self._scale = self._scale * self._backoff_factor + if self._min_scale: + self._scale = torch.max(self._scale, self._min_scale) + + def _grow_scale(self) -> None: + """Increase the loss scale""" + + self._scale = self._scale * self._growth_factor + if self._max_scale: + self._scale = torch.min(self._scale, self._max_scale) + + def state_dict(self): + state_dict = dict() + state_dict["scale"] = self._scale + state_dict["growth_factor"] = self._growth_factor + state_dict["backoff_factor"] = self._backoff_factor + state_dict["hysteresis"] = self._hysteresis + return state_dict + + def load_state_dict(self, state_dict): + self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._hysteresis = state_dict["hysteresis"]