Skip to content

Commit 698d324

Browse files
Merge pull request #19 from mlcommons/fineweb_edu
Add fineweb edu support for existing submissions
2 parents 24b369e + 859088f commit 698d324

6 files changed

Lines changed: 292 additions & 0 deletions

File tree

submissions/self_tuning/ademamix/submission.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ def get_batch_size(workload_name):
352352
return 128
353353
elif workload_name == 'mnist':
354354
return 16
355+
elif workload_name == 'finewebedu_lm':
356+
return 64
355357
else:
356358
raise ValueError(f'Unsupported workload name: {workload_name}.')
357359

submissions/self_tuning/lion/submission.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def get_batch_size(workload_name):
274274
return 128
275275
elif workload_name == 'mnist':
276276
return 16
277+
elif workload_name == 'finewebedu_lm':
278+
return 64
277279
else:
278280
raise ValueError(f'Unsupported workload name: {workload_name}.')
279281

submissions/self_tuning/schedule_free_adamw_v2/__init__.py

Whitespace-only changes.

submissions/self_tuning/schedule_free_adamw_v2/requirements_all.txt

Whitespace-only changes.
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import math
7+
from typing import Dict, Iterator, List, Tuple
8+
from absl import logging
9+
import torch
10+
import torch.distributed.nn as dist_nn
11+
import torch.distributed as dist
12+
import pickle
13+
from algoperf import spec
14+
from algoperf.pytorch_utils import pytorch_setup
15+
16+
USE_PYTORCH_DDP = pytorch_setup()[0]
17+
HPARAMS = {
18+
"learning_rate": 0.0025,
19+
"one_minus_beta1": 0.1,
20+
"beta2": 0.9955159689799007,
21+
"weight_decay": 0.08121616522670176,
22+
"warmup_factor": 0.02,
23+
"weight_lr_power": 2,
24+
"label_smoothing": 0.2,
25+
}
26+
27+
class AdamWScheduleFreeV2(torch.optim.Optimizer):
28+
r"""Schedule Free AdamW
29+
This version differs from the earlier algoperf submitted ScheduleFree
30+
in the following ways:
31+
- It more closely follows the published version of the method, the
32+
AlgoPerf submission was based on an early development version.
33+
- Weight decay is applied on the y sequence instead of the z sequence,
34+
following the published version. This doesn't improve the results but
35+
is more for consistency.
36+
- A r=0.5 weighting sequence is no longer used. This simplifies the
37+
implementation without any performance hit.
38+
39+
The following change was also made to the outer loop:
40+
- Batchnorm buffers are now correctly updated, greatly improving the
41+
performance on the ResNet and DeepSpeech workloads.
42+
"""
43+
def __init__(self, params,
44+
lr=1e-3,
45+
betas=(0.9, 0.999),
46+
eps=1e-8,
47+
weight_decay=0,
48+
weight_lr_power=2,
49+
warmup_steps=0,
50+
):
51+
defaults = dict(lr=lr,
52+
betas=betas,
53+
eps=eps,
54+
k=0,
55+
weight_sum=0.0,
56+
warmup_steps=warmup_steps,
57+
weight_lr_power=weight_lr_power,
58+
weight_decay=weight_decay)
59+
60+
super().__init__(params, defaults)
61+
62+
def step(self, closure):
63+
"""Performs a single optimization step.
64+
65+
Arguments:
66+
closure (callable, optional): A closure that reevaluates the model
67+
and returns the loss.
68+
"""
69+
# Swap to extrapolated point:
70+
for group in self.param_groups:
71+
beta1, beta2 = group['betas']
72+
k = group['k']
73+
74+
for p in group['params']:
75+
# State initialization
76+
state = self.state[p]
77+
if 'z' not in state:
78+
state['z'] = torch.clone(p.data)
79+
state['exp_avg_sq'] = torch.zeros_like(p.data)
80+
81+
z = state['z']
82+
83+
# Extrapolate
84+
p.data.lerp_(end=z, weight=1-beta1)
85+
86+
# Evaluate gradient at extrapolated point
87+
loss = closure()
88+
89+
for group in self.param_groups:
90+
eps = group['eps']
91+
k = group['k']
92+
warmup_steps = group['warmup_steps']
93+
decay = group['weight_decay']
94+
beta1, beta2 = group['betas']
95+
weight_lr_power = group['weight_lr_power']
96+
97+
if k < warmup_steps:
98+
sched = (k+1) / warmup_steps
99+
else:
100+
sched = 1.0
101+
annealed_lr = group['lr']*sched
102+
103+
lr = max(annealed_lr, eps)
104+
105+
weight = lr**weight_lr_power
106+
weight_sum = group['weight_sum'] = group['weight_sum'] + weight
107+
108+
ckp1 = weight/weight_sum
109+
110+
bias_correction2 = 1 - beta2 ** (k+1)
111+
step_size = lr * math.sqrt(bias_correction2)
112+
113+
for p in group['params']:
114+
if p.grad is None:
115+
continue
116+
grad = p.grad.data
117+
118+
state = self.state[p]
119+
120+
exp_avg_sq = state['exp_avg_sq']
121+
z = state['z']
122+
123+
y = p.data.clone()
124+
125+
# Unextrapolate (y -> x)
126+
p.data.lerp_(end=z, weight=1-1/beta1)
127+
128+
# Decay at y (differs from V1)
129+
z.sub_(y, alpha=step_size*decay)
130+
del y
131+
132+
# Decay the first and second moment running average coefficient
133+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
134+
# Apply bias correction to denominator (differs from V1)
135+
denom = exp_avg_sq.sqrt().add_(eps)
136+
137+
# Take step on z
138+
z.addcdiv_(grad, denom, value=-step_size)
139+
140+
### Take step on x
141+
p.data.lerp_(end=z, weight=ckp1)
142+
143+
group['k'] = k+1
144+
return loss
145+
146+
def init_optimizer_state(workload: spec.Workload,
147+
model_params: spec.ParameterContainer,
148+
model_state: spec.ModelAuxiliaryState,
149+
hyperparameters: spec.Hyperparameters,
150+
rng: spec.RandomState) -> spec.OptimizerState:
151+
del model_state
152+
153+
optimizer = AdamWScheduleFreeV2(
154+
model_params.parameters(),
155+
lr=HPARAMS['learning_rate'],
156+
betas=(1.0 - HPARAMS['one_minus_beta1'], HPARAMS['beta2']),
157+
warmup_steps=int(HPARAMS['warmup_factor'] * workload.step_hint * 0.75),
158+
weight_decay=HPARAMS['weight_decay'],
159+
weight_lr_power=HPARAMS['weight_lr_power'])
160+
161+
optimizer_state = {'optimizer':optimizer}
162+
return optimizer_state
163+
164+
def update_params(workload: spec.Workload,
165+
current_param_container: spec.ParameterContainer,
166+
current_params_types: spec.ParameterTypeTree,
167+
model_state: spec.ModelAuxiliaryState,
168+
hyperparameters: spec.Hyperparameters,
169+
batch: Dict[str, spec.Tensor],
170+
loss_type: spec.LossType,
171+
optimizer_state: spec.OptimizerState,
172+
eval_results: List[Tuple[int, float]],
173+
global_step: int,
174+
rng: spec.RandomState) -> spec.UpdateReturn:
175+
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
176+
del current_params_types
177+
del loss_type
178+
del hyperparameters
179+
180+
# Detect if BN
181+
current_model = current_param_container
182+
contains_bn = any(hasattr(m, "running_mean") for m in current_model.modules())
183+
184+
if contains_bn and global_step % 3 == 0:
185+
# Update batch-norm statistics at the eval point x not y.
186+
with torch.no_grad():
187+
_, new_model_state = workload.model_fn(
188+
params=current_model,
189+
augmented_and_preprocessed_input_batch=batch,
190+
model_state=model_state,
191+
mode=spec.ForwardPassMode.TRAIN,
192+
rng=rng,
193+
update_batch_norm=True)
194+
model_state = new_model_state
195+
196+
new_model_state = None
197+
198+
def closure():
199+
nonlocal new_model_state
200+
optimizer_state['optimizer'].zero_grad()
201+
202+
logits_batch, new_model_state = workload.model_fn(
203+
params=current_model,
204+
augmented_and_preprocessed_input_batch=batch,
205+
model_state=model_state,
206+
mode=spec.ForwardPassMode.TRAIN,
207+
rng=rng,
208+
update_batch_norm=False)
209+
210+
loss_dict = workload.loss_fn(
211+
label_batch=batch['targets'],
212+
logits_batch=logits_batch,
213+
mask_batch=batch.get('weights'),
214+
label_smoothing=HPARAMS['label_smoothing'])
215+
summed_loss = loss_dict['summed']
216+
n_valid_examples = loss_dict['n_valid_examples']
217+
if USE_PYTORCH_DDP:
218+
# Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
219+
summed_loss = dist_nn.all_reduce(summed_loss)
220+
n_valid_examples = dist_nn.all_reduce(n_valid_examples)
221+
loss = summed_loss / n_valid_examples
222+
223+
loss.backward()
224+
return loss
225+
226+
optimizer_state['optimizer'].step(closure)
227+
228+
return (optimizer_state, current_param_container, new_model_state)
229+
230+
def get_batch_size(workload_name):
231+
# Return the global batch size.
232+
if workload_name == 'criteo1tb':
233+
return 262_144
234+
elif workload_name == 'fastmri':
235+
return 16
236+
elif workload_name == 'imagenet_resnet':
237+
return 1024
238+
elif workload_name == 'imagenet_vit':
239+
return 1024
240+
elif workload_name == 'librispeech_conformer':
241+
return 224
242+
elif workload_name == 'librispeech_deepspeech':
243+
return 128
244+
elif workload_name == 'ogbg':
245+
return 512
246+
elif workload_name == 'wmt':
247+
return 128
248+
elif workload_name == 'mnist':
249+
return 16
250+
elif workload_name == 'imagenet_resnet_gelu':
251+
return 512
252+
elif workload_name == 'imagenet_resnet_silu':
253+
return 512
254+
elif workload_name == 'finewebedu_lm':
255+
return 64
256+
else:
257+
raise ValueError(f'Unsupported workload name: {workload_name}.')
258+
259+
def data_selection(workload: spec.Workload,
260+
input_queue: Iterator[Dict[str, spec.Tensor]],
261+
optimizer_state: spec.OptimizerState,
262+
current_param_container: spec.ParameterContainer,
263+
model_state: spec.ModelAuxiliaryState,
264+
hyperparameters: spec.Hyperparameters,
265+
global_step: int,
266+
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
267+
"""Select data from the infinitely repeating, pre-shuffled input queue.
268+
Each element of the queue is a batch of training examples and labels.
269+
"""
270+
del workload
271+
del optimizer_state
272+
del current_param_container
273+
del model_state
274+
del hyperparameters
275+
del global_step
276+
del rng
277+
batch = next(input_queue)
278+
return batch
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
submission_name: Schedule-Free AdamW
2+
submission_folder: submissions/self_tuning/schedule_free_adamw_v2
3+
authors: >-
4+
Alice Yang, Aaron Defazio, Konstantin Mishchenko
5+
affiliations: Meta AI, Samsung AI
6+
version: "2.0"
7+
ruleset: self-tuning
8+
framework: PyTorch
9+
description: >-
10+
A self-tuning version of Schedule Free AdamW ([Defazio et al., 2024](https://openreview.net/forum?id=0XeNkkENuI)) using a single hyperparameter configuration. Version 2.0 makes use of the batch-norm fixes made after the competition deadline, and which follows the published version of Schedule-Free in a few small details for consistency.

0 commit comments

Comments
 (0)