Skip to content

Commit 9a5a0d1

Browse files
Merge pull request #11 from tfaod/lion-baseline
[submission] lion with optimal hyperparameters (excl ogbg)
2 parents 075a310 + 8b6d5eb commit 9a5a0d1

3 files changed

Lines changed: 457 additions & 0 deletions

File tree

submissions/self_tuning/lion/__init__.py

Whitespace-only changes.
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
from __future__ import annotations
2+
3+
import collections
4+
from typing import Any, Dict, Iterator, List, Optional, Tuple
5+
6+
import torch
7+
import torch.distributed.nn as dist_nn
8+
from absl import logging
9+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
10+
from torch.optim.optimizer import Optimizer
11+
12+
from algoperf import spec
13+
from algoperf.pytorch_utils import pytorch_setup
14+
15+
USE_PYTORCH_DDP = pytorch_setup()[0]
16+
17+
# optimal parameters across all workloads, excluding ogbg
18+
HPARAMS = {
19+
'dropout_rate': 0.1,
20+
'learning_rate': 2e-4,
21+
'one_minus_beta1': 0.05,
22+
'beta2': 0.98,
23+
'weight_decay': 0.5,
24+
'warmup_factor': 0.02,
25+
}
26+
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
27+
28+
29+
# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py.
30+
class Lion(Optimizer):
31+
def __init__(
32+
self,
33+
params,
34+
lr: float = 1e-4,
35+
betas: Tuple[float, float] = (0.9, 0.99),
36+
weight_decay: float = 0.0,
37+
):
38+
if not 0.0 <= lr:
39+
raise ValueError('Invalid learning rate: {}'.format(lr))
40+
if not 0.0 <= betas[0] < 1.0:
41+
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
42+
if not 0.0 <= betas[1] < 1.0:
43+
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
44+
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
45+
super().__init__(params, defaults)
46+
47+
@torch.no_grad()
48+
def step(self, closure=None):
49+
"""Performs a single optimization step.
50+
51+
Args:
52+
closure (callable, optional): A closure that reevaluates the model
53+
and returns the loss.
54+
55+
Returns:
56+
the loss.
57+
"""
58+
loss = None
59+
if closure is not None:
60+
with torch.enable_grad():
61+
loss = closure()
62+
63+
for group in self.param_groups:
64+
for p in group['params']:
65+
if p.grad is None:
66+
continue
67+
68+
# Perform stepweight decay
69+
p.data.mul_(1 - group['lr'] * group['weight_decay'])
70+
71+
grad = p.grad
72+
state = self.state[p]
73+
# State initialization
74+
if len(state) == 0:
75+
# Exponential moving average of gradient values
76+
state['exp_avg'] = torch.zeros_like(p)
77+
78+
exp_avg = state['exp_avg']
79+
beta1, beta2 = group['betas']
80+
81+
# Weight update
82+
update = exp_avg * beta1 + grad * (1 - beta1)
83+
84+
p.add_(update.sign_(), alpha=-group['lr'])
85+
86+
# Decay the momentum running average coefficient
87+
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
88+
89+
return loss
90+
91+
92+
def init_optimizer_state(
93+
workload: spec.Workload,
94+
model_params: spec.ParameterContainer,
95+
model_state: spec.ModelAuxiliaryState,
96+
hyperparameters: spec.Hyperparameters,
97+
rng: spec.RandomState,
98+
) -> spec.OptimizerState:
99+
"""Creates a Lion optimizer and a learning rate schedule."""
100+
del model_state
101+
del rng
102+
del hyperparameters
103+
104+
hyperparameters = HPARAMS
105+
106+
optimizer_state = {
107+
'optimizer': Lion(
108+
model_params.parameters(),
109+
lr=HPARAMS.learning_rate,
110+
betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2),
111+
weight_decay=HPARAMS.weight_decay,
112+
)
113+
}
114+
115+
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
116+
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
117+
warmup = LinearLR(
118+
optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
119+
)
120+
cosine_steps = max(step_hint - warmup_steps, 1)
121+
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
122+
return SequentialLR(
123+
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
124+
)
125+
126+
optimizer_state['scheduler'] = pytorch_cosine_warmup(
127+
workload.step_hint, HPARAMS, optimizer_state['optimizer']
128+
)
129+
optimizer_state['hyperparameters'] = hyperparameters
130+
131+
return optimizer_state
132+
133+
134+
def update_params(
135+
workload: spec.Workload,
136+
current_param_container: spec.ParameterContainer,
137+
current_params_types: spec.ParameterTypeTree,
138+
model_state: spec.ModelAuxiliaryState,
139+
hyperparameters: spec.Hyperparameters,
140+
batch: Dict[str, spec.Tensor],
141+
loss_type: spec.LossType,
142+
optimizer_state: spec.OptimizerState,
143+
eval_results: List[Tuple[int, float]],
144+
global_step: int,
145+
rng: spec.RandomState,
146+
train_state: Optional[Dict[str, Any]] = None,
147+
) -> spec.UpdateReturn:
148+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
149+
del current_params_types
150+
del loss_type
151+
del train_state
152+
del eval_results
153+
del hyperparameters
154+
155+
hyperparameters = HPARAMS
156+
157+
current_model = current_param_container
158+
current_model.train()
159+
optimizer_state['optimizer'].zero_grad()
160+
161+
logits_batch, new_model_state = workload.model_fn(
162+
params=current_model,
163+
augmented_and_preprocessed_input_batch=batch,
164+
model_state=model_state,
165+
mode=spec.ForwardPassMode.TRAIN,
166+
rng=rng,
167+
update_batch_norm=True,
168+
)
169+
170+
label_smoothing = (
171+
hyperparameters.label_smoothing
172+
if hasattr(HPARAMS, 'label_smoothing')
173+
else 0.0
174+
)
175+
if hasattr(hyperparameters, 'grad_clip'):
176+
grad_clip = hyperparameters.grad_clip
177+
else:
178+
grad_clip = None
179+
180+
loss_dict = workload.loss_fn(
181+
label_batch=batch['targets'],
182+
logits_batch=logits_batch,
183+
mask_batch=batch.get('weights'),
184+
label_smoothing=label_smoothing,
185+
)
186+
summed_loss = loss_dict['summed']
187+
n_valid_examples = loss_dict['n_valid_examples']
188+
if USE_PYTORCH_DDP:
189+
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
190+
summed_loss = dist_nn.all_reduce(summed_loss)
191+
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
192+
loss = summed_loss / n_valid_examples
193+
194+
loss.backward()
195+
196+
if grad_clip is not None:
197+
torch.nn.utils.clip_grad_norm_(
198+
current_model.parameters(), max_norm=grad_clip
199+
)
200+
optimizer_state['optimizer'].step()
201+
optimizer_state['scheduler'].step()
202+
203+
# Log training metrics - loss, grad_norm, batch_size.
204+
if global_step <= 100 or global_step % 500 == 0:
205+
with torch.no_grad():
206+
parameters = [p for p in current_model.parameters() if p.grad is not None]
207+
grad_norm = torch.norm(
208+
torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
209+
)
210+
if workload.metrics_logger is not None:
211+
workload.metrics_logger.append_scalar_metrics(
212+
{
213+
'loss': loss.item(),
214+
'grad_norm': grad_norm.item(),
215+
},
216+
global_step,
217+
)
218+
logging.info(
219+
'%d) loss = %0.3f, grad_norm = %0.3f',
220+
global_step,
221+
loss.item(),
222+
grad_norm.item(),
223+
)
224+
225+
return (optimizer_state, current_param_container, new_model_state)
226+
227+
228+
def prepare_for_eval(
229+
workload: spec.Workload,
230+
current_param_container: spec.ParameterContainer,
231+
current_params_types: spec.ParameterTypeTree,
232+
model_state: spec.ModelAuxiliaryState,
233+
hyperparameters: spec.Hyperparameters,
234+
loss_type: spec.LossType,
235+
optimizer_state: spec.OptimizerState,
236+
eval_results: List[Tuple[int, float]],
237+
global_step: int,
238+
rng: spec.RandomState,
239+
) -> spec.UpdateReturn:
240+
"""Return (updated_optimizer_state, updated_params)."""
241+
del workload
242+
del hyperparameters
243+
del current_params_types
244+
del loss_type
245+
del eval_results
246+
del global_step
247+
del rng
248+
return (optimizer_state, current_param_container, model_state)
249+
250+
251+
def get_batch_size(workload_name):
252+
# Return the global batch size.
253+
if hasattr(HPARAMS, 'batch_size'):
254+
return HPARAMS.batch_size
255+
if workload_name == 'criteo1tb':
256+
return 262_144
257+
elif workload_name == 'fastmri':
258+
return 32
259+
elif workload_name == 'imagenet_resnet':
260+
return 1024
261+
elif workload_name == 'imagenet_resnet_silu':
262+
return 512
263+
elif workload_name == 'imagenet_resnet_gelu':
264+
return 512
265+
elif workload_name == 'imagenet_vit':
266+
return 1024
267+
elif workload_name == 'librispeech_conformer':
268+
return 256
269+
elif workload_name == 'librispeech_deepspeech':
270+
return 256
271+
elif workload_name == 'ogbg':
272+
return 512
273+
elif workload_name == 'wmt':
274+
return 128
275+
elif workload_name == 'mnist':
276+
return 16
277+
else:
278+
raise ValueError(f'Unsupported workload name: {workload_name}.')
279+
280+
281+
def data_selection(
282+
workload: spec.Workload,
283+
input_queue: Iterator[Dict[str, spec.Tensor]],
284+
optimizer_state: spec.OptimizerState,
285+
current_param_container: spec.ParameterContainer,
286+
model_state: spec.ModelAuxiliaryState,
287+
hyperparameters: spec.Hyperparameters,
288+
global_step: int,
289+
rng: spec.RandomState,
290+
) -> Dict[str, spec.Tensor]:
291+
"""Select data from the infinitely repeating, pre-shuffled input queue.
292+
Each element of the queue is a batch of training examples and labels.
293+
"""
294+
del workload
295+
del optimizer_state
296+
del current_param_container
297+
del model_state
298+
del hyperparameters
299+
del global_step
300+
del rng
301+
batch = next(input_queue)
302+
return batch

0 commit comments

Comments
 (0)