Skip to content

[DRAFT] Modify CosWithWarmupAndLinearDecay tail-handoff and missing super().__post_init__()#676

Draft
AkshitaB wants to merge 4 commits into
mainfrom
akshitab/cos-with-warmup-and-linear-decay-fix
Draft

[DRAFT] Modify CosWithWarmupAndLinearDecay tail-handoff and missing super().__post_init__()#676
AkshitaB wants to merge 4 commits into
mainfrom
akshitab/cos-with-warmup-and-linear-decay-fix

Conversation

@AkshitaB
Copy link
Copy Markdown
Contributor

@AkshitaB AkshitaB commented May 12, 2026

Summary

TODO: unclear if we actually want this behavior.

Rewrites the tail-handoff semantics of CosWithWarmupAndLinearDecay and fixes a missing super().__post_init__() call. Original work by @Tianhua-Tao; this PR ports just the scheduler portion onto main (preserving authorship/date) and adds tests + a CHANGELOG entry.

What's changed

1. Tail starts at the cosine's terminal value, not its mid-flight value

Before: The cosine ran across the full t_max. At step t_max - decay, the linear tail attached at whatever value the cosine happened to produce there (call that cos(t_max - decay, t_max) — a value somewhere between peak and alpha_f * peak, depending on horizon and alpha_f). The tail then anneals from that mid-flight value to decay_min_lr.

After: The cosine completes (reaches alpha_f * peak) by step t_max - decay. Achieved by passing t_max - decay to super().get_lr(...) for the cosine region, and hard-coding the tail's starting value to initial_lr * alpha_f. The linear tail then anneals from alpha_f * peak down to decay_min_lr over the last decay steps.

Net effect: the class now matches the schedule its name promises — "warmup → cosine completing to alpha_f * peak by t_max - decay → linear tail to decay_min_lr over decay steps."

2. Missing super().__post_init__() call

CosWithWarmupAndLinearDecay.__post_init__ validated its own fields but never invoked super().__post_init__(). This meant the parent's deprecated-alias migration (warmup_stepswarmup) was skipped on the subclass, and the parent's warmup_fraction range check ran only via duplicated logic in the child. Now super().__post_init__() is called at the end.

Tests

src/test/optim/scheduler_test.py gains:

  • test_cos_with_warmup_and_linear_decay — verifies the three-phase schedule. The cosine region matches a bare CosWithWarmup evaluated against t_max = (full_t_max - decay). The tail starts at alpha_f * peak at exactly t_max - decay, anneals linearly to decay_min_lr by t_max.
  • test_cos_with_warmup_and_linear_decay_migrates_deprecated_warmup_steps — confirms the super().__post_init__() call now propagates the warmup_stepswarmup migration and emits the deprecation warning.

34 tests pass; make checks clean (isort, black, ruff, mypy across all 381 source files).

Commits

SHA Author Subject
0dd3f4c5 Tianhua Tao Rewrite CosWithWarmupAndLinearDecay tail semantics (the core change)
f824210d Akshita Bhagia Clean up leftover # ... reference comments and trailing whitespace
b7ab9a43 Akshita Bhagia Add tests
8834192c Akshita Bhagia Update CHANGELOG

Compatibility note

Existing callers of CosWithWarmupAndLinearDecay will see different LR values in the tail region. If you have runs in flight that depend on the old "tail attaches mid-cosine" behavior, this is a behavior change — but the class name already implies the new semantics, and the old behavior was buggy in the sense that the tail start was implicit/non-obvious.

Test plan

  • pytest src/test/optim/scheduler_test.py — 34 passed
  • make checks (isort + black + ruff + mypy) — clean
  • Confirm no existing training scripts on main depend on the old tail-attachment behavior of CosWithWarmupAndLinearDecay

🤖 Generated with Claude Code

TianhuaTao and others added 4 commits May 12, 2026 15:11
Three changes:

1. Adds the missing super().__post_init__() call. Without it, the
   deprecated 'warmup_steps' alias never gets migrated on the subclass,
   and the parent's warmup-fraction validations are skipped.

2. Hard-codes the linear-decay-tail starting LR to initial_lr * alpha_f
   (rather than computing it via super().get_lr(initial_lr, t_max-decay,
   t_max)). Combined with (3), this guarantees the tail starts from the
   value the cosine *fully completes at*, regardless of any
   floating-point drift at the boundary.

3. Passes t_max - decay to super().get_lr(...) for the main cosine
   region, so the cosine completes (reaches alpha_f * peak) by step
   t_max - decay rather than running to t_max. The linear tail then
   carries from alpha_f * peak down to decay_min_lr over the remaining
   'decay' steps.

Net effect: 'warmup -> cosine fully completing to alpha_f*peak by t_max
- decay -> linear tail to decay_min_lr over decay steps'. Before, the
cosine ran across the full t_max and the linear tail started from
wherever the cosine happened to be at t_max - decay — i.e. the tail
attached mid-flight, not at the cosine's terminal value.

Originally introduced in cb9d4a3 ("add upcycle convert and train code")
on akshitab/olmoe-dev-v2; only the scheduler.py portion is carried over
here.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drops the leftover '# final_cosine_lr = super().get_lr(...)' and
'# return super().get_lr(...)' reference comments and a trailing-
whitespace line that came in with the rewrite, and expands the inline
comment to make the t_max-decay handoff semantics explicit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- test_cos_with_warmup_and_linear_decay: verifies the three-phase
  schedule end-to-end. The cosine region matches a bare CosWithWarmup
  evaluated against t_max = (full_t_max - decay). The tail starts at
  alpha_f * peak at exactly t_max - decay, anneals linearly to
  decay_min_lr by t_max.
- test_cos_with_warmup_and_linear_decay_migrates_deprecated_warmup_steps:
  confirms the new super().__post_init__() call propagates the deprecated
  'warmup_steps' alias migration into the subclass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@AkshitaB AkshitaB marked this pull request as draft May 12, 2026 22:27
@AkshitaB AkshitaB changed the title Fix CosWithWarmupAndLinearDecay tail-handoff and missing super().__post_init__() [DRAFT] Modify CosWithWarmupAndLinearDecay tail-handoff and missing super().__post_init__() May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants