The JAX implementation of the schedule-free AdamW algorithm exhibits significant training curve discrepancies compared to the PyTorch reference across several workloads, including Librispeech DeepSpeech, ImageNet, WMT, and Criteo1TB.
Following an initial debugging session with @priyakasimbeg, a key issue was identified: the JAX code incorrectly used a single variable (y) for both training and validation phases. The intended logic requires using x for validation and y for training.
A fix was issued in pull request #16. However, the training results for Librispeech still do not align with PyTorch, and other issues have emerged for WMT and Criteo1TB, specifically deadlocks and out-of-memory errors.
Further in-depth debugging is necessary to bring the JAX training results in line with PyTorch to finalize this MLCommons submission.
The JAX implementation of the schedule-free AdamW algorithm exhibits significant training curve discrepancies compared to the PyTorch reference across several workloads, including Librispeech DeepSpeech, ImageNet, WMT, and Criteo1TB.
Following an initial debugging session with @priyakasimbeg, a key issue was identified: the JAX code incorrectly used a single variable (y) for both training and validation phases. The intended logic requires using x for validation and y for training.
A fix was issued in pull request #16. However, the training results for Librispeech still do not align with PyTorch, and other issues have emerged for WMT and Criteo1TB, specifically deadlocks and out-of-memory errors.
Further in-depth debugging is necessary to bring the JAX training results in line with PyTorch to finalize this MLCommons submission.