I am noticing a very weird pattern with gradient accumulation on my end. Without gradient accumulation, the training proceeds smoothly, but when gradient accumulation is enabled, the optimization gets stuck at a loss which does not improve after that point. I am not sure what can be the underlying issue here, but a few info on the training run:
- The model is instantiated with the same seed in both the cases
- The dataloader is deterministic and produces the same sequences in both the cases
- The learning rate for different layers, and the weight decay for different layers is exactly the same.
Here is the optimizer definition:
def make_adamw_layerwise_like_nanochat(
params,
*,
d_model: int,
other_peak_lr: float,
other_min_lr: float,
total_train_steps: int,
warmup_steps: int = 30,
b1: float = 0.9,
b2: float = 0.95,
embedding_lr: float = 3e-3,
unembedding_lr: float = 3e-4,
):
"""
Parameter groups:
- params.embed -> AdamW with nanochat embedding LR (scaled by (d_model/768)^-0.5)
- params.lm_head -> AdamW with nanochat unembedding LR (scaled by (d_model/768)^-0.5)
- everything else -> AdamW with warmup+cosine schedule
"""
emb_lr = embedding_lr * other_peak_lr
unemb_lr = 1.0 * other_peak_lr
schedules = {
"embed": optax.constant_schedule(emb_lr),
"lm_head": optax.constant_schedule(unemb_lr),
"other": optax.warmup_cosine_decay_schedule(
init_value=other_min_lr,
peak_value=other_peak_lr,
warmup_steps=warmup_steps,
decay_steps=max(1, total_train_steps - warmup_steps),
end_value=other_min_lr,
),
}
def _path_names(path):
out = []
for k in path:
if isinstance(k, GetAttrKey):
out.append(k.name)
elif isinstance(k, SequenceKey):
out.append(str(k.idx))
elif isinstance(k, DictKey):
out.append(str(k.key))
else:
out.append(str(k))
return out
def label_fn(path, leaf):
# Top-level fields in your GPT pytree: embed, blocks, lm_head
names = _path_names(path)
top = names[0] if names else ""
if top == "embed":
return "embed"
if top == "lm_head":
return "lm_head"
return "other"
param_labels = jax.tree_util.tree_map_with_path(label_fn, params)
def make_adamw(lr_schedule, weight_decay=0.0):
return optax.adamw(
learning_rate=lr_schedule,
b1=b1,
b2=b2,
weight_decay=weight_decay,
mu_dtype=jnp.float32
)
def log_step_and_lr(label, schedule):
def init_fn(params):
return {"count": jnp.array(0, dtype=jnp.int32)}
def update_fn(updates, state, params=None):
step = state["count"]
lr = schedule(step)
jax.debug.print("[{label}] step {step} lr {lr}", label=label, step=step, lr=lr)
return updates, {"count": step + 1}
return optax.GradientTransformation(init_fn, update_fn)
tx = optax.multi_transform(
{
"embed": make_adamw(schedules["embed"]),
"lm_head": make_adamw(schedules["lm_head"]),
"other": make_adamw(schedules["other"], weight_decay=0.001),
},
param_labels,
)
return tx
Initialization of the optimier w/wo gradient accumulation:
# In case gradient accum is enabled. Default value for grad_acum_steps=4
optim = optax.MultiSteps(optim, every_k_schedule=grad_accum_steps)
optim_state = optim.init(model)
Here is how the optimizer state is updated inside the training loop:
@partial(jax.jit, static_argnames=("optim",), donate_argnums=(0, 4))
def train_step_accum(params, x_batch, y_batch, freqs, optim_state, optim):
(loss, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(
params, x_batch, y_batch, freqs
)
_, optim_state = optim.update(grads, optim_state, params)
return params, loss, optim_state
@partial(jax.jit, static_argnames=("optim",), donate_argnums=(0, 4))
def train_step(params, x_batch, y_batch, freqs, optim_state, optim):
(loss, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(
params, x_batch, y_batch, freqs
)
updates, optim_state = optim.update(grads, optim_state, params)
updated_params = optax.apply_updates(params, updates)
return updated_params, loss, optim_state
for step in range(total_train_steps):
...
for micro_step in range(grad_accum_steps):
x, y = get_next_batch(starts, ends, bsz, seqlen, tokens, data_sharding)
if micro_step < grad_accum_steps - 1:
model, loss, optim_state = train_step_accum(model, x, y, freqs, optim_state, optim)
train_step_loss += loss
else:
model, loss, optim_state = train_step(model, x, y, freqs, optim_state, optim)
train_step_loss += loss
avg_train_loss = train_step_loss / grad_accum_steps
...
Can someone please explain the reason behind this behavior? I would have expected a more stable training and faster convergence with gradient accumulation, but right now it is totally broken
I am noticing a very weird pattern with gradient accumulation on my end. Without gradient accumulation, the training proceeds smoothly, but when gradient accumulation is enabled, the optimization gets stuck at a loss which does not improve after that point. I am not sure what can be the underlying issue here, but a few info on the training run:
Here is the optimizer definition:
Initialization of the optimier w/wo gradient accumulation:
Here is how the optimizer state is updated inside the training loop:
Can someone please explain the reason behind this behavior? I would have expected a more stable training and faster convergence with gradient accumulation, but right now it is totally broken