|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import math |
| 7 | +from typing import Dict, Iterator, List, Tuple |
| 8 | +from absl import logging |
| 9 | +import torch |
| 10 | +import torch.distributed.nn as dist_nn |
| 11 | +import torch.distributed as dist |
| 12 | +import pickle |
| 13 | +from algoperf import spec |
| 14 | +from algoperf.pytorch_utils import pytorch_setup |
| 15 | + |
| 16 | +USE_PYTORCH_DDP = pytorch_setup()[0] |
| 17 | +HPARAMS = { |
| 18 | + "learning_rate": 0.0025, |
| 19 | + "one_minus_beta1": 0.1, |
| 20 | + "beta2": 0.9955159689799007, |
| 21 | + "weight_decay": 0.08121616522670176, |
| 22 | + "warmup_factor": 0.02, |
| 23 | + "weight_lr_power": 2, |
| 24 | + "label_smoothing": 0.2, |
| 25 | +} |
| 26 | + |
| 27 | +class AdamWScheduleFreeV2(torch.optim.Optimizer): |
| 28 | + r"""Schedule Free AdamW |
| 29 | + This version differs from the earlier algoperf submitted ScheduleFree |
| 30 | + in the following ways: |
| 31 | + - It more closely follows the published version of the method, the |
| 32 | + AlgoPerf submission was based on an early development version. |
| 33 | + - Weight decay is applied on the y sequence instead of the z sequence, |
| 34 | + following the published version. This doesn't improve the results but |
| 35 | + is more for consistency. |
| 36 | + - A r=0.5 weighting sequence is no longer used. This simplifies the |
| 37 | + implementation without any performance hit. |
| 38 | +
|
| 39 | + The following change was also made to the outer loop: |
| 40 | + - Batchnorm buffers are now correctly updated, greatly improving the |
| 41 | + performance on the ResNet and DeepSpeech workloads. |
| 42 | + """ |
| 43 | + def __init__(self, params, |
| 44 | + lr=1e-3, |
| 45 | + betas=(0.9, 0.999), |
| 46 | + eps=1e-8, |
| 47 | + weight_decay=0, |
| 48 | + weight_lr_power=2, |
| 49 | + warmup_steps=0, |
| 50 | + ): |
| 51 | + defaults = dict(lr=lr, |
| 52 | + betas=betas, |
| 53 | + eps=eps, |
| 54 | + k=0, |
| 55 | + weight_sum=0.0, |
| 56 | + warmup_steps=warmup_steps, |
| 57 | + weight_lr_power=weight_lr_power, |
| 58 | + weight_decay=weight_decay) |
| 59 | + |
| 60 | + super().__init__(params, defaults) |
| 61 | + |
| 62 | + def step(self, closure): |
| 63 | + """Performs a single optimization step. |
| 64 | +
|
| 65 | + Arguments: |
| 66 | + closure (callable, optional): A closure that reevaluates the model |
| 67 | + and returns the loss. |
| 68 | + """ |
| 69 | + # Swap to extrapolated point: |
| 70 | + for group in self.param_groups: |
| 71 | + beta1, beta2 = group['betas'] |
| 72 | + k = group['k'] |
| 73 | + |
| 74 | + for p in group['params']: |
| 75 | + # State initialization |
| 76 | + state = self.state[p] |
| 77 | + if 'z' not in state: |
| 78 | + state['z'] = torch.clone(p.data) |
| 79 | + state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 80 | + |
| 81 | + z = state['z'] |
| 82 | + |
| 83 | + # Extrapolate |
| 84 | + p.data.lerp_(end=z, weight=1-beta1) |
| 85 | + |
| 86 | + # Evaluate gradient at extrapolated point |
| 87 | + loss = closure() |
| 88 | + |
| 89 | + for group in self.param_groups: |
| 90 | + eps = group['eps'] |
| 91 | + k = group['k'] |
| 92 | + warmup_steps = group['warmup_steps'] |
| 93 | + decay = group['weight_decay'] |
| 94 | + beta1, beta2 = group['betas'] |
| 95 | + weight_lr_power = group['weight_lr_power'] |
| 96 | + |
| 97 | + if k < warmup_steps: |
| 98 | + sched = (k+1) / warmup_steps |
| 99 | + else: |
| 100 | + sched = 1.0 |
| 101 | + annealed_lr = group['lr']*sched |
| 102 | + |
| 103 | + lr = max(annealed_lr, eps) |
| 104 | + |
| 105 | + weight = lr**weight_lr_power |
| 106 | + weight_sum = group['weight_sum'] = group['weight_sum'] + weight |
| 107 | + |
| 108 | + ckp1 = weight/weight_sum |
| 109 | + |
| 110 | + bias_correction2 = 1 - beta2 ** (k+1) |
| 111 | + step_size = lr * math.sqrt(bias_correction2) |
| 112 | + |
| 113 | + for p in group['params']: |
| 114 | + if p.grad is None: |
| 115 | + continue |
| 116 | + grad = p.grad.data |
| 117 | + |
| 118 | + state = self.state[p] |
| 119 | + |
| 120 | + exp_avg_sq = state['exp_avg_sq'] |
| 121 | + z = state['z'] |
| 122 | + |
| 123 | + y = p.data.clone() |
| 124 | + |
| 125 | + # Unextrapolate (y -> x) |
| 126 | + p.data.lerp_(end=z, weight=1-1/beta1) |
| 127 | + |
| 128 | + # Decay at y (differs from V1) |
| 129 | + z.sub_(y, alpha=step_size*decay) |
| 130 | + del y |
| 131 | + |
| 132 | + # Decay the first and second moment running average coefficient |
| 133 | + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) |
| 134 | + # Apply bias correction to denominator (differs from V1) |
| 135 | + denom = exp_avg_sq.sqrt().add_(eps) |
| 136 | + |
| 137 | + # Take step on z |
| 138 | + z.addcdiv_(grad, denom, value=-step_size) |
| 139 | + |
| 140 | + ### Take step on x |
| 141 | + p.data.lerp_(end=z, weight=ckp1) |
| 142 | + |
| 143 | + group['k'] = k+1 |
| 144 | + return loss |
| 145 | + |
| 146 | +def init_optimizer_state(workload: spec.Workload, |
| 147 | + model_params: spec.ParameterContainer, |
| 148 | + model_state: spec.ModelAuxiliaryState, |
| 149 | + hyperparameters: spec.Hyperparameters, |
| 150 | + rng: spec.RandomState) -> spec.OptimizerState: |
| 151 | + del model_state |
| 152 | + |
| 153 | + optimizer = AdamWScheduleFreeV2( |
| 154 | + model_params.parameters(), |
| 155 | + lr=HPARAMS['learning_rate'], |
| 156 | + betas=(1.0 - HPARAMS['one_minus_beta1'], HPARAMS['beta2']), |
| 157 | + warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75), |
| 158 | + weight_decay=HPARAMS['weight_decay'], |
| 159 | + weight_lr_power=HPARAMS['weight_lr_power']) |
| 160 | + |
| 161 | + optimizer_state = {'optimizer':optimizer} |
| 162 | + return optimizer_state |
| 163 | + |
| 164 | +def update_params(workload: spec.Workload, |
| 165 | + current_param_container: spec.ParameterContainer, |
| 166 | + current_params_types: spec.ParameterTypeTree, |
| 167 | + model_state: spec.ModelAuxiliaryState, |
| 168 | + hyperparameters: spec.Hyperparameters, |
| 169 | + batch: Dict[str, spec.Tensor], |
| 170 | + loss_type: spec.LossType, |
| 171 | + optimizer_state: spec.OptimizerState, |
| 172 | + eval_results: List[Tuple[int, float]], |
| 173 | + global_step: int, |
| 174 | + rng: spec.RandomState) -> spec.UpdateReturn: |
| 175 | + """Return (updated_optimizer_state, updated_params, updated_model_state).""" |
| 176 | + del current_params_types |
| 177 | + del loss_type |
| 178 | + del hyperparameters |
| 179 | + |
| 180 | + # Detect if BN |
| 181 | + current_model = current_param_container |
| 182 | + contains_bn = any(hasattr(m, "running_mean") for m in current_model.modules()) |
| 183 | + |
| 184 | + if contains_bn and global_step % 3 == 0: |
| 185 | + # Update batch-norm statistics at the eval point x not y. |
| 186 | + with torch.no_grad(): |
| 187 | + _, new_model_state = workload.model_fn( |
| 188 | + params=current_model, |
| 189 | + augmented_and_preprocessed_input_batch=batch, |
| 190 | + model_state=model_state, |
| 191 | + mode=spec.ForwardPassMode.TRAIN, |
| 192 | + rng=rng, |
| 193 | + update_batch_norm=True) |
| 194 | + model_state = new_model_state |
| 195 | + |
| 196 | + new_model_state = None |
| 197 | + |
| 198 | + def closure(): |
| 199 | + nonlocal new_model_state |
| 200 | + optimizer_state['optimizer'].zero_grad() |
| 201 | + |
| 202 | + logits_batch, new_model_state = workload.model_fn( |
| 203 | + params=current_model, |
| 204 | + augmented_and_preprocessed_input_batch=batch, |
| 205 | + model_state=model_state, |
| 206 | + mode=spec.ForwardPassMode.TRAIN, |
| 207 | + rng=rng, |
| 208 | + update_batch_norm=False) |
| 209 | + |
| 210 | + loss_dict = workload.loss_fn( |
| 211 | + label_batch=batch['targets'], |
| 212 | + logits_batch=logits_batch, |
| 213 | + mask_batch=batch.get('weights'), |
| 214 | + label_smoothing=HPARAMS['label_smoothing']) |
| 215 | + summed_loss = loss_dict['summed'] |
| 216 | + n_valid_examples = loss_dict['n_valid_examples'] |
| 217 | + if USE_PYTORCH_DDP: |
| 218 | + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. |
| 219 | + summed_loss = dist_nn.all_reduce(summed_loss) |
| 220 | + n_valid_examples = dist_nn.all_reduce(n_valid_examples) |
| 221 | + loss = summed_loss / n_valid_examples |
| 222 | + |
| 223 | + loss.backward() |
| 224 | + return loss |
| 225 | + |
| 226 | + optimizer_state['optimizer'].step(closure) |
| 227 | + |
| 228 | + return (optimizer_state, current_param_container, new_model_state) |
| 229 | + |
| 230 | +def get_batch_size(workload_name): |
| 231 | + # Return the global batch size. |
| 232 | + if workload_name == 'criteo1tb': |
| 233 | + return 262_144 |
| 234 | + elif workload_name == 'fastmri': |
| 235 | + return 16 |
| 236 | + elif workload_name == 'imagenet_resnet': |
| 237 | + return 1024 |
| 238 | + elif workload_name == 'imagenet_vit': |
| 239 | + return 1024 |
| 240 | + elif workload_name == 'librispeech_conformer': |
| 241 | + return 224 |
| 242 | + elif workload_name == 'librispeech_deepspeech': |
| 243 | + return 128 |
| 244 | + elif workload_name == 'ogbg': |
| 245 | + return 512 |
| 246 | + elif workload_name == 'wmt': |
| 247 | + return 128 |
| 248 | + elif workload_name == 'mnist': |
| 249 | + return 16 |
| 250 | + elif workload_name == 'imagenet_resnet_gelu': |
| 251 | + return 512 |
| 252 | + elif workload_name == 'imagenet_resnet_silu': |
| 253 | + return 512 |
| 254 | + elif workload_name == 'finewebedu_lm': |
| 255 | + return 64 |
| 256 | + else: |
| 257 | + raise ValueError(f'Unsupported workload name: {workload_name}.') |
| 258 | + |
| 259 | +def data_selection(workload: spec.Workload, |
| 260 | + input_queue: Iterator[Dict[str, spec.Tensor]], |
| 261 | + optimizer_state: spec.OptimizerState, |
| 262 | + current_param_container: spec.ParameterContainer, |
| 263 | + model_state: spec.ModelAuxiliaryState, |
| 264 | + hyperparameters: spec.Hyperparameters, |
| 265 | + global_step: int, |
| 266 | + rng: spec.RandomState) -> Dict[str, spec.Tensor]: |
| 267 | + """Select data from the infinitely repeating, pre-shuffled input queue. |
| 268 | + Each element of the queue is a batch of training examples and labels. |
| 269 | + """ |
| 270 | + del workload |
| 271 | + del optimizer_state |
| 272 | + del current_param_container |
| 273 | + del model_state |
| 274 | + del hyperparameters |
| 275 | + del global_step |
| 276 | + del rng |
| 277 | + batch = next(input_queue) |
| 278 | + return batch |
0 commit comments