diff --git a/.vscode/settings.json b/.vscode/settings.json index 41bac6a7e..e6a3603da 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -26,5 +26,6 @@ "editor.rulers": [ 120 ], - "autoDocstring.docstringFormat": "google-notypes" + "autoDocstring.docstringFormat": "google-notypes", + "search.exclude": { "**/logs/**": true }, } diff --git a/bionemo-recipes/recipes/esm2_native_te/.dockerignore b/bionemo-recipes/recipes/esm2_native_te/.dockerignore index e67ca715c..ff0577a46 100644 --- a/bionemo-recipes/recipes/esm2_native_te/.dockerignore +++ b/bionemo-recipes/recipes/esm2_native_te/.dockerignore @@ -1,10 +1,34 @@ +# Docker Dockerfile +Dockerfile.* +.dockerignore + +# Docs README.md -checkpoint_export/ -outputs/ -.ruff_cache + +# Python caches __pycache__ .pytest_cache -.ruff.toml -.dockerignore +.ruff_cache .venv/ + +# Linting +.ruff.toml + +# Profiling & debugging artifacts +memory_snapshots/ +nsight_profiling/ +*.nsys-rep +*.sqlite +logs/ +wandb/ + +# Hydra / training outputs +outputs/ +checkpoints/ + +# Checkpoint export +checkpoint_export/ + +# Temp / scratch +j/ diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index b940874af..71a793b1b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:25.12-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ diff --git a/bionemo-recipes/recipes/esm2_native_te/README.md b/bionemo-recipes/recipes/esm2_native_te/README.md index 6d6880873..ee646928f 100644 --- a/bionemo-recipes/recipes/esm2_native_te/README.md +++ b/bionemo-recipes/recipes/esm2_native_te/README.md @@ -1,7 +1,8 @@ # TransformerEngine-accelerated ESM-2 training with native PyTorch training loop This folder demonstrates how to train TE-accelerated ESM-2 with a native PyTorch training loop, including sequence -packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. +packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed +training. ## How to use this recipe @@ -15,10 +16,10 @@ bionemo-framework repository. You can download a zipped directory of this folder ## Supported Models and Training Features -| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | -| ----------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | -| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | 🚧 | ❌ | ❌ | 🚧 | +| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | +| ----------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | +| [ESM-2](../../models/esm2/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [AMPLIFY](../../models/amplify/README.md) | ✅ | ❌ | ❌ | ❌ | 🚧 | 🚧 | ✅: Supported
🚧: Under development
@@ -26,6 +27,7 @@ bionemo-framework repository. You can download a zipped directory of this folder \[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies @@ -72,6 +74,35 @@ Recently, we measured 2800 tokens/second/GPU training speed on H100 with Hugging of THD sequence packing, however we have not been able to make this configuration work on Blackwell and this work is still in progress. +### Low precision performance benchmarks +![Performance Benchmarks Low Precision](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png) +In the above plot, we can see that as we increase the scale of our models, the benefits of low precision training are apparent. +This is because at smaller parameter models (such as 650M, 3B) etc, the cost to quantize activations from high precision to low +precision outweights the benefits of performing matrix multiplication with low precision. However, as our models scale up in +parameter count, the fixed quantization cost is lower than our GEMM operational savings. + +Note: these plots were using our [fsdp2](./train_fsdp2.py) script. + + +### Convergence results for low precision training +#### MXFP8 +![Convergence Benchmarks MXFP8](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg) +In the above plot, for our ESM2-15B model that was trained with FSDP2 using 80 B300 GPUs nodes for 10 hours. We can clearly see that +our MXFP8 run and our BF16 baseline run have the same results. A perfect match in convergence. + +#### NVFP4 +![Convergence Benchmarks NVFP4](../../../docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg) +In the above plot, for our ESM2-15B model, we show several lines. Each experiment shows a unique configuration using a custom +amount of `fp4_layers` per run (more info on how to enable this below). Moreover, the experiments can be read as + `esm2-15b-b300-mxfp8-fp4_layer_start-fp4_layer_end-N-10-mbs-26-b300` which denotes at which point we start and end the fp4 layers. + +We see that as we add more and more layers, our E2E training results get worse. This is a tradeoff between performance and +accuracy. If we look at the performance chart above, we have increased performance dramatically, but our accuracy suffers. +It's important to note that in all NVFP4 experiments we are also utilizing stochastic rounding and random hadamard transformations. + +For more information regarding NVFP4 training please see [paper](https://arxiv.org/pdf/2509.25149) and [here](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) + + ### Distributed Training This recipe supports distributed training using DDP, FSDP2, and Megatron-FSDP, shown in three separate training @@ -97,35 +128,68 @@ torchrun --nproc_per_node=2 train_fsdp2.py # or train_mfsdp.py / train_ddp.py Multi-Node training is supported with all three strategies, see [`slurm.sh`](slurm.sh) for an example SLURM script. -### FP8 Training +### Quantized Training (FP8 / MXFP8 / NVFP4) -To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8 -configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration. +To run training with FP8, enable it via `fp8_config.enabled=true`. By default, all transformer layers will use FP8. ```bash python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true ``` -#### FP8 Debugging +Similarly, to train with NVFP4 quantization: + +```bash +python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true +``` -We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients. +Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration. -To enable this please select the following config options. +#### Layer-Wise Precision -```python +You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and +`fp4_layers`. Layers not assigned to either format will run in BF16. + +For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers: + +```bash +python train_fsdp2.py --config-name L0_sanity \ + fp8_config.enabled=true \ + fp4_config.enabled=true \ + 'fp8_layers=[1,2,3]' \ + 'fp4_layers=[4,5,6]' +``` + +When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically +claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no +`fp4_layers`, then layers 4 through N will default to FP4. + +#### Quantization Stats Debugging + +We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training. +When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are +tracked. + +To enable stats logging: + +```bash python train_fsdp2.py \ -fp8_stats_config.enabled=True # whether to log stats or not -fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy # where to store the logs -fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml # specifies what stats you want to run. Currently this is saved in this yaml file. -fp8_config.enabled=True # set this to use FP8 otherwise stats logging won't work + quant_stats_config.enabled=true \ + quant_stats_config.quant_log_dir=./logs/quant_stats \ + quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \ + fp8_config.enabled=true ``` -Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for `train_mfsdp`. +Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. It is not yet available for +`train_mfsdp`. NVFP4 stats logging is not yet supported and will be enabled in a future TransformerEngine release; +FP8/MXFP8 stats logging works today. -The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure. +The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the +[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) +in more detail. -This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our -experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit. +Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats +on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We +recommend using `freq>=10` to reduce this performance hit. ### Sequence Packing (THD input format) diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index c915f30ea..6f9444648 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -42,6 +42,7 @@ def create_tokenized_dataset( max_seq_length: int = 1024, buffer_size: int = 10_000, use_lazy_tokenization: bool = True, + tokenizer_revision: str | None = None, ): """Create a tokenized dataset.""" logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") @@ -56,7 +57,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None) def tokenize_function(examples): """Tokenize the protein sequences.""" @@ -93,6 +94,7 @@ def create_bshd_dataloader( use_lazy_tokenization: bool = True, use_stateful_dataloader: bool = False, mlm_probability: float = 0.15, + tokenizer_revision: str | None = None, ): """Create a dataloader for the dataset. @@ -121,6 +123,7 @@ def create_bshd_dataloader( max_seq_length=max_seq_length, buffer_size=buffer_size, use_lazy_tokenization=use_lazy_tokenization, + tokenizer_revision=tokenizer_revision, ) if isinstance(tokenized_dataset, datasets.IterableDataset): @@ -167,6 +170,7 @@ def create_thd_dataloader( use_stateful_dataloader: bool = False, mlm_probability: float = 0.15, pad_sequences_to_be_divisible_by: int | None = None, + tokenizer_revision: str | None = None, ): """Create a dataloader that packs up to the maximum number of tokens per batch. @@ -186,7 +190,7 @@ def create_thd_dataloader( mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. This is useful for context parallelism. Defaults to None. - + tokenizer_revision: The revision of the tokenizer to use. Defaults to None. Returns: A dataloader that can be used for training. """ @@ -196,6 +200,7 @@ def create_thd_dataloader( load_dataset_kwargs=load_dataset_kwargs, max_seq_length=max_seq_length, buffer_size=buffer_size, + tokenizer_revision=tokenizer_revision, ) assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml new file mode 100644 index 000000000..d56739a6a --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,33 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml index 7544bbedc..9653d8a04 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad, fprop] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml index 0b91c5608..2b6f602e3 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml @@ -8,6 +8,7 @@ num_train_steps: 500 dataset: micro_batch_size: 12 + tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml index e8e47d908..3e055907c 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml @@ -8,6 +8,7 @@ num_train_steps: 10_000 dataset: micro_batch_size: 16 + tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fd027601d..804f1ae21 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -8,7 +8,7 @@ num_train_steps: 200 dataset: micro_batch_size: 4 - + tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792" # WandB config wandb_init_args: name: "esm2_t33_650M_UR50D" diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index baace7c80..2682cb3a9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -13,6 +13,7 @@ cp_size: 1 use_sequence_packing: false dataset: tokenizer_name: ${model_tag} + tokenizer_revision: null micro_batch_size: ??? num_workers: 1 max_seq_length: 1024 @@ -51,6 +52,14 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + fp4_model_init_kwargs: + enabled: false # If this is set to true, fp4_config.enabled must also be set to true. + # Optimizer config adamw_kwargs: lr: 4e-4 @@ -76,7 +85,23 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: + +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + +# Nsight Systems profiling config. +# To use, wrap your launch command with: +# nsys profile -s none -t cuda,nvtx -o --force-overwrite true \ +# --capture-range=cudaProfilerApi --capture-range-end=stop +nsys_profiling: + enabled: false + start_step: 5 # Step at which to start CUDA profiler capture + end_step: 8 # Step at which to stop CUDA profiler capture + ranks: [0] # Which ranks to profile (list of ints) + +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: null +fp4_layers: null +use_fp32_master_weights: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 000000000..bd9ae9674 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,687 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import torch.cuda.nvtx as nvtx +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +import transformer_engine.common.recipe +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs +from contextlib import nullcontext + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = (transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling) +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling) + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + bf16_layers: Optional[list[int]] = None, + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + self.bf16_layers = bf16_layers + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.layer_number_quantized_recipe_map = None + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + # Utilize the layer number quantized recipe map to determine the context for each layer. + for layer_number, layer_module in enumerate(self.layers): + fp_recipe = self.layer_number_quantized_recipe_map[layer_number] if layer_number in self.layer_number_quantized_recipe_map else None + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + # If BF16 desired --> use autocast(false) so it goes to BF16. + # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. + # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. + if isinstance(fp_recipe, FP8_RECIPES): + fp_context = nullcontext() + elif isinstance(fp_recipe, FP4_RECIPES): + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) + else: + fp_context = transformer_engine.pytorch.autocast(enabled=False) + # TODO(@jomitchell): Double check that this works, make a funciton for it then unit test it. + + nvtx.range_push(f"encoder_layer_{layer_number}") + with fp_context: + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + nvtx.range_pop() # encoder_layer_N + + nvtx.range_push("emb_layer_norm_after") + hidden_states = self.emb_layer_norm_after(hidden_states) + nvtx.range_pop() # emb_layer_norm_after + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + """Override the default get_init_context method to allow for fp8 model initialization.""" + return [] + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys = ("lm_head.decoder.weight",) + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.init_weights() + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/quantization.py b/bionemo-recipes/recipes/esm2_native_te/quantization.py new file mode 100644 index 000000000..c545ef9e2 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/quantization.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for layer-wise quantization configuration (FP8/FP4).""" + +import logging +import tempfile +from pathlib import Path + +import yaml + +logger = logging.getLogger(__name__) + + +def generate_layer_regex(layer_numbers: list[int] | None) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``. + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in debug logs). + If empty or None, returns a pattern that matches nothing. + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + return r"model\.esm\.encoder\.layers\.DISABLED_NO_LAYERS_SPECIFIED" + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (a temp file). + """ + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + if "example_fp4_tensor_stat_collection" in config: + # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine + # release with LogNvfp4TensorStats support is available. At that point, this becomes: + # fp4_regex = generate_layer_regex(fp4_layers) + # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + config["example_fp4_tensor_stat_collection"]["enabled"] = False + if fp4_layers: + logger.warning( + "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). " + f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected." + ) + else: + logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)") + + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") + + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + +def initialize_quant_stats_logging( + quant_stats_file: str, + quant_log_dir: str, + rank: int, + quant_layers: "QuantizationLayers", +) -> None: + """Set up quantization stats logging via nvdlfw_inspect. + + Updates the quant stats YAML config with resolved layer regex patterns, creates + the per-rank log directory, and initializes the debug API. + + Args: + quant_stats_file: Path to the base quant stats YAML config file. + quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created). + rank: The global rank of this process. + quant_layers: Resolved quantization layer assignments. + """ + import nvdlfw_inspect.api as debug_api + import transformer_engine + + updated_config = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=quant_layers.fp4_layers_1indexed, + fp8_layers=quant_layers.fp8_layers_1indexed, + ) + + rank_log_dir = Path(quant_log_dir) / f"rank_{rank}" + rank_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {rank_log_dir}") + + te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") + debug_api.initialize( + config_file=updated_config, + feature_dirs=[te_features_dir], + log_dir=rank_log_dir, + default_logging_enabled=True, + ) + + +class QuantizationLayers: + """Resolved layer-wise quantization assignments. + + Attributes: + fp8_layers_0indexed: 0-indexed FP8 layer numbers (for model internals), or None. + fp4_layers_0indexed: 0-indexed FP4 layer numbers (for model internals), or None. + fp8_layers_1indexed: 1-indexed FP8 layer numbers (for user-facing logs / quant stats), or None. + fp4_layers_1indexed: 1-indexed FP4 layer numbers (for user-facing logs / quant stats), or None. + """ + + def __init__( + self, + fp8_layers_0indexed: list[int] | None, + fp4_layers_0indexed: list[int] | None, + fp8_layers_1indexed: list[int] | None, + fp4_layers_1indexed: list[int] | None, + ): + self.fp8_layers_0indexed = fp8_layers_0indexed + self.fp4_layers_0indexed = fp4_layers_0indexed + self.fp8_layers_1indexed = fp8_layers_1indexed + self.fp4_layers_1indexed = fp4_layers_1indexed + + +def resolve_quantization_layers( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> QuantizationLayers: + """Resolve layer-wise quantization assignments from user config. + + Takes 1-indexed layer lists (as specified by the user) and returns both 0-indexed lists + (for model internals) and 1-indexed lists (for quant stats / debug logging). When a quantization + format is enabled but no layer list is provided, all layers default to that format. When one format + has explicit layers and the other is enabled without a layer list, the unspecified format defaults + to the remaining (unclaimed) layers. + + Args: + num_layers: Total number of transformer layers in the model. + fp8_enabled: Whether FP8 quantization is enabled. + fp4_enabled: Whether FP4 quantization is enabled. + fp8_layers: 1-indexed list of layers for FP8, or None if not specified. + fp4_layers: 1-indexed list of layers for FP4, or None if not specified. + + Returns: + QuantizationLayers with both 0-indexed and 1-indexed layer lists. + + Raises: + ValueError: If both formats are enabled with no layer lists, or if layer lists overlap. + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError( + "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. " + "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format." + ) + + # When one format has explicit layers and the other defaults, fill in the remaining layers. + if fp8_enabled and fp8_layers is None: + claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set() + fp8_layers = sorted(all_layers - claimed_by_fp4) + if claimed_by_fp4: + logger.warning( + f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} " + f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}" + ) + else: + logger.info(f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8") + + if fp4_enabled and fp4_layers is None: + claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set() + fp4_layers = sorted(all_layers - claimed_by_fp8) + if claimed_by_fp8: + logger.warning( + f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} " + f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}" + ) + else: + logger.info(f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4") + + # Disable layer lists when corresponding config is not enabled. + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap between FP8 and FP4 layer assignments. + if fp8_layers is not None and fp4_layers is not None: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError( + f"fp8_layers and fp4_layers cannot have overlapping layer numbers. " + f"Found overlap: {sorted(overlap)}" + ) + + return QuantizationLayers( + fp8_layers_0indexed=[layer - 1 for layer in fp8_layers] if fp8_layers is not None else None, + fp4_layers_0indexed=[layer - 1 for layer in fp4_layers] if fp4_layers is not None else None, + fp8_layers_1indexed=fp8_layers, + fp4_layers_1indexed=fp4_layers, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index b18607fd7..0602ca8a8 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers +transformers==4.57.3 wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py new file mode 100644 index 000000000..afe54dc3c --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py @@ -0,0 +1,339 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys +from pathlib import Path + +import pytest +import yaml + +sys.path.append(Path(__file__).parent.parent.as_posix()) + +from quantization import generate_layer_regex, resolve_quantization_layers, update_quant_stats_config + + +class TestResolveQuantizationLayers: + """Tests for resolve_quantization_layers().""" + + def test_fp8_enabled_no_layers_defaults_all(self): + """When fp8 is enabled with no explicit layers, all layers should default to FP8.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 1, 2, 3, 4, 5] + assert result.fp8_layers_1indexed == [1, 2, 3, 4, 5, 6] + assert result.fp4_layers_0indexed is None + assert result.fp4_layers_1indexed is None + + def test_fp4_enabled_no_layers_defaults_all(self): + """When fp4 is enabled with no explicit layers, all layers should default to FP4.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed == [0, 1, 2, 3, 4, 5] + assert result.fp4_layers_1indexed == [1, 2, 3, 4, 5, 6] + + def test_fp8_explicit_layers(self): + """Explicit 1-indexed fp8_layers should be converted to 0-indexed.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 2, 4] + assert result.fp8_layers_1indexed == [1, 3, 5] + assert result.fp4_layers_0indexed is None + + def test_fp4_explicit_layers(self): + """Explicit 1-indexed fp4_layers should be converted to 0-indexed.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6] + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed == [1, 3, 5] + assert result.fp4_layers_1indexed == [2, 4, 6] + + def test_mixed_fp8_fp4_explicit(self): + """Both enabled with explicit non-overlapping layers should work correctly.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5] + ) + assert result.fp8_layers_0indexed == [0, 2, 3] + assert result.fp8_layers_1indexed == [1, 3, 4] + assert result.fp4_layers_0indexed == [1, 4] + assert result.fp4_layers_1indexed == [2, 5] + + def test_both_enabled_no_layers_raises(self): + """Both enabled with no layer lists should raise ValueError.""" + with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"): + resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None + ) + + def test_overlapping_layers_raises(self): + """Overlapping layer assignments should raise ValueError.""" + with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"): + resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5] + ) + + def test_disabled_ignores_layers(self): + """When a format is disabled, its layers should be None even if provided.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6] + ) + assert result.fp8_layers_0indexed is None + assert result.fp8_layers_1indexed is None + assert result.fp4_layers_0indexed is None + assert result.fp4_layers_1indexed is None + + def test_both_disabled(self): + """Both disabled with no layers should return all None.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed is None + assert result.fp4_layers_0indexed is None + + def test_large_model_defaults_all(self): + """Auto-population should work correctly for larger models (e.g. 36 layers).""" + result = resolve_quantization_layers( + num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None + ) + assert result.fp8_layers_0indexed == list(range(36)) + assert result.fp8_layers_1indexed == list(range(1, 37)) + + def test_fp8_enabled_empty_list(self): + """An explicit empty list should remain empty (not default to all).""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [] + assert result.fp8_layers_1indexed == [] + + def test_both_enabled_fp8_specified_fp4_defaults_to_remaining(self): + """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None + ) + assert result.fp8_layers_0indexed == [0, 1, 2] + assert result.fp8_layers_1indexed == [1, 2, 3] + assert result.fp4_layers_0indexed == [3, 4, 5] + assert result.fp4_layers_1indexed == [4, 5, 6] + + def test_both_enabled_fp4_specified_fp8_defaults_to_remaining(self): + """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers.""" + result = resolve_quantization_layers( + num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6] + ) + assert result.fp8_layers_0indexed == [0, 1, 2] + assert result.fp8_layers_1indexed == [1, 2, 3] + assert result.fp4_layers_0indexed == [3, 4, 5] + assert result.fp4_layers_1indexed == [4, 5, 6] + + +class TestGenerateLayerRegex: + """Tests for generate_layer_regex().""" + + def test_single_layer(self): + """Single layer should produce a simple regex.""" + regex = generate_layer_regex([3]) + assert re.search(regex, "model.esm.encoder.layers.3.self_attention.layernorm_qkv") + assert not re.search(regex, "model.esm.encoder.layers.2.self_attention.layernorm_qkv") + + def test_multiple_layers(self): + """Multiple layers should match any of them.""" + regex = generate_layer_regex([1, 2, 3]) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.esm.encoder.layers.2.layernorm_mlp.fc1") + assert re.search(regex, "model.esm.encoder.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.esm.encoder.layers.4.self_attention.proj") + + def test_matches_correct_sublayers(self): + """Regex should only match layernorm_qkv, proj, fc1, fc2.""" + regex = generate_layer_regex([1]) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv_something") + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.proj_something") + assert re.search(regex, "model.esm.encoder.layers.1.layernorm_mlp.fc1_something") + assert re.search(regex, "model.esm.encoder.layers.1.layernorm_mlp.fc2_something") + # Should not match unrelated sublayer names + assert not re.search(regex, "model.esm.encoder.layers.1.self_attention.some_other_thing") + + def test_none_returns_disabled_pattern(self): + """None should return a pattern that matches nothing.""" + regex = generate_layer_regex(None) + assert "DISABLED" in regex + assert not re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + + def test_empty_list_returns_disabled_pattern(self): + """Empty list should return a pattern that matches nothing.""" + regex = generate_layer_regex([]) + assert "DISABLED" in regex + + def test_1indexed_layer_names(self): + """Regex should use 1-indexed layer numbers (matching debug API naming).""" + regex = generate_layer_regex([1]) + # Should match layers.1 (1-indexed first layer) + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + # Should NOT match layers.0 (0-indexed first layer) + assert not re.search(regex, "model.esm.encoder.layers.0.self_attention.layernorm_qkv") + + +class TestUpdateQuantStatsConfig: + """Tests for update_quant_stats_config().""" + + @pytest.fixture + def fp8_only_config(self, tmp_path): + """Create an FP8-only stats config file.""" + config = { + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": { + "enabled": True, + "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}], + } + }, + } + } + config_path = tmp_path / "fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + @pytest.fixture + def fp4_fp8_config(self, tmp_path): + """Create a combined FP4+FP8 stats config file.""" + config = { + "example_fp4_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogNvfp4TensorStats": {"enabled": True}, + }, + }, + "example_fp8_tensor_stat_collection": { + "enabled": True, + "layers": { + "layer_name_regex_pattern": "PLACEHOLDER", + }, + "transformer_engine": { + "LogFp8TensorStats": {"enabled": True}, + }, + }, + } + config_path = tmp_path / "fp4_fp8_stats.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + return str(config_path) + + def test_fp8_layers_updates_regex(self, fp8_only_config): + """FP8 layer list should update the regex in the output config.""" + output_path = update_quant_stats_config( + config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.esm.encoder.layers.1.self_attention.layernorm_qkv") + assert re.search(regex, "model.esm.encoder.layers.3.layernorm_mlp.fc2") + assert not re.search(regex, "model.esm.encoder.layers.4.self_attention.proj") + + def test_none_layers_disables_matching(self, fp8_only_config): + """None layers should set regex to match nothing.""" + output_path = update_quant_stats_config( + config_file=fp8_only_config, fp4_layers=None, fp8_layers=None + ) + with open(output_path) as f: + result = yaml.safe_load(f) + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert "DISABLED" in regex + + def test_fp4_section_disabled_fp8_still_updated(self, fp4_fp8_config): + """FP4 stats section should be disabled (not yet supported), FP8 should still be updated.""" + output_path = update_quant_stats_config( + config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 regex should still match layers 4-6 + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.esm.encoder.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.esm.encoder.layers.2.self_attention.proj") + + def test_original_file_not_modified(self, fp8_only_config): + """update_quant_stats_config should write to a temp file, not modify the original.""" + with open(fp8_only_config) as f: + original_content = f.read() + + output_path = update_quant_stats_config( + config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2] + ) + + assert output_path != fp8_only_config + with open(fp8_only_config) as f: + assert f.read() == original_content + + def test_preserves_other_config_fields(self, fp8_only_config): + """Non-layer fields in the config should be preserved.""" + output_path = update_quant_stats_config( + config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + # The transformer_engine section should still be there + assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True + + def test_missing_section_is_skipped(self, fp8_only_config): + """If fp4 section doesn't exist in config, it should be silently skipped.""" + # fp8_only_config has no fp4 section — passing fp4_layers should not error + output_path = update_quant_stats_config( + config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + # Only FP8 section should exist and be updated + assert "example_fp4_tensor_stat_collection" not in result + regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(regex, "model.esm.encoder.layers.3.self_attention.layernorm_qkv") + + def test_with_real_fp4_config(self): + """Test with the actual fp4_debugging_stats.yaml file.""" + config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml" + if not config_path.exists(): + pytest.skip("fp4_debugging_stats.yaml not found") + + output_path = update_quant_stats_config( + config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6] + ) + with open(output_path) as f: + result = yaml.safe_load(f) + + # FP4 section should be disabled (not yet supported in current TE release) + assert result["example_fp4_tensor_stat_collection"]["enabled"] is False + + # FP8 section should still be updated and working + fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] + assert re.search(fp8_regex, "model.esm.encoder.layers.5.self_attention.proj") + assert not re.search(fp8_regex, "model.esm.encoder.layers.2.self_attention.proj") diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py index 32b1e99ef..80aef834b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_train.py @@ -154,8 +154,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path): f"+wandb_init_args.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) @@ -211,8 +211,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path): f"+wandb_init_args.dir={tmp_path}", f"checkpoint.ckpt_dir={tmp_path}", "fp8_config.enabled=true", - "fp8_stats_config.enabled=true", - f"fp8_stats_config.fp8_log_dir={fp8_log_dir}", + "quant_stats_config.enabled=true", + f"quant_stats_config.quant_log_dir={fp8_log_dir}", "num_train_steps=4", ], ) diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 1027703f3..9d722bbd2 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -53,21 +53,21 @@ def main(args: DictConfig) -> float | None: torch.cuda.set_device(dist_config.local_rank) # TE Debug feature logging - if args.fp8_stats_config.enabled and not args.fp8_config.enabled: + if args.quant_stats_config.enabled and not args.fp8_config.enabled: raise ValueError( "fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" ) - if args.fp8_stats_config.enabled: - fp8_stats_file = args.fp8_stats_config.fp8_stats_file - fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") + if args.quant_stats_config.enabled: + quant_stats_file = args.quant_stats_config.quant_stats_file + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" + quant_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {quant_log_dir}") te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") debug_api.initialize( - config_file=fp8_stats_file, + config_file=quant_stats_file, feature_dirs=[te_features_dir], - log_dir=fp8_log_dir, + log_dir=quant_log_dir, default_logging_enabled=True, ) # Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2 @@ -104,7 +104,7 @@ def main(args: DictConfig) -> float | None: optimizer = AdamW(model.parameters(), **args.adamw_kwargs) scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) model = model.to(device=device) @@ -157,7 +157,7 @@ def main(args: DictConfig) -> float | None: loss = outputs.loss loss.backward() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.step() # Compute and clip gradient norms. total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() @@ -206,7 +206,7 @@ def main(args: DictConfig) -> float | None: # Clean up distributed training perf_logger.finish() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.end_debug() torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 28409e0c1..5e21a23e0 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -20,15 +20,20 @@ import hydra import nvdlfw_inspect.api as debug_api import torch +import torch.cuda.nvtx as nvtx import transformer_engine import transformer_engine.pytorch + from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW + from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + # This import seems to be needed with meta device init and AutoModel.from_config from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 @@ -36,6 +41,7 @@ from dataset import create_bshd_dataloader, create_thd_dataloader from distributed_config import DistributedConfig from perf_logger import PerfLogger +from quantization import initialize_quant_stats_logging, resolve_quantization_layers from scheduler import get_linear_schedule_with_warmup @@ -57,23 +63,29 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled and not args.fp8_config.enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" - ) - - if args.fp8_stats_config.enabled: - fp8_stats_file = args.fp8_stats_config.fp8_stats_file - fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") - te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") - debug_api.initialize( - config_file=fp8_stats_file, - feature_dirs=[te_features_dir], - log_dir=fp8_log_dir, - default_logging_enabled=True, + # Load model config early so we know the number of layers for auto-populating layer lists. + config = NVEsmConfig.from_pretrained( + args.model_tag, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16 + ) + num_layers = config.num_hidden_layers + + # Resolve layer-wise quantization assignments. + quant_layers = resolve_quantization_layers( + num_layers=num_layers, + fp8_enabled=args.fp8_config.enabled, + fp4_enabled=args.fp4_config.enabled, + fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None, + fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None, + ) + fp8_layers = quant_layers.fp8_layers_0indexed + fp4_layers = quant_layers.fp4_layers_0indexed + + if args.quant_stats_config.enabled: + initialize_quant_stats_logging( + quant_stats_file=args.quant_stats_config.quant_stats_file, + quant_log_dir=args.quant_stats_config.quant_log_dir, + rank=dist_config.rank, + quant_layers=quant_layers, ) # Create a device mesh for FSDP. @@ -84,12 +96,16 @@ def main(args: DictConfig) -> float | None: ) # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -99,19 +115,41 @@ def main(args: DictConfig) -> float | None: # versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), - transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + # model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer - for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward + reduce_dtype=torch.float32, # Gradient reductions in FP32 + output_dtype=torch.bfloat16, # Forward output dtype + ) + if args.use_fp32_master_weights: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + else: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + # Create a layer map for the transformer stack. + layer_number_quantized_recipe_map = {} + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + for layer_number, layer in enumerate(transformer_stack): + if layer_number in fp8_layers_set: + layer_number_quantized_recipe_map[layer_number] = fp8_recipe + elif layer_number in fp4_layers_set: + layer_number_quantized_recipe_map[layer_number] = fp4_recipe + else: + layer_number_quantized_recipe_map[layer_number] = None + model.esm.encoder.layer_number_quantized_recipe_map = layer_number_quantized_recipe_map # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: @@ -123,11 +161,12 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + # Note: Got an error about mixed torch.Tensor and DTensor here, so using AdamW instead. scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. @@ -158,31 +197,60 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + # Nsight Systems profiling setup. + nsys_enabled = args.nsys_profiling.enabled + nsys_start_step = args.nsys_profiling.start_step if nsys_enabled else -1 + nsys_end_step = args.nsys_profiling.end_step if nsys_enabled else -1 + nsys_ranks = set(OmegaConf.to_container(args.nsys_profiling.ranks, resolve=True)) if nsys_enabled else set() + nsys_profiling_active = False + + if nsys_enabled and dist_config.rank in nsys_ranks: + logger.info( + f"Nsight profiling enabled for rank {dist_config.rank}: " + f"will capture steps [{nsys_start_step}, {nsys_end_step})" + ) + # Training loop step = start_step while step < args.num_train_steps: for batch in train_dataloader: - batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + # --- Nsys: start profiler at the configured step --- + if nsys_enabled and step == nsys_start_step and dist_config.rank in nsys_ranks: + logger.info(f"[Rank {dist_config.rank}] Starting nsys capture at step {step}") + torch.cuda.cudart().cudaProfilerStart() + nsys_profiling_active = True - # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 + + # --- Forward pass --- + nvtx.range_push(f"step_{step}") + nvtx.range_push("forward") + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None): outputs = model(**batch) + nvtx.range_pop() # forward - # Backward pass. + # --- Backward pass --- + nvtx.range_push("backward") loss = outputs.loss loss.backward() + nvtx.range_pop() # backward - # Compute and clip gradient norms. + # --- Grad clip --- + nvtx.range_push("clip_grad_norm") total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item() + nvtx.range_pop() # clip_grad_norm - # Step optimizer. + # --- Optimizer step --- + nvtx.range_push("optimizer_step") optimizer.step() scheduler.step() + nvtx.range_pop() # optimizer_step - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.step() optimizer.zero_grad() + nvtx.range_pop() # step_N perf_logger.log_step( step=step, @@ -205,6 +273,12 @@ def main(args: DictConfig) -> float | None: max_checkpoints=args.checkpoint.max_checkpoints, ) + # --- Nsys: stop profiler at the configured step --- + if nsys_profiling_active and step >= nsys_end_step: + logger.info(f"[Rank {dist_config.rank}] Stopping nsys capture at step {step}") + torch.cuda.cudart().cudaProfilerStop() + nsys_profiling_active = False + step += 1 if step >= args.num_train_steps: break @@ -221,9 +295,15 @@ def main(args: DictConfig) -> float | None: dist_config=dist_config, ) + # Ensure nsys profiler is stopped if training ended before end_step. + if nsys_profiling_active: + logger.info(f"[Rank {dist_config.rank}] Stopping nsys capture at end of training (step {step})") + torch.cuda.cudart().cudaProfilerStop() + nsys_profiling_active = False + # Clean up distributed training perf_logger.finish() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.end_debug() torch.distributed.destroy_process_group() diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg new file mode 100644 index 000000000..5880bb7cb --- /dev/null +++ b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg @@ -0,0 +1,118 @@ +
5001k1.5k2kStep1313.51414.51515.516
\ No newline at end of file diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg new file mode 100644 index 000000000..9f43a414a --- /dev/null +++ b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg @@ -0,0 +1,118 @@ +
5001k1.5k2kStep1313.51414.51515.516
\ No newline at end of file diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png new file mode 100644 index 000000000..2a71d80f9 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png differ diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png new file mode 100644 index 000000000..476689147 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png differ diff --git a/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png b/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png new file mode 100644 index 000000000..d89a4e615 Binary files /dev/null and b/docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png differ