Skip to content

[Feat] Add FP8 training support#758

Merged
rchardx merged 58 commits intomainfrom
sxj/fp8_train
Dec 31, 2025
Merged

[Feat] Add FP8 training support#758
rchardx merged 58 commits intomainfrom
sxj/fp8_train

Conversation

@fishcrap
Copy link
Copy Markdown
Collaborator

@fishcrap fishcrap commented Dec 23, 2025

Description

This PR adds comprehensive FP8 (8-bit floating point) training support to AReaL, enabling memory-efficient training with low precision while maintaining training stability. The implementation includes:

  • FP8 quantization/dequantization utilities: New fp8_utils.py and fp8_kernels.py modules providing blockwise quantization support
  • CLI configuration: Extended TrainEngineConfig with FP8-related options (fp8 mode, recipe, parameter quantization, etc.)
  • Model loading/saving: Updated HuggingFace model loading and saving to handle FP8 weights with proper conversion between PyTorch FP8 and Transformer Engine FP8 formats
  • Megatron engine integration: Enhanced MegatronEngine to support FP8 training with proper configuration propagation
  • Comprehensive test suite: Added extensive tests for FP8 conversion, BF16 comparison, and gradient correctness

The implementation supports the blockwise scheme, with integration into Transformer Engine's FP8 infrastructure for efficient GEMM operations.

Related Issue

Fixes #(issue)

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A - This is a new feature that adds optional FP8 support without breaking existing functionality.

Additional Context

Training Curve

  • reward (fp8 vs bf16)
image

TODO:

  • Memory profiling
  • Training time reduction

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @fishcrap, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades AReaL by integrating comprehensive FP8 training support. The primary goal is to enable memory-efficient training with reduced precision without compromising model stability. This is achieved through the introduction of new FP8 quantization and dequantization utilities, extensive configuration options via the CLI, and updates to model loading and saving processes to handle FP8 weights. The core MegatronEngine has been adapted to leverage these FP8 capabilities, and new tests ensure the reliability of these low-precision operations.

Highlights

  • Comprehensive FP8 Training Support: This PR introduces full 8-bit floating point (FP8) training capabilities to AReaL, enabling more memory-efficient training while striving to maintain training stability.
  • FP8 Quantization Utilities: New modules fp8_utils.py and fp8_kernels.py have been added, providing blockwise and per-tensor quantization and dequantization functionalities, including Triton-based kernels for efficient operations.
  • Extended CLI Configuration: The TrainEngineConfig and MegatronEngineConfig have been significantly extended with numerous FP8-related options, allowing users to configure FP8 mode, scaling recipes, parameter quantization, and other precision-related settings via the command-line interface.
  • Enhanced Model Loading and Saving: HuggingFace model loading and saving mechanisms have been updated to correctly handle FP8 weights, including proper conversion between PyTorch FP8 and Transformer Engine FP8 formats, and dequantization when necessary.
  • MegatronEngine Integration: The MegatronEngine has been enhanced to seamlessly support FP8 training, ensuring that FP8 configurations are correctly propagated and applied throughout the training process.
  • New Test Suite: A comprehensive test suite (test_fp8_conversion.py) has been added to verify the correctness of FP8 conversion, compare results with BF16 baselines, and ensure gradient accuracy.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive FP8 training support, including new utilities for quantization/dequantization, CLI configurations, and updates to model loading/saving to handle FP8 weights. The changes are extensive and well-structured. I've identified a few areas with TODO or FIXME comments in the new code, particularly in tests and utility functions, that should be addressed to ensure correctness and clarity. The overall implementation seems robust, with good integration into the existing MegatronEngine and the addition of a comprehensive test suite.

Comment thread areal/models/mcore/hf_load.py Outdated
Comment thread areal/engine/megatron_engine.py
Comment thread areal/tests/test_fp8_conversion.py Outdated
Comment thread areal/utils/fp8_utils.py Outdated
Comment thread areal/utils/megatron.py Outdated
Comment thread areal/utils/mcore/pipeline_parallel.py
@fishcrap fishcrap changed the title Sxj/fp8 train [Feat] Add FP8 training support Dec 24, 2025
@fishcrap fishcrap marked this pull request as ready for review December 25, 2025 12:07
@rchardx
Copy link
Copy Markdown
Collaborator

rchardx commented Dec 26, 2025

I think end-to-end training testcases should be added to areal/tests/grpo/ or areal/tests/sft/ through new yaml configurations and new test entries.

@fishcrap fishcrap requested review from garrett4wade and removed request for garrett4wade December 26, 2025 05:36
Comment thread areal/models/mcore/hf_save.py
Comment thread areal/api/cli_args.py Outdated
Comment thread areal/utils/fp8_utils.py Outdated
Comment thread areal/tests/fp8/test_fp8_rmsnorm.py
@garrett4wade
Copy link
Copy Markdown
Collaborator

I think end-to-end training testcases should be added to areal/tests/grpo/ or areal/tests/sft/ through new yaml configurations and new test entries.

@rchardx It would be good but the test won't run in CI A100 nodes. We can just run them offline.

Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core functionality looks solid. Once these changes (including others made ones) are addressed, this PR should be ready to merge.

Comment thread areal/models/mcore/hf_load.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/engine/megatron_engine.py
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/utils/fp8_utils.py Outdated
Comment thread areal/utils/fp8_utils.py Outdated
Comment thread areal/api/cli_args.py Outdated
Comment thread areal/models/mcore/hf_load.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/api/cli_args.py Outdated
@rchardx
Copy link
Copy Markdown
Collaborator

rchardx commented Dec 27, 2025

I think end-to-end training testcases should be added to areal/tests/grpo/ or areal/tests/sft/ through new yaml configurations and new test entries.

@rchardx It would be good but the test won't run in CI A100 nodes. We can just run them offline.

Agreed.

Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/utils/megatron.py Outdated
Comment thread areal/engine/megatron_engine.py Outdated
fishcrap and others added 5 commits December 30, 2025 11:53
- Move fp8 utilities to areal/utils/fp8/ with clearer module separation
- Implement UE8M0 quantization locally, eliminating sglang import
- Extract common utils: areal/utils/math.py, areal/utils/cuda.py
- Improve constants.py organization and naming
- Clarify high_precision_init_val comment for FP8 HF model loading

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
- Use lazy initialization for DeepGEMM detection to avoid import-time
  CUDA access failures on CPU-only environments
- Add informative error message for UE8M0 block size assertion
- Document FP8 E4M3 max value (448.0) in quantization code
Copy link
Copy Markdown
Collaborator

@rchardx rchardx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rchardx rchardx merged commit 89dda13 into main Dec 31, 2025
1 check passed
@rchardx rchardx deleted the sxj/fp8_train branch December 31, 2025 06:33
leandermaben pushed a commit to leandermaben/AReaL that referenced this pull request Mar 24, 2026
This PR adds comprehensive FP8 (8-bit floating point) training support to AReaL, enabling memory-efficient training with low precision while maintaining training stability. The implementation includes:

- **FP8 quantization/dequantization utilities**: New `fp8_utils.py` and `fp8_kernels.py` modules providing blockwise  quantization support
- **CLI configuration**: Extended `TrainEngineConfig` with FP8-related options (fp8 mode, recipe, parameter quantization, etc.)
- **Model loading/saving**: Updated HuggingFace model loading and saving to handle FP8 weights with proper conversion between PyTorch FP8 and Transformer Engine FP8 formats
- **Megatron engine integration**: Enhanced `MegatronEngine` to support FP8 training with proper configuration propagation
- **Comprehensive test suite**: Added extensive tests for FP8 conversion, BF16 comparison, and gradient correctness

The implementation supports the blockwise scheme, with integration into Transformer Engine's FP8 infrastructure for efficient GEMM operations.
SathyaGnanakumar pushed a commit to danielkiely/AReaL that referenced this pull request Apr 29, 2026
This PR adds comprehensive FP8 (8-bit floating point) training support to AReaL, enabling memory-efficient training with low precision while maintaining training stability. The implementation includes:

- **FP8 quantization/dequantization utilities**: New `fp8_utils.py` and `fp8_kernels.py` modules providing blockwise  quantization support
- **CLI configuration**: Extended `TrainEngineConfig` with FP8-related options (fp8 mode, recipe, parameter quantization, etc.)
- **Model loading/saving**: Updated HuggingFace model loading and saving to handle FP8 weights with proper conversion between PyTorch FP8 and Transformer Engine FP8 formats
- **Megatron engine integration**: Enhanced `MegatronEngine` to support FP8 training with proper configuration propagation
- **Comprehensive test suite**: Added extensive tests for FP8 conversion, BF16 comparison, and gradient correctness

The implementation supports the blockwise scheme, with integration into Transformer Engine's FP8 infrastructure for efficient GEMM operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants