Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
"editor.rulers": [
120
],
"autoDocstring.docstringFormat": "google-notypes"
"autoDocstring.docstringFormat": "google-notypes",
"search.exclude": { "**/logs/**": true },
}
34 changes: 29 additions & 5 deletions bionemo-recipes/recipes/esm2_native_te/.dockerignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
# Docker
Dockerfile
Dockerfile.*
.dockerignore

# Docs
README.md
checkpoint_export/
outputs/
.ruff_cache

# Python caches
__pycache__
.pytest_cache
.ruff.toml
.dockerignore
.ruff_cache
.venv/

# Linting
.ruff.toml

# Profiling & debugging artifacts
memory_snapshots/
nsight_profiling/
*.nsys-rep
*.sqlite
logs/
wandb/

# Hydra / training outputs
outputs/
checkpoints/

# Checkpoint export
checkpoint_export/

# Temp / scratch
j/
2 changes: 1 addition & 1 deletion bionemo-recipes/recipes/esm2_native_te/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.4
FROM nvcr.io/nvidia/pytorch:25.12-py3
FROM nvcr.io/nvidia/pytorch:25.11-py3

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
Expand Down
104 changes: 84 additions & 20 deletions bionemo-recipes/recipes/esm2_native_te/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# TransformerEngine-accelerated ESM-2 training with native PyTorch training loop

This folder demonstrates how to train TE-accelerated ESM-2 with a native PyTorch training loop, including sequence
packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training.
packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed
training.

## How to use this recipe

Expand All @@ -15,17 +16,18 @@ bionemo-framework repository. You can download a zipped directory of this folder

## Supported Models and Training Features

| Model | BF16 | FP8<sup>[1]</sup> | THD Input Format | FP8 with THD Input Format | MXFP8<sup>[2]</sup> | Context Parallelism |
| ----------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- |
| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | 🚧 | ❌ | ❌ | 🚧 |
| Model | BF16 | FP8<sup>[1]</sup> | MXFP8<sup>[2]</sup> | NVFP4<sup>[3]</sup> | THD Input Format | Context Parallelism |
| ----------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- |
| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | | ❌ | 🚧 | 🚧 |

✅: Supported <br/>
🚧: Under development <br/>
❌: Not supported <br/>

\[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+) <br/>
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending <br/>
\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+) <br/>

### Installing Dependencies

Expand Down Expand Up @@ -72,6 +74,35 @@ Recently, we measured 2800 tokens/second/GPU training speed on H100 with Hugging
of THD sequence packing, however we have not been able to make this configuration work on Blackwell and this work is
still in progress.

### Low precision performance benchmarks
![Performance Benchmarks Low Precision](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png)
In the above plot, we can see that as we increase the scale of our models, the benefits of low precision training are apparent.
This is because at smaller parameter models (such as 650M, 3B) etc, the cost to quantize activations from high precision to low
precision outweights the benefits of performing matrix multiplication with low precision. However, as our models scale up in
parameter count, the fixed quantization cost is lower than our GEMM operational savings.

Note: these plots were using our [fsdp2](./train_fsdp2.py) script.


### Convergence results for low precision training
#### MXFP8
![Convergence Benchmarks MXFP8](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg)
In the above plot, for our ESM2-15B model that was trained with FSDP2 using 80 B300 GPUs nodes for 10 hours. We can clearly see that
our MXFP8 run and our BF16 baseline run have the same results. A perfect match in convergence.

#### NVFP4
![Convergence Benchmarks NVFP4](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg)
In the above plot, for our ESM2-15B model, we show several lines. Each experiment shows a unique configuration using a custom
amount of `fp4_layers` per run (more info on how to enable this below). Moreover, the experiments can be read as
`esm2-15b-b300-mxfp8-fp4_layer_start-fp4_layer_end-N-10-mbs-26-b300` which denotes at which point we start and end the fp4 layers.

We see that as we add more and more layers, our E2E training results get worse. This is a tradeoff between performance and
accuracy. If we look at the performance chart above, we have increased performance dramatically, but our accuracy suffers.
It's important to note that in all NVFP4 experiments we are also utilizing stochastic rounding and random hadamard transformations.

For more information regarding NVFP4 training please see [paper](https://arxiv.org/pdf/2509.25149) and [here](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)


### Distributed Training

This recipe supports distributed training using DDP, FSDP2, and Megatron-FSDP, shown in three separate training
Expand All @@ -97,35 +128,68 @@ torchrun --nproc_per_node=2 train_fsdp2.py # or train_mfsdp.py / train_ddp.py

Multi-Node training is supported with all three strategies, see [`slurm.sh`](slurm.sh) for an example SLURM script.

### FP8 Training
### Quantized Training (FP8 / MXFP8 / NVFP4)

To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8
configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration.
To run training with FP8, enable it via `fp8_config.enabled=true`. By default, all transformer layers will use FP8.

```bash
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true
```

#### FP8 Debugging
Similarly, to train with NVFP4 quantization:

```bash
python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true
```

We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.
Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration.

To enable this please select the following config options.
#### Layer-Wise Precision

```python
You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and
`fp4_layers`. Layers not assigned to either format will run in BF16.

For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers:

```bash
python train_fsdp2.py --config-name L0_sanity \
fp8_config.enabled=true \
fp4_config.enabled=true \
'fp8_layers=[1,2,3]' \
'fp4_layers=[4,5,6]'
```

When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically
claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no
`fp4_layers`, then layers 4 through N will default to FP4.

#### Quantization Stats Debugging

We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training.
When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are
tracked.

To enable stats logging:

```bash
python train_fsdp2.py \
fp8_stats_config.enabled=True # whether to log stats or not
fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs
fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file.
fp8_config.enabled=True # set this to use FP8 otherwise stats logging won't work
quant_stats_config.enabled=true \
quant_stats_config.quant_log_dir=./logs/quant_stats \
quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \
fp8_config.enabled=true
```

Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`.
Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for
`train_mfsdp`. NVFP4 stats logging is not yet supported and will be enabled in a future TransformerEngine release;
FP8/MXFP8 stats logging works today.

The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure.
The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the
[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html)
in more detail.

This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our
experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit.
Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats
on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We
recommend using `freq>=10` to reduce this performance hit.

### Sequence Packing (THD input format)

Expand Down
9 changes: 7 additions & 2 deletions bionemo-recipes/recipes/esm2_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create_tokenized_dataset(
max_seq_length: int = 1024,
buffer_size: int = 10_000,
use_lazy_tokenization: bool = True,
tokenizer_revision: str | None = None,
):
"""Create a tokenized dataset."""
logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}")
Expand All @@ -56,7 +57,7 @@ def create_tokenized_dataset(
)
dataset = dataset.shuffle(seed=42, buffer_size=buffer_size)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None)

