From 1a61aae4091e837e0ebe9346897fc9d0ef1dd36f Mon Sep 17 00:00:00 2001 From: nijkah Date: Thu, 8 Jun 2023 07:49:09 +0000 Subject: [PATCH 1/5] [Feature] Add GradScaler for ZeroOptim --- .../data_parallel/grad_scaler/__init__.py | 4 + .../grad_scaler/base_grad_scaler.py | 82 ++++++++++++ .../grad_scaler/dynamic_grad_scaler.py | 118 ++++++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py create mode 100644 oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py create mode 100644 oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py new file mode 100644 index 00000000..3caa50f4 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py @@ -0,0 +1,4 @@ +from .base_grad_scaler import BaseGradScaler +from .dynamic_grad_scaler import DynamicGradScaler + +__ALL__ = ['BaseGradScaler', 'DynamicGradScaler'] \ No newline at end of file diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py new file mode 100644 index 00000000..03867da5 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from typing import Dict + +import torch +from torch import Tensor + +from oslo.torch.utils.logging import get_dist_logger + +__all__ = ['BaseGradScaler'] + + +class BaseGradScaler(ABC): + """A base class for the gradient scaler. + + Args: + initial_scale (float): the initial loss scale + verbose (bool): whether to log messages + """ + + def __init__(self, initial_scale: float, verbose: bool): + assert initial_scale > 0 + self._scale = torch.cuda.FloatTensor([initial_scale]) + self._verbose = verbose + + if self._verbose: + self._logger = get_dist_logger() + + # @property + def scale(self, loss) -> Tensor: + """Returns the loss scale. + """ + + return self._scale * loss + + def step(self, optimizer): + self.found_inf = optimizer._check_overflow() + return optimizer.step() + + @property + def inv_scale(self) -> Tensor: + """Returns the inverse of the loss scale. + """ + + return self._scale.double().reciprocal().float() + + def state_dict(self) -> Dict: + """Returns the states of the gradient scaler as a dict object. + """ + + state_dict = dict() + state_dict['scale'] = self.scale + return state_dict + + def load_state_dict(self, state_dict: Dict) -> None: + """Load the states of the gradient scaler from a dict object. + + Args: + state_dict (dict): the states of the gradient scaler + """ + + self._scale = state_dict['scale'] + + @abstractmethod + def update(self) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + pass + + def log(self, message, *args, **kwargs): + """Log messages. + + Args: + message (str): the message to log + *args: positional arguments for :class:`oslo.torch.utils.logging.DistributedLogger` + **kwargs: key-word arguments for :class:`oslo.torch.utils.logging.DistributedLogger` + """ + + if self._verbose: + self._logger.info(message, *args, **kwargs) \ No newline at end of file diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py new file mode 100644 index 00000000..86c8829d --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py @@ -0,0 +1,118 @@ +from typing import Optional + +import torch + +from .base_grad_scaler import BaseGradScaler + +__all__ = ['DynamicGradScaler'] + + +class DynamicGradScaler(BaseGradScaler): + """A gradient scaler which uses dynamic loss scale + + Args: + initial_scale (float): the initial loss scale, defaults to 2**16 + growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2 + backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5 + growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000 + min_scale (float): the minimum loss scale, defaults to None + max_scale (float): the maximum loss scale, defaults to None + hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2 + verbose (bool): whether to log messages, defaults to False + """ + + def __init__(self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False): + super().__init__(initial_scale, verbose) + if min_scale: + self._min_scale = torch.cuda.FloatTensor([min_scale]) + else: + self._min_scale = None + + if max_scale: + self._max_scale = torch.cuda.FloatTensor([max_scale]) + else: + self._max_scale = None + + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._growth_step = 0 + self._hysteresis = hysteresis + self._hysteresis_step = 0 + self._sanity_checks() + + def _sanity_checks(self) -> None: + """Check if the arguments are correct. + """ + + if self._min_scale: + assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' + assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + if self._max_scale: + assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' + assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' + assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' + assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' + assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + + def update(self) -> None: + """Update the loss scale. + + Args: + overflow (bool): whether overflow occurs + """ + if self.found_inf: + self._hysteresis_step += 1 + self._growth_step = 0 + + if self._hysteresis_step >= self._hysteresis: + self._backoff_scale() + self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0]) + else: + self._growth_step += 1 + if self._growth_step == self._growth_interval: + self._growth_step = 0 + self._hysteresis_step = 0 + self._grow_scale() + self.log( + f"No overflow for consecutive {self._growth_interval} steps, " + f"the loss scale is adjusted to {self.scale.item()}", + ranks=[0]) + + def _backoff_scale(self) -> None: + """Decrease the loss scale + """ + + self._scale = self._scale * self._backoff_factor + if self._min_scale: + self._scale = torch.max(self._scale, self._min_scale) + + def _grow_scale(self) -> None: + """Increase the loss scale + """ + + self._scale = self._scale * self._growth_factor + if self._max_scale: + self._scale = torch.min(self._scale, self._max_scale) + + def state_dict(self): + state_dict = dict() + state_dict['scale'] = self._scale + state_dict['growth_factor'] = self._growth_factor + state_dict['backoff_factor'] = self._backoff_factor + state_dict['hysteresis'] = self._hysteresis + return state_dict + + def load_state_dict(self, state_dict): + self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict['growth_factor'] + self._backoff_factor = state_dict['backoff_factor'] + self._hysteresis = state_dict['hysteresis'] \ No newline at end of file From 38bc86a6ef03ba8f377c994c602deb56fc707364 Mon Sep 17 00:00:00 2001 From: nijkah Date: Thu, 8 Jun 2023 07:57:43 +0000 Subject: [PATCH 2/5] Lint --- .../data_parallel/grad_scaler/__init__.py | 2 +- .../grad_scaler/base_grad_scaler.py | 19 ++--- .../grad_scaler/dynamic_grad_scaler.py | 81 +++++++++++-------- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py index 3caa50f4..c4c362e5 100644 --- a/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/__init__.py @@ -1,4 +1,4 @@ from .base_grad_scaler import BaseGradScaler from .dynamic_grad_scaler import DynamicGradScaler -__ALL__ = ['BaseGradScaler', 'DynamicGradScaler'] \ No newline at end of file +__ALL__ = ["BaseGradScaler", "DynamicGradScaler"] diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py index 03867da5..12469be1 100644 --- a/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/base_grad_scaler.py @@ -6,7 +6,7 @@ from oslo.torch.utils.logging import get_dist_logger -__all__ = ['BaseGradScaler'] +__all__ = ["BaseGradScaler"] class BaseGradScaler(ABC): @@ -27,28 +27,25 @@ def __init__(self, initial_scale: float, verbose: bool): # @property def scale(self, loss) -> Tensor: - """Returns the loss scale. - """ + """Returns the loss scale.""" return self._scale * loss - + def step(self, optimizer): self.found_inf = optimizer._check_overflow() return optimizer.step() @property def inv_scale(self) -> Tensor: - """Returns the inverse of the loss scale. - """ + """Returns the inverse of the loss scale.""" return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: - """Returns the states of the gradient scaler as a dict object. - """ + """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() - state_dict['scale'] = self.scale + state_dict["scale"] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: @@ -58,7 +55,7 @@ def load_state_dict(self, state_dict: Dict) -> None: state_dict (dict): the states of the gradient scaler """ - self._scale = state_dict['scale'] + self._scale = state_dict["scale"] @abstractmethod def update(self) -> None: @@ -79,4 +76,4 @@ def log(self, message, *args, **kwargs): """ if self._verbose: - self._logger.info(message, *args, **kwargs) \ No newline at end of file + self._logger.info(message, *args, **kwargs) diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py index 86c8829d..22b81ca9 100644 --- a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py @@ -4,7 +4,7 @@ from .base_grad_scaler import BaseGradScaler -__all__ = ['DynamicGradScaler'] +__all__ = ["DynamicGradScaler"] class DynamicGradScaler(BaseGradScaler): @@ -21,15 +21,17 @@ class DynamicGradScaler(BaseGradScaler): verbose (bool): whether to log messages, defaults to False """ - def __init__(self, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - min_scale: Optional[float] = None, - max_scale: Optional[float] = None, - hysteresis: int = 2, - verbose: bool = False): + def __init__( + self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False, + ): super().__init__(initial_scale, verbose) if min_scale: self._min_scale = torch.cuda.FloatTensor([min_scale]) @@ -50,18 +52,29 @@ def __init__(self, self._sanity_checks() def _sanity_checks(self) -> None: - """Check if the arguments are correct. - """ + """Check if the arguments are correct.""" if self._min_scale: - assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' - assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + assert ( + self._min_scale > 0 + ), "The minimum gradient scale cannot be zero or negative" + assert ( + self._min_scale <= self._scale + ), "The minimum gradient scale cannot be greater than the current scale" if self._max_scale: - assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' - assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' - assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' - assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' - assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + assert ( + self._max_scale > 0 + ), "The maximum gradient scale cannot be zero or negative" + assert ( + self._max_scale >= self._scale + ), "The maximum gradient scale cannot be smaller than the current scale" + assert ( + self._growth_factor > 1 + ), "The growth factor cannot be equal or smaller than 1" + assert ( + 0 < self._backoff_factor < 1 + ), "The backoff factor must be between 0 and 1" + assert self._hysteresis >= 0, "The hysteresis cannot be negative" def update(self) -> None: """Update the loss scale. @@ -75,7 +88,10 @@ def update(self) -> None: if self._hysteresis_step >= self._hysteresis: self._backoff_scale() - self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0]) + self.log( + f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", + ranks=[0], + ) else: self._growth_step += 1 if self._growth_step == self._growth_interval: @@ -85,19 +101,18 @@ def update(self) -> None: self.log( f"No overflow for consecutive {self._growth_interval} steps, " f"the loss scale is adjusted to {self.scale.item()}", - ranks=[0]) + ranks=[0], + ) def _backoff_scale(self) -> None: - """Decrease the loss scale - """ + """Decrease the loss scale""" self._scale = self._scale * self._backoff_factor if self._min_scale: self._scale = torch.max(self._scale, self._min_scale) def _grow_scale(self) -> None: - """Increase the loss scale - """ + """Increase the loss scale""" self._scale = self._scale * self._growth_factor if self._max_scale: @@ -105,14 +120,14 @@ def _grow_scale(self) -> None: def state_dict(self): state_dict = dict() - state_dict['scale'] = self._scale - state_dict['growth_factor'] = self._growth_factor - state_dict['backoff_factor'] = self._backoff_factor - state_dict['hysteresis'] = self._hysteresis + state_dict["scale"] = self._scale + state_dict["growth_factor"] = self._growth_factor + state_dict["backoff_factor"] = self._backoff_factor + state_dict["hysteresis"] = self._hysteresis return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_factor = state_dict['growth_factor'] - self._backoff_factor = state_dict['backoff_factor'] - self._hysteresis = state_dict['hysteresis'] \ No newline at end of file + self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._hysteresis = state_dict["hysteresis"] From 979c55430aee147671d086cee2a4e8af6f0ebcbf Mon Sep 17 00:00:00 2001 From: nijkah Date: Thu, 20 Jul 2023 12:00:16 +0000 Subject: [PATCH 3/5] fix CUDA --- oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu b/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu index 770b08c6..032c8220 100644 --- a/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu +++ b/oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu @@ -1,4 +1,5 @@ #include +#include #include #include "cuda_util.h" From 944d5596a4d1a265ef4994c9d24b0c427de37cf1 Mon Sep 17 00:00:00 2001 From: nijkah Date: Thu, 20 Jul 2023 12:00:37 +0000 Subject: [PATCH 4/5] relive cuda assertion --- oslo/lightseq2/training/ops/pytorch/builder/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oslo/lightseq2/training/ops/pytorch/builder/builder.py b/oslo/lightseq2/training/ops/pytorch/builder/builder.py index 329cb693..fb059fb7 100644 --- a/oslo/lightseq2/training/ops/pytorch/builder/builder.py +++ b/oslo/lightseq2/training/ops/pytorch/builder/builder.py @@ -60,7 +60,7 @@ def assert_no_cuda_mismatch(): torch_cuda_version = ".".join(torch.version.cuda.split(".")[:2]) # This is a show-stopping error, should probably not proceed past this if sys_cuda_version != torch_cuda_version: - if cuda_major == 11 and torch_cuda_version.split(".")[0] == "11": + if cuda_major in [11, 12] and torch_cuda_version.split(".")[0] == "11": # it works to build against installed cuda-11 while torch was built with cuda-11 return raise Exception( From 1cfd3ac88198ace4a030347e9e414d5d05adae36 Mon Sep 17 00:00:00 2001 From: nijkah Date: Thu, 20 Jul 2023 12:00:47 +0000 Subject: [PATCH 5/5] fix bug --- .../parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py index 22b81ca9..d9dde8a3 100644 --- a/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py +++ b/oslo/torch/nn/parallel/data_parallel/grad_scaler/dynamic_grad_scaler.py @@ -89,7 +89,7 @@ def update(self) -> None: if self._hysteresis_step >= self._hysteresis: self._backoff_scale() self.log( - f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", + f"Overflow occurs, the loss scale is adjusted to {self._scale.item()}", ranks=[0], ) else: @@ -100,7 +100,7 @@ def update(self) -> None: self._grow_scale() self.log( f"No overflow for consecutive {self._growth_interval} steps, " - f"the loss scale is adjusted to {self.scale.item()}", + f"the loss scale is adjusted to {self._scale.item()}", ranks=[0], )