diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index b940874af3..71a793b1b7 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:25.12-py3 +FROM nvcr.io/nvidia/pytorch:25.11-py3 RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing new file mode 100644 index 0000000000..c388ddf563 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile.te_testing @@ -0,0 +1,22 @@ +# syntax=docker/dockerfile:1.4 +FROM nvcr.io/nvidia/pytorch:25.12-py3 + +# Install sccache for faster builds +RUN --mount=type=cache,target=/var/lib/apt/lists,sharing=locked \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + apt-get update \ + && apt-get install -y sccache \ + && rm -rf /var/lib/apt/lists/* + +# Uninstall pre-installed Transformer Engine and install from source +RUN pip uninstall -y transformer-engine && \ + NVTE_USE_CCACHE=1 NVTE_CCACHE_BIN=sccache NVTE_FRAMEWORK=pytorch NVTE_BUILD_DEBUG=1 \ + pip install -v --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main + +# Install BioNeMo requirements +WORKDIR /workspace/bionemo +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + PIP_CONSTRAINT= pip install -r /requirements.txt + +COPY . . diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index c915f30ea8..78def1ef96 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -42,6 +42,7 @@ def create_tokenized_dataset( max_seq_length: int = 1024, buffer_size: int = 10_000, use_lazy_tokenization: bool = True, + tokenizer_revision: str | None = None, ): """Create a tokenized dataset.""" logger.info(f"Loading dataset with kwargs: {load_dataset_kwargs}") @@ -56,7 +57,7 @@ def create_tokenized_dataset( ) dataset = dataset.shuffle(seed=42, buffer_size=buffer_size) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None) def tokenize_function(examples): """Tokenize the protein sequences.""" @@ -167,6 +168,7 @@ def create_thd_dataloader( use_stateful_dataloader: bool = False, mlm_probability: float = 0.15, pad_sequences_to_be_divisible_by: int | None = None, + tokenizer_revision: str | None = None, ): """Create a dataloader that packs up to the maximum number of tokens per batch. @@ -186,7 +188,7 @@ def create_thd_dataloader( mlm_probability: The probability of masking tokens for MLM (default 0.15). Set to 0 for no masking. pad_sequences_to_be_divisible_by: If provided, sequences will be padded to be divisible by this value. This is useful for context parallelism. Defaults to None. - + tokenizer_revision: The revision of the tokenizer to use. Defaults to None. Returns: A dataloader that can be used for training. """ @@ -196,6 +198,7 @@ def create_thd_dataloader( load_dataset_kwargs=load_dataset_kwargs, max_seq_length=max_seq_length, buffer_size=buffer_size, + tokenizer_revision=tokenizer_revision, ) assert isinstance(tokenized_dataset, datasets.IterableDataset), "THD token packing requires a streaming dataset." diff --git a/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml new file mode 100644 index 0000000000..d56739a6a6 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml @@ -0,0 +1,33 @@ +example_fp4_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogNvfp4TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [underflows%, mse] + freq: 100 + - tensor: gradient + stats: [underflows%, mse] + freq: 100 + +example_fp8_tensor_stat_collection: + enabled: True + layers: + # Use regex to select layers 0-4 (1-indexed as layers.1 through layers.5 in the naming) + # This matches: model.esm.encoder.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2) + layer_name_regex_pattern: 'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)' + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 + - tensor: gradient + stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse] + freq: 100 diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml index 7544bbedcf..9653d8a044 100644 --- a/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml @@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection: enabled: True layers: # Match the actual linear layers within attention that support FP8 stats - layer_types: [layernorm_qkv] + layer_types: [layernorm_qkv, proj, fc1, fc2] transformer_engine: LogFp8TensorStats: enabled: True @@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection: - tensor: weight stats: [underflows%, scale_inv_min, scale_inv_max, mse] freq: 10 + LogTensorStats: + enabled: True + stats: [max, min, mean, std, l1_norm] + tensors: [dgrad, wgrad, fprop] + freq: 1 diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml index 0b91c5608a..2b6f602e3e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml @@ -8,6 +8,7 @@ num_train_steps: 500 dataset: micro_batch_size: 12 + tokenizer_revision: "f29e20d2b10d0aba2036831df65cdca1befe926f" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml index e8e47d908f..3e055907ca 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml @@ -8,6 +8,7 @@ num_train_steps: 10_000 dataset: micro_batch_size: 16 + tokenizer_revision: "86a86f18e6bb1eb4bcf91c594e1c0ad446d8eec6" # WandB config wandb_init_args: diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml index fd027601df..fc1153dc31 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml @@ -8,7 +8,7 @@ num_train_steps: 200 dataset: micro_batch_size: 4 - + tokenizer_revision: "d81c2e5aec37b5e794d0482e3996fb045a137792" # WandB config wandb_init_args: name: "esm2_t33_650M_UR50D" @@ -17,3 +17,43 @@ wandb_init_args: checkpoint: ckpt_dir: "checkpoints/esm2_t33_650M_UR50D_sanity" + +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 27 + - 28 + - 29 + - 30 + - 31 + - 32 + - 33 + +fp4_layers: + - 9 + - 10 + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 + +use_fp32_optimizer_weights: true \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml index baace7c805..0cbc271212 100644 --- a/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml +++ b/bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml @@ -13,6 +13,7 @@ cp_size: 1 use_sequence_packing: false dataset: tokenizer_name: ${model_tag} + tokenizer_revision: null micro_batch_size: ??? num_workers: 1 max_seq_length: 1024 @@ -51,6 +52,14 @@ fp8_config: fp8_model_init_kwargs: enabled: false # If this is set to true, fp8_config.enabled must also be set to true. +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + fp4_model_init_kwargs: + enabled: false # If this is set to true, fp4_config.enabled must also be set to true. + # Optimizer config adamw_kwargs: lr: 4e-4 @@ -76,7 +85,13 @@ checkpoint: logger: frequency: 100 -fp8_stats_config: + +quant_stats_config: enabled: false - fp8_stats_file: ./fp8_debugging_stats.yaml - fp8_log_dir: ./log_fp8_stats + quant_stats_file: ./fp8_debugging_stats.yaml + quant_log_dir: ./log_quant_stats + +# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime. +fp8_layers: null +fp4_layers: null +use_fp32_master_weights: null \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py new file mode 100644 index 0000000000..1932c0f4c7 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py @@ -0,0 +1,683 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +import transformer_engine.common.recipe +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs +from contextlib import nullcontext + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + +# From https://github.com/NVIDIA/TransformerEngine/blob/3ceb248e01a2c0dc1215fe0f46ebc235f289ba0d/transformer_engine/common/recipe/__init__.py#L86 +FP8_RECIPES = (transformer_engine.common.recipe.MXFP8BlockScaling, + transformer_engine.common.recipe.DelayedScaling, + transformer_engine.common.recipe.Float8CurrentScaling, + transformer_engine.common.recipe.Float8BlockScaling) +FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling) + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + bf16_layers: Optional[list[int]] = None, + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + self.bf16_layers = bf16_layers + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.layer_number_quantized_recipe_map = None + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + # Utilize the layer number quantized recipe map to determine the context for each layer. + for layer_number, layer_module in enumerate(self.layers): + fp_recipe = self.layer_number_quantized_recipe_map[layer_number] if layer_number in self.layer_number_quantized_recipe_map else None + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + import pdb; pdb.set_trace() + # If BF16 desired --> use autocast(false) so it goes to BF16. + # If FP8 desired --> use nullcontext so it uses upper context manager to FP8. + # If FP4 desired --> use autocast(true, recipe=fp4_recipe) so it uses FP4. + if isinstance(fp_recipe, FP8_RECIPES): + fp_context = nullcontext() + elif isinstance(fp_recipe, FP4_RECIPES): + fp_context = transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) + else: + fp_context = transformer_engine.pytorch.autocast(enabled=False) + # TODO(@jomitchell): Double check that this works, make a funciton for it then unit test it. + + with fp_context: + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + @classmethod + def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool): + """Override the default get_init_context method to allow for fp8 model initialization.""" + return [] + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys = ("lm_head.decoder.weight",) + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.init_weights() + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.init_weights() + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/esm2_native_te/requirements.txt b/bionemo-recipes/recipes/esm2_native_te/requirements.txt index b18607fd7a..0602ca8a83 100644 --- a/bionemo-recipes/recipes/esm2_native_te/requirements.txt +++ b/bionemo-recipes/recipes/esm2_native_te/requirements.txt @@ -8,6 +8,6 @@ torchdata torchmetrics tqdm transformer_engine[pytorch] -transformers +transformers==4.57.3 wandb nvdlfw_inspect @ git+https://github.com/NVIDIA/nvidia-dlfw-inspect diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 1027703f3c..002cbb94db 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -113,7 +113,7 @@ def main(args: DictConfig) -> float | None: device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, device_mesh=device_mesh["ddp"], - ) + ) #TODO: Try BF16 compute weights with FP32 master weights here. # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. train_dataloader, dataset_or_sampler = ( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 28409e0c17..cfbe109a42 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -14,21 +14,29 @@ # limitations under the License. import logging +import tempfile from contextlib import nullcontext from pathlib import Path +from torch.profiler import profile, ProfilerActivity import hydra import nvdlfw_inspect.api as debug_api import torch import transformer_engine import transformer_engine.pytorch +import yaml + from omegaconf import DictConfig, OmegaConf from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import fully_shard +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy from torch.optim import AdamW + +from transformer_engine.pytorch.optimizers import FusedAdam from transformer_engine.common.recipe import Format from transformers import AutoConfig, AutoModelForMaskedLM +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + # This import seems to be needed with meta device init and AutoModel.from_config from transformers.models.esm.modeling_esm import EsmForMaskedLM # noqa: F401 @@ -43,6 +51,86 @@ logger.setLevel(logging.INFO) +def generate_layer_regex(layer_numbers: list[int] | None) -> str: + """Generate a regex pattern to match specific layer numbers (1-indexed). + + Args: + layer_numbers: List of layer numbers (1-indexed, as shown in logs). + If empty or None, returns a pattern that matches nothing. + + Returns: + Regex pattern string for matching those layers' linear sublayers. + """ + if not layer_numbers: + # Return a pattern that matches nothing (non-existent layer) + return r"model\.esm\.encoder\.layers\.DISABLED_NO_LAYERS_SPECIFIED" + # Use alternation for arbitrary layer lists: (1|2|3|4|5) + layer_pattern = "|".join(str(n) for n in sorted(layer_numbers)) + return rf"model\.esm\.encoder\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)" + + +def update_quant_stats_config( + config_file: str, + fp4_layers: list[int] | None, + fp8_layers: list[int] | None, +) -> str: + """Update the quant stats YAML config with layer-specific regex patterns. + + Args: + config_file: Path to the original YAML config file. + fp4_layers: List of layer numbers for FP4 (1-indexed). + fp8_layers: List of layer numbers for FP8 (1-indexed). + + Returns: + Path to the updated config file (may be a temp file). + + Raises: + ValueError: If fp4_layers and fp8_layers have overlapping layer numbers. + """ + # Check for overlapping layers + fp4_set = set(fp4_layers) if fp4_layers else set() + fp8_set = set(fp8_layers) if fp8_layers else set() + overlap = fp4_set & fp8_set + if overlap: + raise ValueError( + f"fp4_layers and fp8_layers cannot have overlapping layer numbers. " + f"Found overlap: {sorted(overlap)}" + ) + + with open(config_file, "r") as f: + config = yaml.safe_load(f) + + # Update FP4 section if it exists (always update, even if empty to disable matching) + if "example_fp4_tensor_stat_collection" in config: + fp4_regex = generate_layer_regex(fp4_layers) + config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex + if fp4_layers: + logger.info(f"Updated FP4 layer regex to match layers: {fp4_layers}") + else: + logger.info("FP4 layers empty - regex set to match nothing") + + # Update FP8 section if it exists (always update, even if empty to disable matching) + if "example_fp8_tensor_stat_collection" in config: + fp8_regex = generate_layer_regex(fp8_layers) + config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex + if fp8_layers: + logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}") + else: + logger.info("FP8 layers empty - regex set to match nothing") + + # Write to a temp file to avoid modifying the original + temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + yaml.dump(config, temp_file, default_flow_style=False) + temp_file.close() + + # Log the updated config for visibility + config_str = yaml.dump(config, default_flow_style=False) + logger.info(f"Created updated quant stats config at: {temp_file.name}") + logger.info(f"Updated quant stats config contents:\n{config_str}") + + return temp_file.name + + @hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2") def main(args: DictConfig) -> float | None: """Train ESM-2 with TE layers using fsdp2. @@ -57,22 +145,32 @@ def main(args: DictConfig) -> float | None: torch.distributed.init_process_group(backend="nccl", device_id=device) torch.cuda.set_device(dist_config.local_rank) - # TE Debug feature logging - MUST be done BEFORE FSDP wrapping - if args.fp8_stats_config.enabled and not args.fp8_config.enabled: - raise ValueError( - "fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats" + # Parse layer lists first (1-indexed from args, used for both quant stats and internal recipe mapping) + fp8_layers_1indexed = OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None and args.fp8_config.enabled else None + fp4_layers_1indexed = OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None and args.fp4_config.enabled else None + + # Convert to 0-indexed for internal use (use 'is not None' to handle empty lists correctly) + fp8_layers = [layer - 1 for layer in fp8_layers_1indexed] if fp8_layers_1indexed is not None else None + fp4_layers = [layer - 1 for layer in fp4_layers_1indexed] if fp4_layers_1indexed is not None else None + + if args.quant_stats_config.enabled: + quant_stats_file = args.quant_stats_config.quant_stats_file + + # Update the quant stats config with layer-specific regex patterns (using 1-indexed layer numbers) + quant_stats_file = update_quant_stats_config( + config_file=quant_stats_file, + fp4_layers=fp4_layers_1indexed, + fp8_layers=fp8_layers_1indexed, ) - - if args.fp8_stats_config.enabled: - fp8_stats_file = args.fp8_stats_config.fp8_stats_file - fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}" - fp8_log_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Logging FP8 stats to {fp8_log_dir}") + + quant_log_dir = Path(args.quant_stats_config.quant_log_dir) / f"rank_{dist_config.rank}" + quant_log_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Logging quant stats to {quant_log_dir}") te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features") debug_api.initialize( - config_file=fp8_stats_file, + config_file=quant_stats_file, feature_dirs=[te_features_dir], - log_dir=fp8_log_dir, + log_dir=quant_log_dir, default_logging_enabled=True, ) @@ -84,12 +182,17 @@ def main(args: DictConfig) -> float | None: ) # Create an FP8 recipe -- this is only used if FP8 is enabled in the config. - fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( - fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs - ) + if args.fp8_config.enabled: + fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)( + fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs + ) + + if args.fp4_config.enabled: + fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)( + fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs + ) - # Create an empty ESM-2 model with a masked language model head, e.g. "nvidia/esm2_t6_8M_UR50D". - config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True, dtype=torch.bfloat16) + config = NVEsmConfig.from_pretrained(args.model_tag, dtype=torch.float32 if args.use_fp32_master_weights else torch.bfloat16) # If we're using sequence packing with TE layers, we need to pass the `attn_input_format` argument. if args.use_sequence_packing: config.attn_input_format = "thd" @@ -99,19 +202,41 @@ def main(args: DictConfig) -> float | None: # versions of weights are kept. with ( torch.device("meta") if args.use_meta_device else nullcontext(), - transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe, **args.fp8_config.fp8_model_init_kwargs), ): - model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + # model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + model = NVEsmForMaskedLM(config) logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer - for layer in transformer_stack: - fully_shard(layer, mesh=device_mesh["dp"]) - fully_shard(model, mesh=device_mesh["dp"]) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, # Cast params to BF16 for forward/backward + reduce_dtype=torch.float32, # Gradient reductions in FP32 + output_dtype=torch.bfloat16, # Forward output dtype + ) + if args.use_fp32_master_weights: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + else: + for layer in transformer_stack: + fully_shard(layer, mesh=device_mesh["dp"]) + fully_shard(model, mesh=device_mesh["dp"]) + # Create a layer map for the transformer stack. + layer_number_quantized_recipe_map = {} + fp8_layers_set = set(fp8_layers) if fp8_layers else set() + fp4_layers_set = set(fp4_layers) if fp4_layers else set() + for layer_number, layer in enumerate(transformer_stack): + if layer_number in fp8_layers_set: + layer_number_quantized_recipe_map[layer_number] = fp8_recipe + elif layer_number in fp4_layers_set: + layer_number_quantized_recipe_map[layer_number] = fp4_recipe + else: + layer_number_quantized_recipe_map[layer_number] = None + model.esm.encoder.layer_number_quantized_recipe_map = layer_number_quantized_recipe_map # If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters. # Note, this should happen before we create the optimizer. if args.use_meta_device: @@ -123,11 +248,20 @@ def main(args: DictConfig) -> float | None: model.apply(model._init_weights) # Assign names to layers so debug API can identify them - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.infer_and_assign_layer_names(model) # Create optimizer. Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873). optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore + # optimizer = FusedAdam(model.parameters(), + # lr=4e-4, + # betas=(0.9, 0.98), + # eps=1e-8, + # weight_decay=0.01, + # master_weights=True, + # master_weight_dtype=torch.float32, + # ) + # Note: Got an error about mixed torch.Tensor and DTensor here, so using AdamW instead. scheduler = get_linear_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs) # If we're using sequence packing, create a THD dataloader, otherwise create a BSHD dataloader. @@ -163,10 +297,16 @@ def main(args: DictConfig) -> float | None: while step < args.num_train_steps: for batch in train_dataloader: batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # noqa: PLW2901 - - # Forward pass with mixed precision. - with transformer_engine.pytorch.fp8_autocast(enabled=args.fp8_config.enabled, fp8_recipe=fp8_recipe): + + # Use an outer FP8 recipe. + with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe if args.fp8_config.enabled else None): outputs = model(**batch) + + # if step == 5: # Profile step 5 + # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe): + # outputs = model(**batch) + # logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) # Backward pass. loss = outputs.loss @@ -179,7 +319,7 @@ def main(args: DictConfig) -> float | None: optimizer.step() scheduler.step() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.step() optimizer.zero_grad() @@ -223,7 +363,7 @@ def main(args: DictConfig) -> float | None: # Clean up distributed training perf_logger.finish() - if args.fp8_stats_config.enabled: + if args.quant_stats_config.enabled: debug_api.end_debug() torch.distributed.destroy_process_group()