def tokenize_function(examples):
"""Tokenize the protein sequences."""
Expand Down Expand Up @@ -93,6 +94,7 @@ def create_bshd_dataloader(
use_lazy_tokenization: bool = True,
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
tokenizer_revision: str | None = None,
):
"""Create a dataloader for the dataset.

Expand Down Expand Up @@ -121,6 +123,7 @@ def create_bshd_dataloader(
max_seq_length=max_seq_length,
buffer_size=buffer_size,
use_lazy_tokenization=use_lazy_tokenization,
tokenizer_revision=tokenizer_revision,
)

if isinstance(tokenized_dataset, datasets.IterableDataset):
Expand Down Expand Up @@ -167,6 +170,7 @@ def create_thd_dataloader(
use_stateful_dataloader: bool = False,
mlm_probability: float = 0.15,
pad_sequences_to_be_divisible_by: int | None = None,
tokenizer_revision: str | None = None,
):
"""Create a dataloader that packs up to the maximum number of tokens per batch.

Expand All @@ -186,7 +190,7 @@ def create_thd_dataloader(
mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking.
pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value.
This is useful for context parallelism. Defaults to None.

tokenizer_revision: The revision of the tokenizer to use. Defaults to None.
Returns:
A dataloader that can be used for training.
"""
Expand All @@ -196,6 +200,7 @@ def create_thd_dataloader(
load_dataset_kwargs=load_dataset_kwargs,
max_seq_length=max_seq_length,
buffer_size=buffer_size,
tokenizer_revision=tokenizer_revision,
)

assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset."
Expand Down
33 changes: 33 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
example_fp4_tensor_stat_collection:
enabled: True
layers:
# Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming)
# This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)'
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 100
- tensor: gradient
stats: [underflows%, mse]
freq: 100

example_fp8_tensor_stat_collection:
enabled: True
layers:
# Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming)
# This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
freq: 100
- tensor: gradient
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
freq: 100
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
layer_types: [layernorm_qkv]
layer_types: [layernorm_qkv, proj, fc1, fc2]
transformer_engine:
LogFp8TensorStats:
enabled: True
Expand All @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection:
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 10
LogTensorStats:
enabled: True
stats: [max, min, mean, std, l1_norm]
tensors: [dgrad, wgrad, fprop]
freq: 1
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ num_train_steps: 500

dataset:
micro_batch_size: 12
tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f"

# WandB config
wandb_init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ num_train_steps: 10_000

dataset:
micro_batch_size: 16
tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6"

# WandB config
wandb_init_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ num_train_steps: 200

dataset:
micro_batch_size: 4

tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792"
# WandB config
wandb_init_args:
name: "esm2_t33_650M_UR50D"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cp_size: 1
use_sequence_packing: false
dataset:
tokenizer_name: ${model_tag}
tokenizer_revision: null
micro_batch_size: ???
num_workers: 1
max_seq_length: 1024
Expand Down Expand Up @@ -51,6 +52,14 @@ fp8_config:
fp8_model_init_kwargs:
enabled: false # If this is set to true, fp8_config.enabled must also be set to true.

fp4_config:
enabled: false
fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling
fp4_format: "E2M1"
fp4_recipe_kwargs: {}
fp4_model_init_kwargs:
enabled: false # If this is set to true, fp4_config.enabled must also be set to true.

# Optimizer config
adamw_kwargs:
lr: 4e-4
Expand All @@ -76,7 +85,23 @@ checkpoint:
logger:
frequency: 100

fp8_stats_config:

quant_stats_config:
enabled: false
fp8_stats_file: ./fp8_debugging_stats.yaml
fp8_log_dir: ./log_fp8_stats
quant_stats_file: ./fp8_debugging_stats.yaml
quant_log_dir: ./log_quant_stats

# Nsight Systems profiling config.
# To use, wrap your launch command with:
# nsys profile -s none -t cuda,nvtx -o <output_report> --force-overwrite true \
# --capture-range=cudaProfilerApi --capture-range-end=stop <your torchrun command>
nsys_profiling:
enabled: false
start_step: 5 # Step at which to start CUDA profiler capture
end_step: 8 # Step at which to stop CUDA profiler capture
ranks: [0] # Which ranks to profile (list of ints)

# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
fp8_layers: null
fp4_layers: null
use_fp32_master_weights: null
Loading