Skip to content

schedule free adamw jax training results doesn't match pytorch #17

@wyfEmma

Description

@wyfEmma

The JAX implementation of the schedule-free AdamW algorithm exhibits significant training curve discrepancies compared to the PyTorch reference across several workloads, including Librispeech DeepSpeech, ImageNet, WMT, and Criteo1TB.

Following an initial debugging session with @priyakasimbeg, a key issue was identified: the JAX code incorrectly used a single variable (y) for both training and validation phases. The intended logic requires using x for validation and y for training.

A fix was issued in pull request #16. However, the training results for Librispeech still do not align with PyTorch, and other issues have emerged for WMT and Criteo1TB, specifically deadlocks and out-of-memory errors.

Further in-depth debugging is necessary to bring the JAX training results in line with PyTorch to finalize this MLCommons submission.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions