Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions oslo/lightseq2/csrc/kernels/cuda/cuda_util.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <thrust/device_vector.h>
#include <thrust/transform_reduce.h>
#include <thrust/reduce.h>

#include "cuda_util.h"
Expand Down
2 changes: 1 addition & 1 deletion oslo/lightseq2/training/ops/pytorch/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def assert_no_cuda_mismatch():
torch_cuda_version = ".".join(torch.version.cuda.split(".")[:2])
# This is a show-stopping error, should probably not proceed past this
if sys_cuda_version != torch_cuda_version:
if cuda_major == 11 and torch_cuda_version.split(".")[0] == "11":
if cuda_major in [11, 12] and torch_cuda_version.split(".")[0] == "11":
# it works to build against installed cuda-11 while torch was built with cuda-11
return
raise Exception(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base_grad_scaler import BaseGradScaler
from .dynamic_grad_scaler import DynamicGradScaler

__ALL__ = ["BaseGradScaler", "DynamicGradScaler"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from abc import ABC, abstractmethod
from typing import Dict

import torch
from torch import Tensor

from oslo.torch.utils.logging import get_dist_logger

__all__ = ["BaseGradScaler"]


class BaseGradScaler(ABC):
"""A base class for the gradient scaler.

Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""

def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose

if self._verbose:
self._logger = get_dist_logger()

# @property
def scale(self, loss) -> Tensor:
"""Returns the loss scale."""

return self._scale * loss

def step(self, optimizer):
self.found_inf = optimizer._check_overflow()
return optimizer.step()

@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale."""

return self._scale.double().reciprocal().float()

def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object."""

state_dict = dict()
state_dict["scale"] = self.scale
return state_dict

def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.

Args:
state_dict (dict): the states of the gradient scaler
"""

self._scale = state_dict["scale"]

@abstractmethod
def update(self) -> None:
"""Update the loss scale.

Args:
overflow (bool): whether overflow occurs
"""
pass

def log(self, message, *args, **kwargs):
"""Log messages.

Args:
message (str): the message to log
*args: positional arguments for :class:`oslo.torch.utils.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`oslo.torch.utils.logging.DistributedLogger`
"""

if self._verbose:
self._logger.info(message, *args, **kwargs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional

import torch

from .base_grad_scaler import BaseGradScaler

__all__ = ["DynamicGradScaler"]


class DynamicGradScaler(BaseGradScaler):
"""A gradient scaler which uses dynamic loss scale

Args:
initial_scale (float): the initial loss scale, defaults to 2**16
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
min_scale (float): the minimum loss scale, defaults to None
max_scale (float): the maximum loss scale, defaults to None
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
verbose (bool): whether to log messages, defaults to False
"""

def __init__(
self,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
verbose: bool = False,
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
else:
self._min_scale = None

if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale])
else:
self._max_scale = None

self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._growth_step = 0
self._hysteresis = hysteresis
self._hysteresis_step = 0
self._sanity_checks()

def _sanity_checks(self) -> None:
"""Check if the arguments are correct."""

if self._min_scale:
assert (
self._min_scale > 0
), "The minimum gradient scale cannot be zero or negative"
assert (
self._min_scale <= self._scale
), "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
assert (
self._max_scale > 0
), "The maximum gradient scale cannot be zero or negative"
assert (
self._max_scale >= self._scale
), "The maximum gradient scale cannot be smaller than the current scale"
assert (
self._growth_factor > 1
), "The growth factor cannot be equal or smaller than 1"
assert (
0 < self._backoff_factor < 1
), "The backoff factor must be between 0 and 1"
assert self._hysteresis >= 0, "The hysteresis cannot be negative"

def update(self) -> None:
"""Update the loss scale.

Args:
overflow (bool): whether overflow occurs
"""
if self.found_inf:
self._hysteresis_step += 1
self._growth_step = 0

if self._hysteresis_step >= self._hysteresis:
self._backoff_scale()
self.log(
f"Overflow occurs, the loss scale is adjusted to {self._scale.item()}",
ranks=[0],
)
else:
self._growth_step += 1
if self._growth_step == self._growth_interval:
self._growth_step = 0
self._hysteresis_step = 0
self._grow_scale()
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self._scale.item()}",
ranks=[0],
)

def _backoff_scale(self) -> None:
"""Decrease the loss scale"""

self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)

def _grow_scale(self) -> None:
"""Increase the loss scale"""

self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)

def state_dict(self):
state_dict = dict()
state_dict["scale"] = self._scale
state_dict["growth_factor"] = self._growth_factor
state_dict["backoff_factor"] = self._backoff_factor
state_dict["hysteresis"] = self._hysteresis
return state_dict

def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]