Skip to content

Gradient Accumulation Results #1567

@AakashKumarNain

Description

@AakashKumarNain

I am noticing a very weird pattern with gradient accumulation on my end. Without gradient accumulation, the training proceeds smoothly, but when gradient accumulation is enabled, the optimization gets stuck at a loss which does not improve after that point. I am not sure what can be the underlying issue here, but a few info on the training run:

  1. The model is instantiated with the same seed in both the cases
  2. The dataloader is deterministic and produces the same sequences in both the cases
  3. The learning rate for different layers, and the weight decay for different layers is exactly the same.

Here is the optimizer definition:

def make_adamw_layerwise_like_nanochat(
    params,
    *,
    d_model: int,                    
    other_peak_lr: float,           
    other_min_lr: float,            
    total_train_steps: int,
    warmup_steps: int = 30,
    b1: float = 0.9,                
    b2: float = 0.95,           
    embedding_lr: float = 3e-3,
    unembedding_lr: float = 3e-4,

):
    """
    Parameter groups:
      - params.embed      -> AdamW with nanochat embedding LR (scaled by (d_model/768)^-0.5)
      - params.lm_head    -> AdamW with nanochat unembedding LR (scaled by (d_model/768)^-0.5)
      - everything else   -> AdamW with warmup+cosine schedule
    """
    emb_lr = embedding_lr * other_peak_lr
    unemb_lr = 1.0 * other_peak_lr


    schedules = {
        "embed": optax.constant_schedule(emb_lr),
        "lm_head": optax.constant_schedule(unemb_lr),
        "other": optax.warmup_cosine_decay_schedule(
            init_value=other_min_lr,
            peak_value=other_peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=max(1, total_train_steps - warmup_steps),
            end_value=other_min_lr,
        ),
    }

    def _path_names(path):
        out = []
        for k in path:
            if isinstance(k, GetAttrKey):
                out.append(k.name)
            elif isinstance(k, SequenceKey):
                out.append(str(k.idx))
            elif isinstance(k, DictKey):
                out.append(str(k.key))
            else:
                out.append(str(k))
        return out

    def label_fn(path, leaf):
        # Top-level fields in your GPT pytree: embed, blocks, lm_head
        names = _path_names(path)
        top = names[0] if names else ""
        if top == "embed":
            return "embed"
        if top == "lm_head":
            return "lm_head"
        return "other"

    param_labels = jax.tree_util.tree_map_with_path(label_fn, params)

    def make_adamw(lr_schedule, weight_decay=0.0):
        return optax.adamw(
            learning_rate=lr_schedule,
            b1=b1,
            b2=b2,
            weight_decay=weight_decay,
            mu_dtype=jnp.float32
        )

    def log_step_and_lr(label, schedule):
        def init_fn(params):
            return {"count": jnp.array(0, dtype=jnp.int32)}
        def update_fn(updates, state, params=None):
            step = state["count"]
            lr = schedule(step)
            jax.debug.print("[{label}] step {step} lr {lr}", label=label, step=step, lr=lr)
            return updates, {"count": step + 1}
        return optax.GradientTransformation(init_fn, update_fn)

    tx = optax.multi_transform(
        {
            "embed": make_adamw(schedules["embed"]),
            "lm_head": make_adamw(schedules["lm_head"]),
            "other": make_adamw(schedules["other"], weight_decay=0.001),
        },
        param_labels,
    )
    return tx

Initialization of the optimier w/wo gradient accumulation:

# In case gradient accum is enabled. Default value for grad_acum_steps=4
optim = optax.MultiSteps(optim, every_k_schedule=grad_accum_steps)
optim_state = optim.init(model)

Here is how the optimizer state is updated inside the training loop:

@partial(jax.jit, static_argnames=("optim",), donate_argnums=(0, 4))
def train_step_accum(params, x_batch, y_batch, freqs, optim_state, optim):
    (loss, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(
        params, x_batch, y_batch, freqs
    )
    _, optim_state = optim.update(grads, optim_state, params)
    return params, loss, optim_state


@partial(jax.jit, static_argnames=("optim",), donate_argnums=(0, 4))
def train_step(params, x_batch, y_batch, freqs, optim_state, optim):
    (loss, logits), grads = jax.value_and_grad(compute_loss, has_aux=True)(
        params, x_batch, y_batch, freqs
    )
    updates, optim_state = optim.update(grads, optim_state, params)
    updated_params = optax.apply_updates(params, updates)
    return updated_params, loss, optim_state


for step in range(total_train_steps):
    ...
    for micro_step in range(grad_accum_steps):
        x, y = get_next_batch(starts, ends, bsz, seqlen, tokens, data_sharding)
        if micro_step < grad_accum_steps - 1:
            model, loss, optim_state = train_step_accum(model, x, y, freqs, optim_state, optim)
            train_step_loss += loss
        else:
            model, loss, optim_state = train_step(model, x, y, freqs, optim_state, optim)
            train_step_loss += loss
    avg_train_loss = train_step_loss / grad_accum_steps
    ...

Can someone please explain the reason behind this behavior? I would have expected a more stable training and faster convergence with gradient accumulation, but right now it is totally broken

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions