From e415bb9d8b1dac8f2fda761453c77e4cd9966c18 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 18 Feb 2026 23:33:01 +0000 Subject: [PATCH 1/9] draft pr - esm2 working with vllm --- .../models/esm2/src/esm/convert.py | 36 +- bionemo-recipes/models/esm2/src/esm/export.py | 14 +- .../models/esm2/src/esm/modeling_esm_te.py | 75 +- .../models/esm2/tests/test_convert.py | 185 +++++ .../models/esm2/tests/test_cp_bshd.py | 10 +- .../models/esm2/tests/test_cp_thd.py | 10 +- .../models/esm2/tests/test_distributed_fp8.py | 2 +- .../esm2/tests/test_distributed_strategies.py | 2 +- bionemo-recipes/models/esm2/tests/test_fp8.py | 192 +++++ .../esm2/tests/test_meta_device_init.py | 294 ++++++++ .../models/esm2/tests/test_modeling_esm_te.py | 415 +++++------ bionemo-recipes/models/esm2/tests/test_thd.py | 329 +++++++++ .../example_8m_checkpoint/esm_nv.py | 75 +- .../example_8m_checkpoint/esm_nv.py | 75 +- .../esm2_native_te/tests/test_stop_and_go.py | 4 +- .../recipes/esm2_native_te/train_ddp.py | 2 +- .../recipes/esm2_native_te/train_ddp_cp.py | 2 +- .../recipes/esm2_native_te/train_fsdp2.py | 4 +- .../recipes/esm2_native_te/train_fsdp2_cp.py | 4 +- .../example_8m_checkpoint/esm_nv.py | 75 +- .../vllm/esm2_vllm_converted/config.json | 46 ++ .../vllm/esm2_vllm_converted/esm_nv.py | 681 ++++++++++++++++++ .../special_tokens_map.json | 44 ++ .../vllm/esm2_vllm_converted/tokenizer.json | 176 +++++ .../esm2_vllm_converted/tokenizer_config.json | 18 + .../vllm_conversion_info.json | 5 + .../vllm/esm2_vllm_converted/vocab.txt | 33 + bionemo-recipes/vllm/vllm_test.py | 193 +++++ 28 files changed, 2596 insertions(+), 405 deletions(-) create mode 100644 bionemo-recipes/models/esm2/tests/test_convert.py create mode 100644 bionemo-recipes/models/esm2/tests/test_fp8.py create mode 100644 bionemo-recipes/models/esm2/tests/test_meta_device_init.py create mode 100644 bionemo-recipes/models/esm2/tests/test_thd.py create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/config.json create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json create mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt create mode 100644 bionemo-recipes/vllm/vllm_test.py diff --git a/bionemo-recipes/models/esm2/src/esm/convert.py b/bionemo-recipes/models/esm2/src/esm/convert.py index 0406e1bbb..4007ac369 100644 --- a/bionemo-recipes/models/esm2/src/esm/convert.py +++ b/bionemo-recipes/models/esm2/src/esm/convert.py @@ -23,18 +23,18 @@ mapping = { - "esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight", - "esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias", - "esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", - "esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", - "esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight", - "esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias", - "esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight", - "esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias", - "esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight", - "esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias", - "esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight", - "esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias", + "esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight", + "esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias", + "esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight", + "esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias", + "esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight", + "esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias", "lm_head.dense.weight": "lm_head.dense.weight", "lm_head.dense.bias": "lm_head.dense.bias", "lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight", @@ -146,7 +146,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: "esm.encoder.layer.*.attention.self.key.weight", "esm.encoder.layer.*.attention.self.value.weight", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", ) def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" @@ -168,7 +168,7 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): "esm.encoder.layer.*.attention.self.key.bias", "esm.encoder.layer.*.attention.self.value.bias", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", ) def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" @@ -185,7 +185,7 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): @io.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", target_key=( "esm.encoder.layer.*.attention.self.query.weight", "esm.encoder.layer.*.attention.self.key.weight", @@ -214,7 +214,7 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight): @io.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", target_key=( "esm.encoder.layer.*.attention.self.query.bias", "esm.encoder.layer.*.attention.self.key.bias", @@ -259,7 +259,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed): _pad_embeddings = io.state_transform( source_key="esm.embeddings.word_embeddings.weight", - target_key="esm.embeddings.word_embeddings.weight", + target_key="model.embeddings.word_embeddings.weight", )(_pad_weights) _pad_decoder_weights = io.state_transform( @@ -268,7 +268,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed): )(_pad_weights) _unpad_embeddings = io.state_transform( - source_key="esm.embeddings.word_embeddings.weight", + source_key="model.embeddings.word_embeddings.weight", target_key="esm.embeddings.word_embeddings.weight", )(_unpad_weights) diff --git a/bionemo-recipes/models/esm2/src/esm/export.py b/bionemo-recipes/models/esm2/src/esm/export.py index 7eb6437d7..1aa81a81d 100644 --- a/bionemo-recipes/models/esm2/src/esm/export.py +++ b/bionemo-recipes/models/esm2/src/esm/export.py @@ -61,7 +61,13 @@ def export_hf_checkpoint(tag: str, export_path: Path): model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") model_hf = AutoModel.from_pretrained(f"facebook/{tag}") model_hf_masked_lm.esm.pooler = model_hf.pooler - model_te = convert_esm_hf_to_te(model_hf_masked_lm) + + # Export with padded_vocab_size=None (defaults to vocab_size) so that the checkpoint + # stores embeddings/decoder at the real vocab_size without zero-padding. Padding is + # only needed at runtime for FP8 training efficiency; users who train with FP8 pass + # padded_vocab_size explicitly. Keeping vocab_size-sized weights in the checkpoint + # avoids shape-mismatch assertions in vLLM's VocabParallelEmbedding. + model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) model_te.save_pretrained(export_path / tag) tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation. @@ -73,6 +79,12 @@ def export_hf_checkpoint(tag: str, export_path: Path): config["auto_map"] = AUTO_MAP + # Disable pooler in the exported checkpoint. NVEsmForMaskedLM saves with + # add_pooling_layer=False, so pooler weights are absent. Setting this to false + # prevents vLLM from creating a pooler module and then erroring on missing weights. + # (HuggingFace tolerates missing keys via strict=False, but vLLM does not.) + config["add_pooling_layer"] = False + with open(export_path / tag / "config.json", "w") as f: json.dump(config, f, indent=2, sort_keys=True) diff --git a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py index d4ee0845e..fb2e4136d 100644 --- a/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = True, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. Set to + ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` + (which does not use a pooler). This avoids missing-weight errors in vLLM. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +420,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +455,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -451,8 +463,7 @@ def forward( **kwargs, ) sequence_output = outputs[0] - with transformer_engine.pytorch.autocast(enabled=False): - prediction_scores = self.lm_head(sequence_output) + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -483,15 +494,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - with transformer_engine.pytorch.quantized_model_init(enabled=False): - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -511,7 +522,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.autocast(enabled=False): + with transformer_engine.pytorch.fp8_autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -599,10 +610,6 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -610,7 +617,7 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: @@ -633,7 +640,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +666,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/models/esm2/tests/test_convert.py b/bionemo-recipes/models/esm2/tests/test_convert.py new file mode 100644 index 000000000..30483334d --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_convert.py @@ -0,0 +1,185 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from transformers import AutoModelForMaskedLM + + +def test_convert_te_to_hf_roundtrip(): + """Test that converting HF -> TE -> HF produces the same model.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf_original = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + + model_te = convert_esm_hf_to_te(model_hf_original) + model_hf_converted = convert_esm_te_to_hf(model_te) + + original_state_dict = model_hf_original.state_dict() + converted_state_dict = model_hf_converted.state_dict() + original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k} + converted_keys = set(converted_state_dict.keys()) + assert original_keys == converted_keys + + for key in original_state_dict.keys(): + if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key: + torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5) + + +def test_load_from_converted_checkpoint(te_model_checkpoint): + from esm.modeling_esm_te import NVEsmForMaskedLM + + NVEsmForMaskedLM.from_pretrained(te_model_checkpoint) + + +def test_qkv_unpacking(): + """Test that QKV unpacking works correctly.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + for i in range(model_hf.config.num_hidden_layers): + hf_query = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] + hf_key = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] + hf_value = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] + + converted_query = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] + converted_key = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] + converted_value = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] + + torch.testing.assert_close(hf_query, converted_query, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(hf_key, converted_key, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(hf_value, converted_value, atol=1e-5, rtol=1e-5) + + +def test_config_conversion(): + """Test that config conversion works correctly.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + original_config_dict = model_hf.config.to_dict() + converted_config_dict = model_hf_converted.config.to_dict() + + for key, value in original_config_dict.items(): + assert key in converted_config_dict, f"Config field '{key}' missing in converted model" + assert converted_config_dict[key] == value, ( + f"Config field '{key}' differs: original={value}, converted={converted_config_dict[key]}" + ) + + assert model_hf_converted.config.model_type == "esm" + + te_specific_fields = [ + "qkv_weight_interleaved", + "encoder_activation", + "attn_input_format", + "fuse_qkv_params", + "micro_batch_size", + "auto_map", + ] + for field in te_specific_fields: + assert not hasattr(model_hf_converted.config, field), ( + f"TE-specific field '{field}' should not be present in converted model" + ) + + +def test_padding_unpadding_operations(): + """Test that padding and unpadding operations work correctly for embeddings and decoder weights.""" + from esm.convert import convert_esm_hf_to_te, convert_esm_te_to_hf + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + model_te = convert_esm_hf_to_te(model_hf) + model_hf_converted = convert_esm_te_to_hf(model_te) + + # Test word embeddings + original_embeddings = model_hf.state_dict()["esm.embeddings.word_embeddings.weight"] + converted_embeddings = model_hf_converted.state_dict()["esm.embeddings.word_embeddings.weight"] + assert original_embeddings.shape == converted_embeddings.shape, ( + f"Embedding shapes don't match: {original_embeddings.shape} vs {converted_embeddings.shape}" + ) + torch.testing.assert_close(original_embeddings, converted_embeddings, atol=1e-5, rtol=1e-5) + + # Test decoder weights + original_decoder = model_hf.state_dict()["lm_head.decoder.weight"] + converted_decoder = model_hf_converted.state_dict()["lm_head.decoder.weight"] + assert original_decoder.shape == converted_decoder.shape, ( + f"Decoder shapes don't match: {original_decoder.shape} vs {converted_decoder.shape}" + ) + torch.testing.assert_close(original_decoder, converted_decoder, atol=1e-5, rtol=1e-5) + + # Test bias + original_bias = model_hf.state_dict()["lm_head.bias"] + converted_bias = model_hf_converted.state_dict()["lm_head.bias"] + assert original_bias.shape == converted_bias.shape, ( + f"Bias shapes don't match: {original_bias.shape} vs {converted_bias.shape}" + ) + torch.testing.assert_close(original_bias, converted_bias, atol=1e-5, rtol=1e-5) + + # Test that TE model has padded dimensions + te_embeddings = model_te.state_dict()["model.embeddings.word_embeddings.weight"] + te_decoder = model_te.state_dict()["lm_head.decoder.weight"] + assert te_embeddings.shape[0] >= original_embeddings.shape[0], "TE embeddings should be padded" + assert te_decoder.shape[0] >= original_decoder.shape[0], "TE decoder should be padded" + + # The padded parts should be zeros (for embeddings) or min values (for bias) + if te_embeddings.shape[0] > original_embeddings.shape[0]: + padding_rows = te_embeddings[original_embeddings.shape[0] :] + torch.testing.assert_close(padding_rows, torch.zeros_like(padding_rows), atol=1e-6, rtol=1e-6) + + +def test_weight_initialization_matches_hf(): + from transformers import AutoConfig, set_seed + from transformers.models.esm.modeling_esm import EsmForMaskedLM + + from esm.convert import convert_esm_hf_to_te + from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + + set_seed(42) + + config_hf = AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", vocab_size=64, revision="c731040f") + model_hf = EsmForMaskedLM(config_hf) + model_te_converted = convert_esm_hf_to_te(model_hf) + + config = NVEsmConfig(**model_hf.config.to_dict()) + model_te = NVEsmForMaskedLM(config) + model_te.to("cuda") + model_te_converted.to("cuda") + + state_dict_hf = model_te_converted.state_dict() + state_dict_te = model_te.state_dict() + + for name in state_dict_hf.keys(): + if name.endswith("_extra_state"): + continue + + torch.testing.assert_close( + state_dict_te[name].mean(), + state_dict_hf[name].mean(), + atol=1e-3, + rtol=1e-4, + msg=lambda x: f"Mean mismatch for parameter {name}: {x}", + ) + + torch.testing.assert_close( + state_dict_te[name].std(), + state_dict_hf[name].std(), + atol=1e-3, + rtol=1e-4, + msg=lambda x: f"Std mismatch for parameter {name}: {x}", + ) diff --git a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py index e770d56cc..72023a7b6 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py @@ -203,8 +203,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -256,7 +256,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -341,8 +341,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_cp_thd.py b/bionemo-recipes/models/esm2/tests/test_cp_thd.py index 9695e0091..6eb1d58fe 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_thd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_thd.py @@ -194,8 +194,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -247,7 +247,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -338,8 +338,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index c195167cf..b07ac4c49 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -160,7 +160,7 @@ def is_main_process(self) -> bool: model = NVEsmForMaskedLM(config) if args.strategy is Strategy.FSDP2: - for layer in model.esm.encoder.layers: + for layer in model.model.encoder.layers: fully_shard(layer, mesh=device_mesh["dp"]) fully_shard(model, mesh=device_mesh["dp"]) model.to(device) diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py index 3f9523410..452f3a18a 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py @@ -188,7 +188,7 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis revision="c731040f", ) model = NVEsmForMaskedLM(config) - transformer_layers = model.esm.encoder.layers + transformer_layers = model.model.encoder.layers else: model = AutoModelForMaskedLM.from_pretrained( "facebook/esm2_t6_8M_UR50D", diff --git a/bionemo-recipes/models/esm2/tests/test_fp8.py b/bionemo-recipes/models/esm2/tests/test_fp8.py new file mode 100644 index 000000000..e5d584d3d --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_fp8.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import torch.distributed.checkpoint as dcp +import transformer_engine +from torch.distributed.checkpoint.state_dict import get_model_state_dict +from transformer_engine.common import recipe as recipe_module +from transformers import DataCollatorForLanguageModeling + +from esm.collator import DataCollatorWithFlattening +from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + + +try: + from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +except ImportError: # TE nightly uses a new import path for QuantizedTensor + from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + +@pytest.fixture +def input_data_thd(tokenizer, tokenized_proteins): + mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42) + data_collator = DataCollatorWithFlattening( + collator=mlm_collator, + pad_to_multiple_of=32, # MXFP8 requires the sequence length to be divisible by 32, regular FP8 requires 16. + ) + + return data_collator(tokenized_proteins) + + +def test_fp8_forward_and_backward_pass(te_model_checkpoint, input_data, fp8_recipe): + model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + model_te.to("cuda") + + input_data = {k: v.to("cuda") for k, v in input_data.items()} + outputs = model_te(**input_data) + + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs_fp8 = model_te(**input_data) + outputs_fp8.loss.backward() + + if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling): + atol = 0.2 + rtol = 0.05 + else: + atol = None + rtol = None + + torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol) + + +def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd, fp8_recipe, monkeypatch): + if torch.cuda.get_device_capability() == (12, 0): + # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default, + # but it's missing this THD implementation. + monkeypatch.setenv("NVTE_FUSED_ATTN", "0") + + model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + model_te.to("cuda") + + input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + outputs = model_te(**input_data) + + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs_fp8 = model_te(**input_data) + outputs_fp8.loss.backward() + + if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling): + atol = 0.2 + rtol = 0.05 + elif isinstance(fp8_recipe, recipe_module.DelayedScaling): + atol = 0.1 + rtol = 0.03 + else: + atol = None + rtol = None + + torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol) + + +def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data, fp8_recipe): + config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): + model_te = NVEsmForMaskedLM(config) + + assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor) + + model_te.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} + + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs_fp8 = model_te(**input_data) + + outputs_fp8.loss.backward() + + +@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained loading is not currently supported.") +def test_fp8_model_init_from_pretrained(te_model_checkpoint, fp8_recipe): + # TODO: this will be renamed to quantized_model_init in the future, fp8_model_init will be removed in 3.0 + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): + model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + + assert isinstance(model_te.model.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor) + assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor) + + +@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained saving is not currently supported.") +def test_fp8_model_init_save_pretrained(te_model_checkpoint, tmp_path, fp8_recipe): + config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): + model_fp8 = NVEsmForMaskedLM(config) + + assert isinstance(model_fp8.model.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor) + assert isinstance(model_fp8.lm_head.dense.weight, QuantizedTensor) + + model_fp8.save_pretrained(tmp_path / "fp8_checkpoint") + del model_fp8 + NVEsmForMaskedLM.from_pretrained(tmp_path / "fp8_checkpoint", dtype=torch.bfloat16) + + +def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint, tmp_path, input_data, fp8_recipe): + config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): + model_fp8 = NVEsmForMaskedLM(config) + + model_fp8.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} + with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + outputs = model_fp8(**input_data) + outputs.loss.backward() + + state_dict = get_model_state_dict(model_fp8) + state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")} + dcp.save(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint") + + del model_fp8, state_dict + + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): + model_fp8 = NVEsmForMaskedLM(config) + + state_dict = model_fp8.state_dict() + state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")} + dcp.load(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint") + + +def _format_bytes(num: int, suffix: str = "B") -> str: + """Format bytes as a human-readable string (e.g. 1.2 MB).""" + for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): + if abs(num) < 1024.0: + return f"{num:3.1f} {unit}{suffix}" + num /= 1024.0 + return f"{num:.1f} Y{suffix}" + + +@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init seems to have issues.") +def test_fp8_model_init_uses_less_memory(te_model_checkpoint, fp8_recipe): + torch.cuda.empty_cache() + + config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + torch.cuda.reset_peak_memory_stats() + memory_before = torch.cuda.memory_allocated() + with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe), torch.device("cuda"): + model_fp8 = NVEsmForMaskedLM(config) + peak_memory_fp8 = torch.cuda.max_memory_allocated() - memory_before + del model_fp8 + torch.cuda.empty_cache() + + torch.cuda.reset_peak_memory_stats() + memory_before = torch.cuda.memory_allocated() + with transformer_engine.pytorch.fp8_model_init(enabled=False, recipe=fp8_recipe), torch.device("cuda"): + model_bf16 = NVEsmForMaskedLM(config) + peak_memory_bf16 = torch.cuda.max_memory_allocated() - memory_before + del model_bf16 + + assert peak_memory_fp8 < peak_memory_bf16, ( + f"FP8 model init uses more memory than BF16 model init: {_format_bytes(peak_memory_fp8)} " + f"vs {_format_bytes(peak_memory_bf16)}" + ) diff --git a/bionemo-recipes/models/esm2/tests/test_meta_device_init.py b/bionemo-recipes/models/esm2/tests/test_meta_device_init.py new file mode 100644 index 000000000..ab06a2baf --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_meta_device_init.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test that parameter distributions are identical with and without meta device initialization. + +These tests verify that when using meta device initialization (creating the model on meta device, then calling +`to_empty` and `_init_weights`), the resulting parameter distributions (mean and std) match those from normal +initialization. This is important because we previously observed differences in convergence between meta-device-init and +non-meta-device-init training, which suggested that the initialization was not being applied correctly after `to_empty`. +By explicitly calling `_init_weights` after `to_empty`, we ensure that parameters are properly initialized, leading to +consistent training behavior regardless of whether meta device initialization is used. +""" + +import os +import subprocess + +import pytest +import torch +import transformer_engine.pytorch +from torch.distributed.fsdp import fully_shard +from torch.distributed.tensor import DTensor +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformers import AutoConfig, set_seed + +from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM, NVEsmForTokenClassification + + +requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + +def verify_model_parameters_initialized_correctly( + model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False +): + config = model.config + + for name, parameter in model.named_parameters(): + assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" + + for name, module in model.named_modules(): + + def msg(x): + return f"Mismatch in module {name}: {x}" + + if isinstance(module, torch.nn.Embedding): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg + ) + + elif name == "lm_head.decoder": + # Make sure the lm_head decoder weights are still tied to the encoder weights + assert module.weight is model.model.embeddings.word_embeddings.weight, ( + "Decoder weight tying has been broken" + ) + + elif isinstance(module, transformer_engine.pytorch.Linear): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg + ) + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + if should_be_fp8: + assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a QuantizedTensor" + + elif isinstance(module, transformer_engine.pytorch.LayerNormLinear): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg + ) + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + torch.testing.assert_close(module.layer_norm_weight, torch.ones_like(module.layer_norm_weight), msg=msg) + torch.testing.assert_close(module.layer_norm_bias, torch.zeros_like(module.layer_norm_bias), msg=msg) + if should_be_fp8: + assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a QuantizedTensor" + + elif isinstance(module, transformer_engine.pytorch.LayerNormMLP): + torch.testing.assert_close(module.fc1_weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.fc1_weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg + ) + torch.testing.assert_close(module.fc2_weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.fc2_weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg + ) + torch.testing.assert_close(module.fc1_bias, torch.zeros_like(module.fc1_bias), msg=msg) + torch.testing.assert_close(module.fc2_bias, torch.zeros_like(module.fc2_bias), msg=msg) + torch.testing.assert_close(module.layer_norm_weight, torch.ones_like(module.layer_norm_weight), msg=msg) + torch.testing.assert_close(module.layer_norm_bias, torch.zeros_like(module.layer_norm_bias), msg=msg) + if should_be_fp8: + assert isinstance(module.fc1_weight, QuantizedTensor), ( + f"Module {name} fc1_weight is not a QuantizedTensor" + ) + assert isinstance(module.fc2_weight, QuantizedTensor), ( + f"Module {name} fc2_weight is not a QuantizedTensor" + ) + + elif isinstance(module, torch.nn.LayerNorm): + torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + elif isinstance(module, transformer_engine.pytorch.attention.rope.RotaryPositionEmbedding): + dim = config.hidden_size // config.num_attention_heads + expected_inv_freq = 1.0 / (10_000.0 ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim)) + torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg) + + +def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-2, rtol=1e-3): + for name, p in model.named_parameters(): + + def msg(x): + return f"Mismatch in parameter {name}: {x}" + + assert p.numel() > 0, f"{name} is empty" + assert torch.isfinite(p).all(), f"{name} has NaN/Inf" + + max_abs = p.abs().max().item() + assert max_abs < 1e3, f"{name} extreme values: {max_abs}" + + if name == "classifier.weight": + torch.testing.assert_close(p.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close(p.std().item(), model.config.initializer_range, atol=atol, rtol=rtol, msg=msg) + + if name == "classifier.bias": + torch.testing.assert_close(p, torch.zeros_like(p), msg=msg) + + +def test_cuda_init(): + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) + + set_seed(42) + model = NVEsmForMaskedLM(config) + model.to("cuda") + + verify_model_parameters_initialized_correctly(model) + + +def test_meta_init(): + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) + + set_seed(42) + with torch.device("meta"): + model = NVEsmForMaskedLM(config) + + # Assert parameters are actually on the meta device + for name, parameter in model.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + # Move the model to the cuda device and initialize the parameters + model.init_empty_weights() + + verify_model_parameters_initialized_correctly(model) + + +def test_cuda_fp8_init(fp8_recipe): + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) + + set_seed(42) + with transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe): + model = NVEsmForMaskedLM(config) + + model.to("cuda") + + verify_model_parameters_initialized_correctly(model, atol=1e-2, should_be_fp8=True) + + +def test_meta_fp8_init(fp8_recipe): + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) + + set_seed(42) + with transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe), torch.device("meta"): + model = NVEsmForMaskedLM(config) + + # Move the model to the cuda device and initialize the parameters + model.init_empty_weights() + + verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + + +def test_model_for_token_classification_init(te_model_checkpoint): + set_seed(42) + + config = NVEsmConfig.from_pretrained(te_model_checkpoint) + model = NVEsmForTokenClassification.from_pretrained(te_model_checkpoint, config=config, dtype=torch.bfloat16) + # model.classifier.reset_parameters() + model.to("cuda") + verify_pretrained_model_sanity(model) + + +@pytest.mark.parametrize("num_gpus", [1, pytest.param(2, marks=requires_multi_gpu)]) +def test_meta_device_init_after_fully_shard(num_gpus: int): + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpus}", + os.path.relpath(__file__), + ] + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=240, + ) + + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"Command failed with exit code {result.returncode}") + + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="cuda:nccl") + torch.cuda.set_device(torch.distributed.get_rank()) + + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) + + set_seed(42) + + with torch.device("meta"): + model_meta_device = NVEsmForMaskedLM(config) + + for layer in model_meta_device.model.encoder.layers: + fully_shard(layer) + fully_shard(model_meta_device) + + # Assert parameters are actually on the meta device + for name, parameter in model_meta_device.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + model_meta_device.init_empty_weights() + + # Assert parameters are actually on the cuda device after to_empty + for name, parameter in model_meta_device.named_parameters(): + assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" + + set_seed(42) + model_normal_init = NVEsmForMaskedLM(config) + + for layer in model_normal_init.model.encoder.layers: + fully_shard(layer) + fully_shard(model_normal_init) + + state_dict_meta_init = model_meta_device.state_dict() + state_dict_normal_init = model_normal_init.state_dict() + + for key in state_dict_meta_init.keys(): + if key.endswith("_extra_state"): + continue + + meta_tensor = state_dict_meta_init[key] + normal_tensor = state_dict_normal_init[key] + + torch.testing.assert_close( + normal_tensor.mean(), + meta_tensor.mean(), + atol=1e-3, + rtol=1e-4, + msg=lambda x: f"Mean mismatch for parameter {key}: {x}", + ) + + if isinstance(normal_tensor, DTensor) and isinstance(meta_tensor, DTensor): + torch.testing.assert_close( + normal_tensor.full_tensor().std(), + meta_tensor.full_tensor().std(), + atol=1e-2, + rtol=1e-4, + msg=lambda x: f"Std mismatch for parameter {key}: {x}", + ) + + else: + torch.testing.assert_close( + normal_tensor.std(), + meta_tensor.std(), + atol=1e-2, + rtol=1e-4, + msg=lambda x: f"Std mismatch for parameter {key}: {x}", + ) diff --git a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py index c71e906a8..3a27d5112 100644 --- a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py @@ -13,260 +13,211 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ESM2 model using the common test library. - -This file provides comprehensive tests for the ESM2 model including: -- Meta device initialization tests -- Golden value tests against HuggingFace reference models -- Conversion tests (HF ↔ TE) -- FP8 tests -- Model-specific tests - -Most tests are inherited from the common test library to reduce duplication. -""" - -from typing import Callable, Dict, List, Literal, Type from unittest.mock import MagicMock import torch from torch import nn -from transformers import ( - AutoTokenizer, - DataCollatorForLanguageModeling, - PretrainedConfig, - PreTrainedModel, - PreTrainedTokenizer, -) -from transformers.models.esm.modeling_esm import EsmForMaskedLM - -from esm.collator import DataCollatorWithFlattening -from esm.convert import ( - _pack_qkv_bias, - _pack_qkv_weight, - _pad_bias, - _pad_weights, - convert_esm_hf_to_te, - convert_esm_te_to_hf, - mapping, -) -from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM -from tests.common import BaseModelTest, TestTolerances - - -class TestESM2Model(BaseModelTest): - """Model tester for ESM2. - - This class provides ESM2-specific configuration for the common test suite. - """ - - def get_model_class(self) -> Type[PreTrainedModel]: - """Return the ESM2 TE model class.""" - return NVEsmForMaskedLM - - def get_config_class(self) -> Type[PretrainedConfig]: - """Return the ESM2 config class.""" - return NVEsmConfig - - def get_upstream_model_id(self) -> str: - """Return the upstream HuggingFace model ID.""" - return "facebook/esm2_t6_8M_UR50D" - - def get_upstream_model_revision(self) -> str: - """Return the specific revision for the upstream model.""" - return "c731040f" - - def get_upstream_model_class(self) -> Type[PreTrainedModel]: - """Return the upstream HuggingFace model class.""" - return EsmForMaskedLM - - def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: - """Return the list of transformer layers.""" - return list(model.esm.encoder.layers) # type: ignore - - def get_reference_model_no_weights(self) -> PreTrainedModel: - """For checkpoint conversion tests to pass, we need to remove the unused contact head.""" - model = super().get_reference_model_no_weights() - del model.esm.contact_head - return model - - def get_test_input_data( - self, - format: Literal["bshd", "thd"] = "bshd", - pad_to_multiple_of: int | None = None, - ) -> Dict[str, torch.Tensor]: - """Prepare test input data (protein sequences).""" - - tokenizer = self.get_tokenizer() - - # Use real protein sequences - test_proteins = [ - "MLSATEKLSDYISSLFASVSIINSISTEDLFFLKLTCQTFSKDSEEYKAAYRILRGVQRGKVQIIEEALVS", - "MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQDASKYLEYAIDSV", - "MAATGSLILSDEEQAELIALAVRIVLACAGGSQNKELAAQLGVIETTVGEWRRRFAQNRVEGLRDEARPGAPSDDQ", - "MSAVLSAVASDDWTAFAKLVHPYVHWTADGITTRGRTRVMARLSGHDGVKPASSYELRDGQVYRWTS", - "MSDPAAEPPADTSGIAWRKSSYSGPNGNCVELAQISGDHVGIRNSRDLHGSVLTCTRAEFAALLCDIKAGRFDSLIL", - ] +from transformers import AutoConfig, AutoModelForMaskedLM - # Tokenize - tokenized = [tokenizer(p, truncation=True, max_length=1024) for p in test_proteins] - # Use data collator for MLM - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=0.15, - pad_to_multiple_of=pad_to_multiple_of if format == "bshd" else None, - seed=42, - ) +def test_esm_model_for_masked_lm(input_data): + from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM - if format == "thd": - data_collator = DataCollatorWithFlattening( - collator=data_collator, - pad_sequences_to_be_divisible_by=pad_to_multiple_of, - ) + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) + model = NVEsmForMaskedLM(config) + model.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} - batch = data_collator(tokenized) + with torch.no_grad(): + outputs = model(**input_data) + assert outputs.loss - # Move to device - return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} - def get_hf_to_te_converter(self) -> Callable: - """Return the HF to TE conversion function.""" - return convert_esm_hf_to_te +def test_esm_model_has_all_te_layers(input_data): + from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM - def get_te_to_hf_converter(self) -> Callable: - """Return the TE to HF conversion function.""" - return convert_esm_te_to_hf + config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) + model = NVEsmForMaskedLM(config) + for name, module in model.named_modules(): + assert not isinstance(module, nn.Linear), f"Vanilla linear layer found in {name}" + assert not isinstance(module, nn.LayerNorm), f"Vanilla LayerNorm layer found in {name}" - def get_tolerances(self) -> TestTolerances: - """Return ESM2-specific test tolerances.""" - return TestTolerances( - golden_value_loss_atol=1e-2, - golden_value_loss_rtol=1e-3, - golden_value_logits_atol=2.0, # Higher tolerance needed after transformers PR#40370 - golden_value_logits_rtol=1e-4, - cp_loss_atol=0.1, - cp_loss_rtol=0.05, - ) - def get_tokenizer(self) -> PreTrainedTokenizer: - """Return the ESM2 tokenizer.""" - return AutoTokenizer.from_pretrained("esm_fast_tokenizer") - - # ==================== ESM2-Specific Tests ==================== - - def test_convert_state_dict_explicit_check(self): - """Test detailed state dict conversion and mapping.""" - - input_data = self.get_test_input_data() - model_hf = self.get_reference_model() - model_te = self.get_converted_te_model() - - model_hf.to("cuda") - model_te.to("cuda") - input_data = {k: v.to("cuda") for k, v in input_data.items()} - - with torch.no_grad(): - outputs = model_te(**input_data) - assert outputs.loss - - te_state_dict_keys = { - k for k in model_te.state_dict().keys() if not k.endswith("_extra_state") and not k.endswith("inv_freq") - } - - # Check standard mapping - for k, v in mapping.items(): - if "*" in k: - for i in range(model_hf.config.num_hidden_layers): - k_sub = k.replace("*", str(i)) - v_sub = v.replace("*", str(i)) - torch.testing.assert_close( - model_te.state_dict()[v_sub], - model_hf.state_dict()[k_sub], - msg=lambda x: f"{k} {i} is not close: {x}", - ) - te_state_dict_keys.remove(v_sub) - else: +def test_convert_state_dict(input_data): + from esm.convert import _pack_qkv_bias, _pack_qkv_weight, _pad_bias, _pad_weights, convert_esm_hf_to_te, mapping + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + model_te = convert_esm_hf_to_te(model_hf) + model_hf.to("cuda") + model_te.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} + + with torch.no_grad(): + outputs = model_te(**input_data) + assert outputs.loss + + te_state_dict_keys = { + k for k in model_te.state_dict().keys() if not k.endswith("_extra_state") and not k.endswith("inv_freq") + } + + for k, v in mapping.items(): + if "*" in k: + for i in range(model_hf.config.num_hidden_layers): + k_sub = k.replace("*", str(i)) + v_sub = v.replace("*", str(i)) torch.testing.assert_close( - model_te.state_dict()[v], - model_hf.state_dict()[k], - msg=lambda x: f"{k} is not close: {x}", + model_te.state_dict()[v_sub], + model_hf.state_dict()[k_sub], + msg=lambda x: f"{k} {i} is not close: {x}", ) - te_state_dict_keys.remove(v) - - # Check packed QKV weights - for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.weight" - v = [ - f"esm.encoder.layer.{i}.attention.self.query.weight", - f"esm.encoder.layer.{i}.attention.self.key.weight", - f"esm.encoder.layer.{i}.attention.self.value.weight", - ] - - ctx_mock = MagicMock() - ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads - - packed_weight = _pack_qkv_weight.transform( - ctx_mock, - model_hf.state_dict()[v[0]], - model_hf.state_dict()[v[1]], - model_hf.state_dict()[v[2]], + te_state_dict_keys.remove(v_sub) + else: + torch.testing.assert_close( + model_te.state_dict()[v], + model_hf.state_dict()[k], + msg=lambda x: f"{k} is not close: {x}", ) + te_state_dict_keys.remove(v) + + # # We untie these weights so we need to compare and remove manually + # torch.testing.assert_close( + # model_te.state_dict()["lm_head.layer_norm_decoder.weight"], + # model_hf.state_dict()["lm_head.decoder.weight"], + # ) + # te_state_dict_keys.remove("lm_head.layer_norm_decoder.weight") + + for i in range(model_hf.config.num_hidden_layers): + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.weight" + v = [ + f"esm.encoder.layer.{i}.attention.self.query.weight", + f"esm.encoder.layer.{i}.attention.self.key.weight", + f"esm.encoder.layer.{i}.attention.self.value.weight", + ] - torch.testing.assert_close(packed_weight, model_te.state_dict()[k]) - te_state_dict_keys.remove(k) - - # Check packed QKV biases - for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.bias" - v = [ - f"esm.encoder.layer.{i}.attention.self.query.bias", - f"esm.encoder.layer.{i}.attention.self.key.bias", - f"esm.encoder.layer.{i}.attention.self.value.bias", - ] - - ctx_mock = MagicMock() - ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads - - packed_weight = _pack_qkv_bias.transform( - ctx_mock, - model_hf.state_dict()[v[0]], - model_hf.state_dict()[v[1]], - model_hf.state_dict()[v[2]], - ) - - torch.testing.assert_close(packed_weight, model_te.state_dict()[k]) - te_state_dict_keys.remove(k) - - # Check padded embeddings and LM head ctx_mock = MagicMock() - ctx_mock.target.config.padded_vocab_size = model_te.config.padded_vocab_size + ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads - torch.testing.assert_close( - _pad_weights(ctx_mock, model_hf.state_dict()["esm.embeddings.word_embeddings.weight"]), - model_te.state_dict()["esm.embeddings.word_embeddings.weight"], - ) - torch.testing.assert_close( - _pad_weights(ctx_mock, model_hf.state_dict()["lm_head.decoder.weight"]), - model_te.state_dict()["lm_head.decoder.weight"], - ) - torch.testing.assert_close( - _pad_bias.transform(ctx_mock, model_hf.state_dict()["lm_head.bias"]), - model_te.state_dict()["lm_head.decoder.bias"], + packed_weight = _pack_qkv_weight.transform( + ctx_mock, + model_hf.state_dict()[v[0]], + model_hf.state_dict()[v[1]], + model_hf.state_dict()[v[2]], ) - te_state_dict_keys.remove("esm.embeddings.word_embeddings.weight") - te_state_dict_keys.remove("lm_head.decoder.weight") - te_state_dict_keys.remove("lm_head.decoder.bias") + torch.testing.assert_close(packed_weight, model_te.state_dict()[k]) + te_state_dict_keys.remove(k) - assert len(te_state_dict_keys) == 0 + for i in range(model_hf.config.num_hidden_layers): + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.bias" + v = [ + f"esm.encoder.layer.{i}.attention.self.query.bias", + f"esm.encoder.layer.{i}.attention.self.key.bias", + f"esm.encoder.layer.{i}.attention.self.value.bias", + ] - # Check that the tied weights are the same - assert ( - model_hf.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() - == model_hf.state_dict()["lm_head.decoder.weight"].data_ptr() - ) + ctx_mock = MagicMock() + ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads - assert ( - model_te.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() - == model_te.state_dict()["lm_head.decoder.weight"].data_ptr() + packed_weight = _pack_qkv_bias.transform( + ctx_mock, + model_hf.state_dict()[v[0]], + model_hf.state_dict()[v[1]], + model_hf.state_dict()[v[2]], ) + + torch.testing.assert_close(packed_weight, model_te.state_dict()[k]) + te_state_dict_keys.remove(k) + + ctx_mock = MagicMock() + ctx_mock.target.config.padded_vocab_size = model_te.config.padded_vocab_size + + torch.testing.assert_close( + _pad_weights(ctx_mock, model_hf.state_dict()["esm.embeddings.word_embeddings.weight"]), + model_te.state_dict()["model.embeddings.word_embeddings.weight"], + ) + torch.testing.assert_close( + _pad_weights(ctx_mock, model_hf.state_dict()["lm_head.decoder.weight"]), + model_te.state_dict()["lm_head.decoder.weight"], + ) + torch.testing.assert_close( + _pad_bias.transform(ctx_mock, model_hf.state_dict()["lm_head.bias"]), + model_te.state_dict()["lm_head.decoder.bias"], + ) + + te_state_dict_keys.remove("model.embeddings.word_embeddings.weight") + te_state_dict_keys.remove("lm_head.decoder.weight") + te_state_dict_keys.remove("lm_head.decoder.bias") + + assert len(te_state_dict_keys) == 0 + + # Check that the tied weights are the same + assert ( + model_hf.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() + == model_hf.state_dict()["lm_head.decoder.weight"].data_ptr() + ) + + assert ( + model_te.state_dict()["model.embeddings.word_embeddings.weight"].data_ptr() + == model_te.state_dict()["lm_head.decoder.weight"].data_ptr() + ) + + +def test_golden_values(input_data): + from esm.convert import convert_esm_hf_to_te + + model_hf = AutoModelForMaskedLM.from_pretrained( + "facebook/esm2_t6_8M_UR50D", attn_implementation="flash_attention_2", revision="c731040f" + ) + model_te = convert_esm_hf_to_te(model_hf) + model_te.to(torch.bfloat16) + model_hf.to(torch.bfloat16) + + model_te.to("cuda") + model_hf.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} + + with torch.no_grad(): + te_outputs = model_te(**input_data, output_hidden_states=True) + hf_outputs = model_hf(**input_data, output_hidden_states=True) + + torch.testing.assert_close(te_outputs.loss, hf_outputs.loss, atol=1e-2, rtol=1e-3) + torch.testing.assert_close( + te_outputs.logits[input_data["attention_mask"].to(bool)], + hf_outputs.logits[input_data["attention_mask"].to(bool)], + atol=2, # This seems high, needed to increase after https://github.com/huggingface/transformers/pull/40370 + rtol=1e-4, + ) + + +def test_converted_model_roundtrip(tmp_path, input_data): + from transformer_engine.pytorch import TransformerLayer + + from esm.convert import convert_esm_hf_to_te + from esm.modeling_esm_te import NVEsmConfig, NVEsmEncoder, NVEsmForMaskedLM, NVEsmLMHead, NVEsmModel + + model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") + model_te = convert_esm_hf_to_te(model_hf) + + model_te.save_pretrained(tmp_path / "esm2_t6_8M_UR50D_te") + del model_te + + model_te = NVEsmForMaskedLM.from_pretrained(tmp_path / "esm2_t6_8M_UR50D_te") + + # Ensure our custom classes are still there + assert isinstance(model_te, NVEsmForMaskedLM) + assert isinstance(model_te.config, NVEsmConfig) + assert isinstance(model_te.lm_head, NVEsmLMHead) + assert isinstance(model_te.model, NVEsmModel) + assert isinstance(model_te.model.encoder, NVEsmEncoder) + assert isinstance(model_te.model.encoder.layers[0], TransformerLayer) + assert model_te.config.model_type == "nv_esm" + + model_te.to("cuda") + model_hf.to("cuda") + input_data = {k: v.to("cuda") for k, v in input_data.items()} + + with torch.no_grad(): + te_outputs = model_te(**input_data, output_hidden_states=True) + hf_outputs = model_hf(**input_data, output_hidden_states=True) + + torch.testing.assert_close(te_outputs.loss, hf_outputs.loss, atol=1e-1, rtol=1e-3) diff --git a/bionemo-recipes/models/esm2/tests/test_thd.py b/bionemo-recipes/models/esm2/tests/test_thd.py new file mode 100644 index 000000000..daad705ff --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_thd.py @@ -0,0 +1,329 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os + +import pytest +import torch +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollatorForLanguageModeling + +from esm.collator import DataCollatorWithFlattening +from esm.modeling_esm_te import NVEsmConfig, NVEsmEmbeddings, NVEsmForMaskedLM + + +compute_capability = torch.cuda.get_device_capability() + +# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. +requires_datacenter_hardware = pytest.mark.skipif( + not torch.cuda.is_available() + or not any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ), + reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", +) + + +@pytest.fixture +def input_data_thd(tokenizer, tokenized_proteins): + """The collator here needs to exactly match the one used in the `input_data` fixture for golden values to pass.""" + data_collator = DataCollatorWithFlattening( + collator=DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm_probability=0.15, + pad_to_multiple_of=32, + seed=42, + ) + ) + return data_collator(tokenized_proteins) + + +@pytest.fixture +def input_data_thd_padded_from_input_data_thd(input_data_thd): + input_data_thd_padded = copy.deepcopy(input_data_thd) + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_data_thd_padded["input_ids"], + input_data_thd_padded["labels"], + input_data_thd_padded["cu_seq_lens_q"], + 16, + padding_token_id=1, + padding_label_id=-100, + ) + + input_data_thd_padded["input_ids"] = input_ids_padded.unsqueeze(0) + input_data_thd_padded["labels"] = labels_padded.unsqueeze(0) + input_data_thd_padded["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + input_data_thd_padded["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + input_data_thd_padded["pad_between_seqs"] = True + return input_data_thd_padded + + +@pytest.mark.parametrize("use_token_dropout", [True, False]) +def test_nv_esm_embeddings_random_init(te_model_checkpoint, input_data_thd, input_data, use_token_dropout): + config = NVEsmConfig.from_pretrained(te_model_checkpoint) + assert config.token_dropout is True + embedding = NVEsmEmbeddings(config) + embedding.token_dropout = use_token_dropout + embedding.to("cuda") + + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + input_data_thd.pop("labels") + outputs_thd = embedding(**input_data_thd) + + input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} + input_data_bshd.pop("labels") + outputs_bshd = embedding(**input_data_bshd) + + # Reshape outputs_bshd to match outputs_thd + outputs_bshd = outputs_bshd[input_data_bshd["attention_mask"].to(bool)].unsqueeze(0) + torch.testing.assert_close(outputs_thd, outputs_bshd, atol=1e-8, rtol=1e-8) + + +@pytest.mark.parametrize("use_token_dropout", [True, False]) +def test_nv_esm_embeddings_from_model(te_model_checkpoint, input_data_thd, input_data, use_token_dropout): + model = NVEsmForMaskedLM.from_pretrained( + te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16, token_dropout=use_token_dropout + ) + embedding = model.model.embeddings + assert embedding.token_dropout == use_token_dropout + embedding.to("cuda") + + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + input_data_thd.pop("labels") + outputs_thd = embedding(**input_data_thd) + + input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} + input_data_bshd.pop("labels") + outputs_bshd = embedding(**input_data_bshd) + + # Reshape outputs_bshd to match outputs_thd + outputs_bshd = outputs_bshd[input_data_bshd["attention_mask"].to(bool)].unsqueeze(0) + torch.testing.assert_close(outputs_thd, outputs_bshd, atol=1e-8, rtol=1e-8) + + +def test_thd_from_collator_output(te_model_checkpoint, input_data_thd): + model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + model_thd.to("cuda") + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + with torch.no_grad(): + outputs = model_thd(**input_data_thd, output_hidden_states=True) + + assert outputs.loss < 3.0 + + +@pytest.fixture(params=["flash_attn", "fused_attn"]) +def attn_impl(request, monkeypatch): + if request.param == "flash_attn": + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_FLASH_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + else: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + return request.param + + +def test_thd_losses_match(te_model_checkpoint, input_data, input_data_thd, attn_impl): + if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") + + torch.testing.assert_close( + input_data["input_ids"][input_data["attention_mask"].to(bool)], + input_data_thd["input_ids"].flatten(0), + ) + + torch.testing.assert_close( + input_data["labels"][input_data["attention_mask"].to(bool)], + input_data_thd["labels"].flatten(0), + ) + + model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + model_bshd.to("cuda") + model_thd.to("cuda") + + input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + + bshd_outputs = model_bshd(**input_data_bshd) + thd_outputs = model_thd(**input_data_thd) + + torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss) + + +def test_thd_logits_match_with_bf16_autocast(te_model_checkpoint, input_data, input_data_thd, attn_impl): + if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + elif attn_impl == "flash_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("BIONEMO-2801: On Ada and Ampere, the flash attention logits don't seem to match.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") + + # Ensure the input data is the same + torch.testing.assert_close( + input_data["input_ids"][input_data["attention_mask"].to(bool)], + input_data_thd["input_ids"].flatten(0), + ) + + torch.testing.assert_close( + input_data["labels"][input_data["attention_mask"].to(bool)], + input_data_thd["labels"].flatten(0), + ) + + # Create models + model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + + model_bshd.to("cuda") + model_thd.to("cuda") + + input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + + thd_outputs = model_thd(**input_data_thd, output_hidden_states=True) + bshd_outputs = model_bshd(**input_data_bshd, output_hidden_states=True) + + for i, (bshd_hidden, thd_hidden) in enumerate(zip(bshd_outputs.hidden_states, thd_outputs.hidden_states)): + torch.testing.assert_close( + bshd_hidden[input_data_bshd["attention_mask"].to(bool)], + thd_hidden.squeeze(0), + msg=lambda msg: "Hidden states do not match going into layer " + str(i + 1) + ": " + msg, + atol=1e-1 if compute_capability[0] == 8 else 1e-5, + rtol=1.6e-2, + ) + + if compute_capability[0] == 8: + break # On Ada and Ampere, we see much larger numerical errors so we stop after the first layer + + bshd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)] + torch.testing.assert_close(bshd_logits, thd_outputs.logits, atol=1e-8, rtol=1e-8) + + +def test_thd_backwards_works(te_model_checkpoint, input_data_thd, attn_impl): + if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0): + pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + + model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + model_thd.to("cuda") + input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + outputs = model_thd(**input_data) + outputs.loss.backward() + + +def test_thd_backwards_passes_match(te_model_checkpoint, input_data, input_data_thd, attn_impl): + if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0): + pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + + torch.testing.assert_close( + input_data["input_ids"][input_data["attention_mask"].to(bool)], + input_data_thd["input_ids"].flatten(0), + ) + + torch.testing.assert_close( + input_data["labels"][input_data["attention_mask"].to(bool)], + input_data_thd["labels"].flatten(0), + ) + + model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) + model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) + model_bshd.to("cuda") + model_thd.to("cuda") + + input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} + input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + + bshd_outputs = model_bshd(**input_data_bshd) + thd_outputs = model_thd(**input_data_thd) + + thd_outputs.loss.backward() + bshd_outputs.loss.backward() + + thd_grads = {name: p.grad for name, p in model_thd.named_parameters() if p.grad is not None} + bshd_grads = {name: p.grad for name, p in model_bshd.named_parameters() if p.grad is not None} + + # max_diff_by_layer = {key: (thd_grads[key] - bshd_grads[key]).abs().max().item() for key in thd_grads.keys()} + + # For some reason, the word embeddings grads have a slightly higher numerical error. + thd_word_embeddings_grad = thd_grads.pop("model.embeddings.word_embeddings.weight") + bshd_word_embeddings_grad = bshd_grads.pop("model.embeddings.word_embeddings.weight") + torch.testing.assert_close( + thd_grads, + bshd_grads, + atol=1e-2 if compute_capability[0] == 8 else 1e-5, + rtol=1.6e-2, + ) + + torch.testing.assert_close(thd_word_embeddings_grad, bshd_word_embeddings_grad, atol=1e-2, rtol=1e-5) + + +@requires_datacenter_hardware +def test_thd_vs_padded_thd_equivalence( + te_model_checkpoint, input_data_thd, input_data_thd_padded_from_input_data_thd, attn_impl +): + if attn_impl == "flash_attn": + pytest.xfail("Flash attention is not supported for padded sequences.") + + input_data_thd_padded = input_data_thd_padded_from_input_data_thd + seqlens_q = input_data_thd_padded["cu_seq_lens_q_padded"][1:] - input_data_thd_padded["cu_seq_lens_q_padded"][:-1] + max_length_q = int((seqlens_q.max().item() + 63) // 64 * 64) # TODO(@jomitchell): Not sure if I need this anymore. + max_length_k = max_length_q + input_data_thd_padded["max_length_q"] = max_length_q + input_data_thd_padded["max_length_k"] = max_length_k + + input_data_thd_gpu = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} + input_data_thd_padded_gpu = { + k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd_padded.items() + } + + # Run the input data thd through. + model_thd = NVEsmForMaskedLM.from_pretrained( + te_model_checkpoint, attn_input_format="thd", token_dropout=False, dtype=torch.bfloat16 + ) + model_thd.to("cuda") + outputs_thd = model_thd(**input_data_thd_gpu) + outputs_thd_padded = model_thd(**input_data_thd_padded_gpu) + + cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] + cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] + cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q + seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + + num_real_tokens = outputs_thd.logits.shape[0] # should be cu_seq_lens_q[-1] + + # How much we need to shift each sequence by. + offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) + + # The indices of the real tokens as appears in the padded logits. + real_idx = torch.arange(0, num_real_tokens) + offsets + + assert ( + input_data_thd["input_ids"].squeeze() - input_data_thd_padded["input_ids"].squeeze().index_select(0, real_idx) + ).abs().max().item() == 0 + + # Now index select the padded logits to get the real logits. + logits_unpadded = outputs_thd_padded.logits.index_select(0, real_idx.cuda()) + + torch.testing.assert_close(outputs_thd.logits, logits_unpadded, atol=1e-8, rtol=1e-5) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..fb2e4136d 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = True, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. Set to + ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` + (which does not use a pooler). This avoids missing-weight errors in vLLM. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +420,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +455,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -451,8 +463,7 @@ def forward( **kwargs, ) sequence_output = outputs[0] - with transformer_engine.pytorch.autocast(enabled=False): - prediction_scores = self.lm_head(sequence_output) + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -483,15 +494,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - with transformer_engine.pytorch.quantized_model_init(enabled=False): - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -511,7 +522,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.autocast(enabled=False): + with transformer_engine.pytorch.fp8_autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -599,10 +610,6 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -610,7 +617,7 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: @@ -633,7 +640,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +666,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..fb2e4136d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = True, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. Set to + ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` + (which does not use a pooler). This avoids missing-weight errors in vLLM. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +420,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +455,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -451,8 +463,7 @@ def forward( **kwargs, ) sequence_output = outputs[0] - with transformer_engine.pytorch.autocast(enabled=False): - prediction_scores = self.lm_head(sequence_output) + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -483,15 +494,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - with transformer_engine.pytorch.quantized_model_init(enabled=False): - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -511,7 +522,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.autocast(enabled=False): + with transformer_engine.pytorch.fp8_autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -599,10 +610,6 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -610,7 +617,7 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: @@ -633,7 +640,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +666,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py index a8b8afc6a..271e25723 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py @@ -75,7 +75,7 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. try: - del model.esm.contact_head + del model.model.contact_head except AttributeError: pass @@ -157,7 +157,7 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat resumed_model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) try: - del resumed_model.esm.contact_head + del resumed_model.model.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 168d25b57..3ad704916 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -85,7 +85,7 @@ def main(args: DictConfig) -> float | None: # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. try: - del model.esm.contact_head + del model.model.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index c5a8dad34..c7150f677 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -112,7 +112,7 @@ def main(args: DictConfig) -> float | None: ) if args.cp_size > 1: - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}") transformer_layer.set_context_parallel_group( device_mesh["cp"].get_group(), diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index c74a5ad6c..7a6917de5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -96,7 +96,9 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + transformer_stack = ( + model.model.encoder.layers if hasattr(model.model.encoder, "layers") else model.model.encoder.layer + ) mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 6a824cc9f..d268441bf 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -113,7 +113,9 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + transformer_stack = ( + model.model.encoder.layers if hasattr(model.model.encoder, "layers") else model.model.encoder.layer + ) # Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension). # FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension. for layer in transformer_stack: diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..fb2e4136d 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = True, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. Set to + ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` + (which does not use a pooler). This avoids missing-weight errors in vLLM. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +420,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +455,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -451,8 +463,7 @@ def forward( **kwargs, ) sequence_output = outputs[0] - with transformer_engine.pytorch.autocast(enabled=False): - prediction_scores = self.lm_head(sequence_output) + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -483,15 +494,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - with transformer_engine.pytorch.quantized_model_init(enabled=False): - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -511,7 +522,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.autocast(enabled=False): + with transformer_engine.pytorch.fp8_autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -599,10 +610,6 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -610,7 +617,7 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: @@ -633,7 +640,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +666,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/config.json b/bionemo-recipes/vllm/esm2_vllm_converted/config.json new file mode 100644 index 000000000..430c4c272 --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/config.json @@ -0,0 +1,46 @@ +{ + "add_cross_attention": false, + "architectures": [ + "NVEsmForMaskedLM" + ], + "attention_probs_dropout_prob": 0.0, + "attn_input_format": "bshd", + "attn_mask_type": "padding", + "auto_map": { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification" + }, + "classifier_dropout": null, + "dtype": "float32", + "emb_layer_norm_before": false, + "encoder_activation": "gelu", + "esmfold_config": null, + "fuse_qkv_params": true, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 320, + "initializer_range": 0.02, + "intermediate_size": 1280, + "is_decoder": false, + "is_folding_model": false, + "layer_norm_eps": 1e-05, + "mask_token_id": 32, + "max_position_embeddings": 1026, + "max_seq_length": null, + "micro_batch_size": null, + "model_type": "nv_esm", + "num_attention_heads": 20, + "num_hidden_layers": 6, + "pad_token_id": 1, + "padded_vocab_size": 64, + "position_embedding_type": "rotary", + "qkv_weight_interleaved": true, + "tie_word_embeddings": true, + "token_dropout": true, + "transformers_version": "5.0.0", + "use_cache": true, + "vocab_list": null, + "vocab_size": 33 +} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py b/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py new file mode 100644 index 000000000..00fdf2312 --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py @@ -0,0 +1,681 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +from typing import ClassVar, Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention. This controls + whether the dimensions of the intermediate hidden states is 'batch first' + ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, + `b` batch size, `h` the number of heads, `d` head size. Note that these + formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + for layer_module in self.layers: + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. + These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. + """ + super().__init__(config) + self.config = config + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + with transformer_engine.pytorch.fp8_model_init(enabled=False): + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + # Keep the last layers of the network in higher precision to avoid numerical instability. + # Please see recipes/fp8_analysis/README.md for more details. + with transformer_engine.pytorch.fp8_autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + + if not using_thd: + # BSHD token dropout correction + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + else: + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged( + is_masked, offsets=kwargs["cu_seq_lens_q"] + ).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json b/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json new file mode 100644 index 000000000..0d2693ece --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json @@ -0,0 +1,44 @@ +{ + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "cls_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "mask_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json new file mode 100644 index 000000000..708fe8bbd --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json @@ -0,0 +1,176 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 32, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 1 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 0 + ], + "tokens": [ + "" + ] + }, + "": { + "id": "", + "ids": [ + 2 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "L": 4, + "A": 5, + "G": 6, + "V": 7, + "S": 8, + "E": 9, + "R": 10, + "T": 11, + "I": 12, + "D": 13, + "P": 14, + "K": 15, + "Q": 16, + "N": 17, + "F": 18, + "Y": 19, + "M": 20, + "H": 21, + "W": 22, + "C": 23, + "X": 24, + "B": 25, + "U": 26, + "Z": 27, + "O": 28, + ".": 29, + "-": 30, + "": 31, + "": 32 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json new file mode 100644 index 000000000..a8223b53e --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json @@ -0,0 +1,18 @@ +{ + "backend": "tokenizers", + "bos_token": "", + "clean_up_tokenization_spaces": false, + "cls_token": "", + "eos_token": "", + "is_local": true, + "mask_token": "", + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 1000000000000000019884624838656, + "model_specific_special_tokens": {}, + "pad_token": "", + "tokenizer_class": "TokenizersBackend", + "unk_token": "" +} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json b/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json new file mode 100644 index 000000000..87114c5a6 --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json @@ -0,0 +1,5 @@ +{ + "original_model": "nvidia/esm2_t6_8M_UR50D", + "conversion": "strip_esm_prefix", + "description": "Stripped 'esm.' prefix from weight keys for vLLM compatibility (lm_head.* dropped)" +} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt b/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt new file mode 100644 index 000000000..741f755a4 --- /dev/null +++ b/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt @@ -0,0 +1,33 @@ + + + + +L +A +G +V +S +E +R +T +I +D +P +K +Q +N +F +Y +M +H +W +C +X +B +U +Z +O +. +- + + diff --git a/bionemo-recipes/vllm/vllm_test.py b/bionemo-recipes/vllm/vllm_test.py new file mode 100644 index 000000000..1bf35b790 --- /dev/null +++ b/bionemo-recipes/vllm/vllm_test.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM ESM2 embedding example -- loads directly from HuggingFace. + +The NVEsm model on HuggingFace follows the standard HuggingFace convention where +wrapper classes (NVEsmForMaskedLM, NVEsmForTokenClassification) store the base +model as ``self.model``, with ``base_model_prefix = "model"``. + +This means checkpoint weight keys align with vLLM's TransformersForEmbedding +wrapper out of the box: + +- Checkpoint bare keys: ``embeddings.*``, ``encoder.*`` +- vLLM mapper adds ``model.``: ``model.embeddings.*``, ``model.encoder.*`` +- Wrapper module tree: ``model.embeddings.*``, ``model.encoder.*`` + +No conversion scripts or weight renaming needed. +""" + +import numpy as np +import torch +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + + +# MODEL_ID = "nvidia/esm2_t6_8M_UR50D" +# To test with a local re-exported checkpoint before pushing to HuggingFace, use a path: +MODEL_ID = "/workspace/bionemo-framework/bionemo-recipes/models/esm2/exported/esm2_t6_8M_UR50D" # after running export from models/esm2 + +# Reference: nvidia model on HuggingFace Hub (same as MODEL_ID when using Hub) — check local/conversion against it. +REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" + + +if __name__ == "__main__": + num_gpus = torch.cuda.device_count() + print(f"Detected {num_gpus} GPU(s)") + + # Load ESM2 directly from HuggingFace as a pooling/embedding model. + # No checkpoint conversion needed -- the model code and checkpoint are + # aligned so that vLLM's generic weight mapper works out of the box. + print(f"\nLoading model: {MODEL_ID}") + model = LLM( + model=MODEL_ID, + runner="pooling", + trust_remote_code=True, + # TransformerEngine layers use pydantic (ArgsKwargs) which torch.compile + # cannot trace. Use eager mode to avoid the dynamo error. + enforce_eager=True, + # vLLM's profiling run packs all tokens into a single batch-1 sequence. + # Cap batched tokens to max_position_embeddings (1026) so the rotary + # embeddings don't run out of positions. + max_num_batched_tokens=1026, + ) + + # Example protein sequences + prompts = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", + ] + + print(f"\nGenerating embeddings for {len(prompts)} sequences...") + outputs = model.embed(prompts) + + # Collect vLLM embeddings for comparison (one vector per sequence from pooling) + vllm_embeddings = [] + for i, (prompt, output) in enumerate(zip(prompts, outputs)): + embedding = output.outputs.embedding + if isinstance(embedding, list): + embedding = np.array(embedding) + vllm_embeddings.append(embedding) + print(f"\nSequence {i + 1}:") + print(f" Length: {len(prompt)} amino acids") + print(f" Embedding shape: {embedding.shape}") + print(f" First 5 dims: {embedding[:5].tolist()}") + + vllm_embeddings = np.stack(vllm_embeddings) + + print("\nSUCCESS: ESM2 embeddings generated with vLLM!") + + # ---- Native HuggingFace inference on the same sequences ---- + # Run the same model via transformers and compare outputs to vLLM. + print(f"\n--- HuggingFace (native) inference on same {len(prompts)} sequences ---") + hf_model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + # Use float16 to match vLLM's default dtype for numerical comparison + hf_model = hf_model.to("cuda", dtype=torch.float16) + hf_model.eval() + + # Exported checkpoints use add_pooling_layer=False (no pooler weights). + # vLLM pooling runner uses seq_pooling_type='LAST' (last token). Use pooler_output when + # present, else take the last non-padding token's hidden state to match vLLM. + hf_embeddings_list = [] + with torch.no_grad(): + for seq in prompts: + inputs = tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = hf_model(**inputs) + if out.pooler_output is not None: + vec = out.pooler_output.cpu().numpy().squeeze(0) + else: + # Last-token hidden state to match vLLM's seq_pooling_type='LAST' + last_hidden = out.last_hidden_state # (1, seq_len, hidden_size) + vec = last_hidden[0, -1, :].cpu().float().numpy() + # vLLM pooling runner L2-normalizes the embedding; do the same for comparison + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + hf_embeddings_list.append(vec) + + hf_embeddings = np.stack(hf_embeddings_list) + + # Compare vLLM vs HuggingFace (same sequences, same model). + # Use relaxed tolerance for fp16/FlashAttention vs fp32/PyTorch attention differences. + rtol, atol = 1e-2, 5e-4 + match = np.allclose(vllm_embeddings, hf_embeddings, rtol=rtol, atol=atol) + max_diff = np.abs(vllm_embeddings.astype(np.float64) - hf_embeddings.astype(np.float64)).max() + print("\nComparison (vLLM vs HuggingFace embedding):") + print(f" allclose(rtol={rtol}, atol={atol}): {match}") + print(f" max |diff|: {max_diff}") + if not match: + raise AssertionError( + "vLLM and HuggingFace outputs differ. " + f"max |diff| = {max_diff}; expected allclose(rtol={rtol}, atol={atol})." + ) + print(" Match: vLLM and HuggingFace results are the same.") + + del hf_model + del tokenizer + + # ---- Reference: nvidia model from HuggingFace Hub (check our conversion / local export against it) ---- + # Load nvidia/esm2_t6_8M_UR50D from Hub and run same sequences; compare to MODEL_ID (local or Hub). + print(f"\n--- Reference: HuggingFace Hub {REFERENCE_MODEL_ID} (same {len(prompts)} sequences) ---") + ref_model = AutoModel.from_pretrained(REFERENCE_MODEL_ID, trust_remote_code=True) + ref_tokenizer = AutoTokenizer.from_pretrained(REFERENCE_MODEL_ID, trust_remote_code=True) + ref_model = ref_model.to("cuda", dtype=torch.float16) + ref_model.eval() + + def _embed_from_output(out, use_pooler=True, last_token_and_l2=True): + """Get one embedding per batch item: pooler_output or last-token hidden state (L2-normalized).""" + if use_pooler and out.pooler_output is not None: + return out.pooler_output.cpu().numpy().squeeze(0) + last_hidden = out.last_hidden_state # (batch, seq_len, hidden_size) + vec = last_hidden[0, -1, :].cpu().float().numpy() + if last_token_and_l2: + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + return vec + + ref_embeddings_list = [] + with torch.no_grad(): + for seq in prompts: + inputs = ref_tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = ref_model(**inputs) + # Use last-token + L2 to match our model (no pooler); validates conversion. + ref_embeddings_list.append(_embed_from_output(out, use_pooler=False)) + + ref_embeddings = np.stack(ref_embeddings_list) + + # Compare our model (HF path) vs reference (facebook Hub): conversion should match. + rtol_ref, atol_ref = 1e-2, 5e-4 + match_ref = np.allclose(hf_embeddings, ref_embeddings, rtol=rtol_ref, atol=atol_ref) + max_diff_ref = np.abs(hf_embeddings.astype(np.float64) - ref_embeddings.astype(np.float64)).max() + print(f"\nComparison (our model vs reference {REFERENCE_MODEL_ID}):") + print(f" allclose(rtol={rtol_ref}, atol={atol_ref}): {match_ref}") + print(f" max |diff|: {max_diff_ref}") + if not match_ref: + raise AssertionError( + "Our model and reference (Hub) outputs differ. " + f"max |diff| = {max_diff_ref}; expected allclose(rtol={rtol_ref}, atol={atol_ref})." + ) + print(" Match: conversion matches reference HuggingFace Hub model.") + + del ref_model + del ref_tokenizer + + # Cleanup + del model + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() From 3f9f00ed794cb36d9f84fcac9ae8fbcfbbd0c309 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 23 Feb 2026 03:15:02 +0000 Subject: [PATCH 2/9] remove converted checkpoint from git tracking --- .../vllm/esm2_vllm_converted/config.json | 46 -- .../vllm/esm2_vllm_converted/esm_nv.py | 681 ------------------ .../special_tokens_map.json | 44 -- .../vllm/esm2_vllm_converted/tokenizer.json | 176 ----- .../esm2_vllm_converted/tokenizer_config.json | 18 - .../vllm_conversion_info.json | 5 - .../vllm/esm2_vllm_converted/vocab.txt | 33 - 7 files changed, 1003 deletions(-) delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/config.json delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json delete mode 100644 bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/config.json b/bionemo-recipes/vllm/esm2_vllm_converted/config.json deleted file mode 100644 index 430c4c272..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/config.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "add_cross_attention": false, - "architectures": [ - "NVEsmForMaskedLM" - ], - "attention_probs_dropout_prob": 0.0, - "attn_input_format": "bshd", - "attn_mask_type": "padding", - "auto_map": { - "AutoConfig": "esm_nv.NVEsmConfig", - "AutoModel": "esm_nv.NVEsmModel", - "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", - "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification" - }, - "classifier_dropout": null, - "dtype": "float32", - "emb_layer_norm_before": false, - "encoder_activation": "gelu", - "esmfold_config": null, - "fuse_qkv_params": true, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.0, - "hidden_size": 320, - "initializer_range": 0.02, - "intermediate_size": 1280, - "is_decoder": false, - "is_folding_model": false, - "layer_norm_eps": 1e-05, - "mask_token_id": 32, - "max_position_embeddings": 1026, - "max_seq_length": null, - "micro_batch_size": null, - "model_type": "nv_esm", - "num_attention_heads": 20, - "num_hidden_layers": 6, - "pad_token_id": 1, - "padded_vocab_size": 64, - "position_embedding_type": "rotary", - "qkv_weight_interleaved": true, - "tie_word_embeddings": true, - "token_dropout": true, - "transformers_version": "5.0.0", - "use_cache": true, - "vocab_list": null, - "vocab_size": 33 -} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py b/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py deleted file mode 100644 index 00fdf2312..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/esm_nv.py +++ /dev/null @@ -1,681 +0,0 @@ -# noqa: license-check -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. -# Copyright 2025 NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""TransformerEngine-optimized ESM model. - -Adapted from `modeling_esm.py` in huggingface/transformers. -""" - -from typing import ClassVar, Literal, Optional, Unpack - -# TODO: put import guard around transformer_engine here, with an informative error message around -# installation and the nvidia docker container. -import torch -import transformer_engine.pytorch -from torch import nn -from torch.nn import CrossEntropyLoss -from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - MaskedLMOutput, - TokenClassifierOutput, -) -from transformers.models.esm.configuration_esm import EsmConfig -from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel -from transformers.utils import logging -from transformers.utils.generic import TransformersKwargs - - -logger = logging.get_logger(__name__) - -# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. -# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. -AUTO_MAP = { - "AutoConfig": "esm_nv.NVEsmConfig", - "AutoModel": "esm_nv.NVEsmModel", - "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", - "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", -} - - -class NVEsmConfig(EsmConfig): - """NVEsmConfig is a configuration for the NVEsm model.""" - - model_type: str = "nv_esm" - - def __init__( - self, - qkv_weight_interleaved: bool = True, - encoder_activation: str = "gelu", - attn_input_format: Literal["bshd", "thd"] = "bshd", - fuse_qkv_params: bool = True, - micro_batch_size: Optional[int] = None, - max_seq_length: Optional[int] = None, - padded_vocab_size: Optional[int] = 64, - attn_mask_type: str = "padding", - **kwargs, - ): - """Initialize the NVEsmConfig with additional TE-related config options. - - Args: - qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the - QKV weight is interpreted as a concatenation of query, key, and value weights along - the `0th` dimension. The default interpretation is that the individual `q`, `k`, and - `v` weights for each attention head are interleaved. This parameter is set to `False` - when using :attr:`fuse_qkv_params=False`. - encoder_activation: The activation function to use in the encoder. - attn_input_format: The input format to use for the attention. This controls - whether the dimensions of the intermediate hidden states is 'batch first' - ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, - `b` batch size, `h` the number of heads, `d` head size. Note that these - formats are very closely related to the `qkv_format` in the - `MultiHeadAttention` and `DotProductAttention` modules. - fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, - `TransformerLayer` module exposes a single fused parameter for query-key-value. - This enables optimizations such as QKV fusion without concatentations/splits and - also enables the argument `fuse_wgrad_accumulation`. - micro_batch_size: The micro batch size to use for the attention. This is needed for - JIT Warmup, a technique where jit fused functions are warmed up before training to - ensure same kernels are used for forward propogation and activation recompute phase. - max_seq_length: The maximum sequence length to use for the attention. This is needed for - JIT Warmup, a technique where jit fused functions are warmed up before training to - ensure same kernels are used for forward propogation and activation recompute phase. - padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults - to vocab_size. Must be greater than or equal to vocab_size. - attn_mask_type: The type of attention mask to use. - **kwargs: Additional config options to pass to EsmConfig. - """ - super().__init__(**kwargs) - # Additional TE-related config options. - self.qkv_weight_interleaved = qkv_weight_interleaved - self.encoder_activation = encoder_activation - self.attn_input_format = attn_input_format - self.fuse_qkv_params = fuse_qkv_params - self.micro_batch_size = micro_batch_size - self.max_seq_length = max_seq_length - self.attn_mask_type = attn_mask_type - - # Set padded_vocab_size with default fallback to vocab_size - self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size - - # Ensure padded_vocab_size is at least as large as vocab_size - if self.padded_vocab_size is not None and self.vocab_size is not None: - assert self.padded_vocab_size >= self.vocab_size, ( - f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" - ) - - -class NVEsmEncoder(nn.Module): - """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" - - def __init__(self, config: NVEsmConfig): - """Initialize a NVEsmEncoder. - - Args: - config (NVEsmConfig): The configuration of the model. - """ - super().__init__() - self.config = config - - def _init_method(x): - torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) - - self.layers = nn.ModuleList( - [ - transformer_engine.pytorch.TransformerLayer( - hidden_size=config.hidden_size, - ffn_hidden_size=config.intermediate_size, - num_attention_heads=config.num_attention_heads, - layernorm_epsilon=config.layer_norm_eps, - hidden_dropout=config.hidden_dropout_prob, - attention_dropout=config.attention_probs_dropout_prob, - qkv_weight_interleaved=config.qkv_weight_interleaved, - layer_number=i + 1, - layer_type="encoder", - self_attn_mask_type=config.attn_mask_type, - activation=config.encoder_activation, - attn_input_format=config.attn_input_format, - seq_length=config.max_seq_length, - micro_batch_size=config.micro_batch_size, - num_gqa_groups=config.num_attention_heads, - fuse_qkv_params=config.fuse_qkv_params, - params_dtype=config.dtype, - window_size=(-1, -1), - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=_init_method, - output_layer_init_method=_init_method, - ) - for i in range(config.num_hidden_layers) - ] - ) - self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - ) - if config.position_embedding_type == "rotary": - self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - **kwargs: Unpack[TransformersKwargs], - ): - """Forward pass of the NVEsmEncoder. - - Args: - hidden_states (torch.Tensor): The hidden states. - attention_mask (torch.Tensor): The attention mask. - **kwargs: Additional arguments, see TransformersKwargs for more details. - """ - all_hidden_states: tuple[torch.Tensor, ...] = () - - if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: - # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE - # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. - hidden_states = hidden_states.squeeze(0) - - # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. - with torch.autocast(device_type="cuda", enabled=False): - te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) - te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) - - for layer_module in self.layers: - if kwargs.get("output_hidden_states", False): - all_hidden_states = (*all_hidden_states, hidden_states) - - hidden_states = layer_module( - hidden_states, - attention_mask, - rotary_pos_emb=te_rope_emb, - cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), - cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), - cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), - cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), - max_seqlen_q=kwargs.get("max_length_q", None), - max_seqlen_kv=kwargs.get("max_length_k", None), - pad_between_seqs=kwargs.get("pad_between_seqs", None), - ) - - hidden_states = self.emb_layer_norm_after(hidden_states) - - if kwargs.get("output_hidden_states", False): - all_hidden_states = (*all_hidden_states, hidden_states) - - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states if all_hidden_states else None, - ) - - -class NVEsmPreTrainedModel(EsmPreTrainedModel): - """An abstract class to handle weights initialization and pretrained model loading.""" - - config_class = NVEsmConfig - base_model_prefix = "esm" - supports_gradient_checkpointing = False - accepts_loss_kwargs = False - _no_split_modules = ( - "TransformerLayer", - "EsmEmbeddings", - ) - - def init_empty_weights(self): - """Handles moving the model from the meta device to the cuda device and initializing the weights.""" - # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight - # initialization we passed them during module creation. - for module in self.modules(): - if hasattr(module, "reset_parameters"): - module.reset_parameters() - - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use - # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) - - # Meta-device init seems to break weight tying, so we re-tie the weights here. - self.tie_weights() - - def _init_weights(self, module): - """Initialize module weights. - - We only use this method for standard pytorch modules, TE modules handle their own weight initialization through - `init_method` parameters and the `reset_parameters` method. - """ - if module.__module__.startswith("transformer_engine.pytorch"): - # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will - # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking - # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and - # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the - # weights are not in fp8. We still need to figure out why this raises an error if we're using - # `quantized_model_init`. - if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): - module.reset_parameters() - return - - super()._init_weights(module) - - def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. - - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). - """ - state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} - - -class NVEsmModel(NVEsmPreTrainedModel): - """The ESM Encoder-only protein language model. - - This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. - """ - - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): - """Initialize a NVEsmModel. - - Args: - config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. - """ - super().__init__(config) - self.config = config - - # Ensure pad_token_id is set properly, defaulting to 0 if not specified - if not hasattr(config, "pad_token_id") or config.pad_token_id is None: - config.pad_token_id = 0 - self.embeddings = NVEsmEmbeddings(config) - self.encoder = NVEsmEncoder(config) - self.pooler = EsmPooler(config) if add_pooling_layer else None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - """Get the input embeddings of the model.""" - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value: torch.Tensor): - """Set the input embeddings of the model. - - Args: - value (torch.Tensor): The input embeddings. - """ - self.embeddings.word_embeddings = value - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPooling: - """Forward pass of the NVEsmModel. - - Args: - input_ids (torch.Tensor): The input ids. - attention_mask (torch.Tensor): The attention mask. - position_ids (torch.Tensor): The position ids. - inputs_embeds (torch.Tensor): The input embeddings. - **kwargs: Additional arguments, see TransformersKwargs for more details. - - Returns: - BaseModelOutputWithPooling: The output of the model. - """ - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - batch_size, seq_length = input_shape - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length)), device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - - # TE expects a boolean attention mask, where 1s are masked and 0s are not masked - extended_attention_mask = extended_attention_mask < -1 - - embedding_output = self.embeddings( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - **kwargs, - ) - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - **kwargs, - ) - sequence_output = encoder_outputs[0] - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - ) - - -class NVEsmForMaskedLM(NVEsmPreTrainedModel): - """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - - def __init__(self, config: NVEsmConfig): - """Initialize a NVEsmForMaskedLM. - - Args: - config (NVEsmConfig): The configuration of the model. - """ - super().__init__(config) - - if config.is_decoder: - logger.warning( - "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) - - self.esm = NVEsmModel(config, add_pooling_layer=False) - self.lm_head = NVEsmLMHead(config) - - self.post_init() - - def get_output_embeddings(self): - """Get the output embeddings of the model.""" - return self.lm_head.decoder - - def set_output_embeddings(self, new_embeddings): - """Set the output embeddings of the model.""" - self.lm_head.decoder = new_embeddings - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> MaskedLMOutput: - """Forward pass of the NVEsmForMaskedLM. - - Args: - input_ids (torch.LongTensor): The input ids. - attention_mask (torch.Tensor): The attention mask. - position_ids (torch.LongTensor): The position ids. - inputs_embeds (torch.FloatTensor): The input embeddings. - labels (torch.LongTensor): The labels. - **kwargs: Additional arguments, see TransformersKwargs for more details. - - Returns: - MaskedLMOutput: The output of the model. - """ - outputs = self.esm( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - **kwargs, - ) - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - # Truncate logits back to original vocab_size if padding was used - if self.config.padded_vocab_size != self.config.vocab_size: - prediction_scores = prediction_scores[..., : self.config.vocab_size] - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.to(prediction_scores.device).view(-1), - ) - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - ) - - -class NVEsmLMHead(nn.Module): - """ESM Head for masked language modeling using TransformerEngine.""" - - def __init__(self, config: NVEsmConfig): - """Initialize a NVEsmLMHead. - - Args: - config (NVEsmConfig): The configuration of the model. - """ - super().__init__() - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) - - with transformer_engine.pytorch.fp8_model_init(enabled=False): - self.decoder = transformer_engine.pytorch.LayerNormLinear( - config.hidden_size, - config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, - bias=True, - eps=config.layer_norm_eps, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) - - def forward(self, features, **kwargs): - """Forward pass of the NVEsmLMHead. - - Args: - features (torch.Tensor): The features. - **kwargs: Additional arguments. - """ - # Keep the last layers of the network in higher precision to avoid numerical instability. - # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.fp8_autocast(enabled=False): - x = self.dense(features) - x = torch.nn.functional.gelu(x) - x = self.decoder(x) - return x - - -class NVEsmEmbeddings(nn.Module): - """Modified version of EsmEmbeddings to support THD inputs.""" - - def __init__(self, config): - """Initialize a NVEsmEmbeddings.""" - super().__init__() - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id, - dtype=config.dtype, - ) - - self.layer_norm = ( - transformer_engine.pytorch.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - ) - if config.emb_layer_norm_before - else None - ) - - if config.position_embedding_type != "rotary": - raise ValueError( - "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " - f"{config.position_embedding_type}" - ) - - self.padding_idx = config.pad_token_id - self.token_dropout = config.token_dropout - self.mask_token_id = config.mask_token_id - - def forward( - self, - input_ids=None, - attention_mask=None, - inputs_embeds=None, - **kwargs: Unpack[TransformersKwargs], - ): - """Forward pass of the NVEsmEmbeddings.""" - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an - # embedding_scale factor here. - embeddings = inputs_embeds - - if ( - kwargs.get("cu_seq_lens_q") is not None - and kwargs.get("cu_seq_lens_k") is not None - and kwargs.get("max_length_q") is not None - and kwargs.get("max_length_k") is not None - ): - using_thd = True - attention_mask = None - else: - using_thd = False - - # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout - # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, - # masked tokens are treated as if they were selected for input dropout and zeroed out. - # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by - # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). - # This is analogous to the way that dropout layers scale down outputs during evaluation when not - # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). - if self.token_dropout and input_ids is not None: - embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) - mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs - - if not using_thd: - # BSHD token dropout correction - src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] - n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() - mask_ratio_observed = n_masked_per_seq / src_lengths - scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - embeddings = (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) - - else: - src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) - # We need to find the number of masked tokens in each sequence in the padded batch. - is_masked = (input_ids == self.mask_token_id).squeeze(0) - n_masked_per_seq = torch.nested.nested_tensor_from_jagged( - is_masked, offsets=kwargs["cu_seq_lens_q"] - ).sum(1) - mask_ratio_observed = n_masked_per_seq.float() / src_lengths - scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) - embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) - - if self.layer_norm is not None: - embeddings = self.layer_norm(embeddings) - - if attention_mask is not None: - embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) - - return embeddings - - -class NVEsmForTokenClassification(NVEsmPreTrainedModel): - """Adds a token classification head to the model. - - Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. - """ - - def __init__(self, config): - """Initialize NVEsmForTokenClassification.""" - super().__init__(config) - self.num_labels = config.num_labels - - self.esm = NVEsmModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = transformer_engine.pytorch.Linear( - config.hidden_size, - config.num_labels, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) - - self.post_init() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> TokenClassifierOutput: - """Forward pass for the token classification head. - - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. - """ - outputs = self.esm( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - inputs_embeds=inputs_embeds, - **kwargs, - ) - - sequence_output = outputs[0] - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() - - labels = labels.to(logits.device) - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json b/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json deleted file mode 100644 index 0d2693ece..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/special_tokens_map.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "bos_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "cls_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "eos_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "mask_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "pad_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - "unk_token": { - "content": "", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - } -} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json deleted file mode 100644 index 708fe8bbd..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer.json +++ /dev/null @@ -1,176 +0,0 @@ -{ - "version": "1.0", - "truncation": null, - "padding": null, - "added_tokens": [ - { - "id": 0, - "content": "", - "single_word": false, - "lstrip": false, - "rstrip": false, - "normalized": false, - "special": true - }, - { - "id": 1, - "content": "", - "single_word": false, - "lstrip": false, - "rstrip": false, - "normalized": false, - "special": true - }, - { - "id": 2, - "content": "", - "single_word": false, - "lstrip": false, - "rstrip": false, - "normalized": false, - "special": true - }, - { - "id": 3, - "content": "", - "single_word": false, - "lstrip": false, - "rstrip": false, - "normalized": false, - "special": true - }, - { - "id": 32, - "content": "", - "single_word": false, - "lstrip": false, - "rstrip": false, - "normalized": false, - "special": true - } - ], - "normalizer": null, - "pre_tokenizer": { - "type": "Split", - "pattern": { - "String": "" - }, - "behavior": "Isolated", - "invert": false - }, - "post_processor": { - "type": "TemplateProcessing", - "single": [ - { - "SpecialToken": { - "id": "", - "type_id": 0 - } - }, - { - "Sequence": { - "id": "A", - "type_id": 0 - } - }, - { - "SpecialToken": { - "id": "", - "type_id": 0 - } - } - ], - "pair": [ - { - "SpecialToken": { - "id": "", - "type_id": 0 - } - }, - { - "Sequence": { - "id": "A", - "type_id": 0 - } - }, - { - "SpecialToken": { - "id": "", - "type_id": 0 - } - }, - { - "Sequence": { - "id": "B", - "type_id": 1 - } - }, - { - "SpecialToken": { - "id": "", - "type_id": 1 - } - } - ], - "special_tokens": { - "": { - "id": "", - "ids": [ - 0 - ], - "tokens": [ - "" - ] - }, - "": { - "id": "", - "ids": [ - 2 - ], - "tokens": [ - "" - ] - } - } - }, - "decoder": null, - "model": { - "type": "WordLevel", - "vocab": { - "": 0, - "": 1, - "": 2, - "": 3, - "L": 4, - "A": 5, - "G": 6, - "V": 7, - "S": 8, - "E": 9, - "R": 10, - "T": 11, - "I": 12, - "D": 13, - "P": 14, - "K": 15, - "Q": 16, - "N": 17, - "F": 18, - "Y": 19, - "M": 20, - "H": 21, - "W": 22, - "C": 23, - "X": 24, - "B": 25, - "U": 26, - "Z": 27, - "O": 28, - ".": 29, - "-": 30, - "": 31, - "": 32 - }, - "unk_token": "" - } -} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json b/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json deleted file mode 100644 index a8223b53e..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/tokenizer_config.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "backend": "tokenizers", - "bos_token": "", - "clean_up_tokenization_spaces": false, - "cls_token": "", - "eos_token": "", - "is_local": true, - "mask_token": "", - "model_input_names": [ - "input_ids", - "attention_mask" - ], - "model_max_length": 1000000000000000019884624838656, - "model_specific_special_tokens": {}, - "pad_token": "", - "tokenizer_class": "TokenizersBackend", - "unk_token": "" -} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json b/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json deleted file mode 100644 index 87114c5a6..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/vllm_conversion_info.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "original_model": "nvidia/esm2_t6_8M_UR50D", - "conversion": "strip_esm_prefix", - "description": "Stripped 'esm.' prefix from weight keys for vLLM compatibility (lm_head.* dropped)" -} diff --git a/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt b/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt deleted file mode 100644 index 741f755a4..000000000 --- a/bionemo-recipes/vllm/esm2_vllm_converted/vocab.txt +++ /dev/null @@ -1,33 +0,0 @@ - - - - -L -A -G -V -S -E -R -T -I -D -P -K -Q -N -F -Y -M -H -W -C -X -B -U -Z -O -. -- - - From 99e808750db83fa2a455370acf0c9630886073bc Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 23 Feb 2026 05:44:31 +0000 Subject: [PATCH 3/9] testing --- bionemo-recipes/vllm/vllm_test.py | 1 + .../vllm/vllm_test_without_convert.py | 60 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 bionemo-recipes/vllm/vllm_test_without_convert.py diff --git a/bionemo-recipes/vllm/vllm_test.py b/bionemo-recipes/vllm/vllm_test.py index 1bf35b790..523677a4c 100644 --- a/bionemo-recipes/vllm/vllm_test.py +++ b/bionemo-recipes/vllm/vllm_test.py @@ -55,6 +55,7 @@ model=MODEL_ID, runner="pooling", trust_remote_code=True, + dtype="float32", # TransformerEngine layers use pydantic (ArgsKwargs) which torch.compile # cannot trace. Use eager mode to avoid the dynamo error. enforce_eager=True, diff --git a/bionemo-recipes/vllm/vllm_test_without_convert.py b/bionemo-recipes/vllm/vllm_test_without_convert.py new file mode 100644 index 000000000..74663738b --- /dev/null +++ b/bionemo-recipes/vllm/vllm_test_without_convert.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import AutoModelForMaskedLM +from vllm import LLM + + +models = {"nvidia": "nvidia/esm2_t6_8M_UR50D", "facebook": "facebook/esm2_t6_8M_UR50D"} + +MODEL_SWITCH = "facebook" + + +def convert_to_hf(model): + """Convert the given model to a HuggingFace AutoModelForMaskedLM.""" + model_hf = AutoModelForMaskedLM.from_pretrained(models[MODEL_SWITCH]) + return model_hf + + +if __name__ == "__main__": + num_gpus = torch.cuda.device_count() + print(f"Detected {num_gpus} GPU(s)") + + model = LLM( + model=models[MODEL_SWITCH], + runner="pooling", + trust_remote_code=True, + dtype="float32", + # TransformerEngine layers use pydantic (ArgsKwargs) which torch.compile + # cannot trace. Use eager mode to avoid the dynamo error. + enforce_eager=True, + # vLLM's profiling run packs all tokens into a single batch-1 sequence. + # Cap batched tokens to max_position_embeddings (1026) so the rotary + # embeddings don't run out of positions. + max_num_batched_tokens=1026, + ) + + prompts = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", + ] + + outputs = model.embed(prompts) + breakpoint() + + del model + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() From af0c43c467c6396e774fe81fb836d9eed46dcdbd Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 23 Feb 2026 16:46:07 +0000 Subject: [PATCH 4/9] cleanup --- .../vllm/test_vllm_golden_values.py | 210 ++++++++++++++++++ bionemo-recipes/vllm/vllm_test.py | 194 ---------------- .../vllm/vllm_test_without_convert.py | 60 ----- 3 files changed, 210 insertions(+), 254 deletions(-) create mode 100644 bionemo-recipes/vllm/test_vllm_golden_values.py delete mode 100644 bionemo-recipes/vllm/vllm_test.py delete mode 100644 bionemo-recipes/vllm/vllm_test_without_convert.py diff --git a/bionemo-recipes/vllm/test_vllm_golden_values.py b/bionemo-recipes/vllm/test_vllm_golden_values.py new file mode 100644 index 000000000..fde9347bf --- /dev/null +++ b/bionemo-recipes/vllm/test_vllm_golden_values.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end golden-value test for ESM2 vLLM compatibility. + +Performs a fresh facebook -> TE export, then cross-validates embeddings across +three backends on the same protein sequences: + +1. **vLLM** - freshly exported model loaded via ``LLM(runner="pooling")``. +2. **HF (exported)** - same exported checkpoint loaded via ``AutoModel``. +3. **HF (reference)**- nvidia Hub model loaded via ``AutoModel`` (ground truth). + +vLLM's pooling runner returns *last-token, L2-normalised* embeddings by default, +so the HuggingFace runs replicate that post-processing for an apples-to-apples comparison. +""" + +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + + +# ---- Fresh export ---- +# The export script uses relative paths (modeling_esm_te.py, esm_fast_tokenizer, etc.) +# so we need to run it from the esm2 model directory. +ESM2_MODEL_DIR = Path(__file__).resolve().parent.parent / "models" / "esm2" +EXPORT_DIR = Path(__file__).resolve().parent / "exported_checkpoint" +EXPORT_TAG = "esm2_t6_8M_UR50D" + +sys.path.insert(0, str(ESM2_MODEL_DIR)) + + +def fresh_export() -> str: + """Run the full facebook -> TE export and return the path to the exported checkpoint.""" + from export import export_hf_checkpoint + + # export_hf_checkpoint uses relative paths, so temporarily chdir + original_cwd = os.getcwd() + os.chdir(ESM2_MODEL_DIR) + try: + EXPORT_DIR.mkdir(parents=True, exist_ok=True) + print(f"Exporting facebook/{EXPORT_TAG} -> {EXPORT_DIR / EXPORT_TAG}") + export_hf_checkpoint(EXPORT_TAG, EXPORT_DIR) + finally: + os.chdir(original_cwd) + + return str(EXPORT_DIR / EXPORT_TAG) + + +# ---- Configuration ---- +REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" + +SEQUENCES = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", +] + +RTOL, ATOL = 0, 0 + + +# ---- Helpers ---- + + +def last_token_l2(hidden_state: torch.Tensor) -> np.ndarray: + """Extract last-token hidden state and L2-normalise (matches vLLM pooling defaults).""" + vec = hidden_state[0, -1, :].cpu().float().numpy() + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + return vec + + +def hf_embed(model_id: str, sequences: list[str], dtype=torch.float32) -> np.ndarray: + """Run HuggingFace inference and return last-token L2-normalised embeddings.""" + model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=dtype).eval() + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + + vecs = [] + with torch.no_grad(): + for seq in sequences: + inputs = tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = model(**inputs) + vecs.append(last_token_l2(out.last_hidden_state)) + + del model, tokenizer + torch.cuda.empty_cache() + return np.stack(vecs) + + +def vllm_embed(model_id: str, sequences: list[str]) -> np.ndarray: + """Run vLLM pooling inference and return embeddings.""" + engine = LLM( + model=model_id, + runner="pooling", + trust_remote_code=True, + dtype="float32", + enforce_eager=True, + max_num_batched_tokens=1026, + ) + outputs = engine.embed(sequences) + + vecs = [] + for output in outputs: + emb = output.outputs.embedding + if isinstance(emb, list): + emb = np.array(emb) + vecs.append(emb) + + del engine + return np.stack(vecs) + + +def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: + """Element-wise maximum absolute difference between two arrays.""" + return float(np.abs(a.astype(np.float64) - b.astype(np.float64)).max()) + + +def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: + """Mean cosine similarity across rows.""" + sims = [] + for va, vb in zip(a, b): + dot = np.dot(va, vb) + na, nb = np.linalg.norm(va), np.linalg.norm(vb) + sims.append(dot / max(na * nb, 1e-12)) + return float(np.mean(sims)) + + +# ---- Main ---- + +if __name__ == "__main__": + print(f"GPUs: {torch.cuda.device_count()}") + + # Step 0: fresh export (facebook HF -> our TE format) + print("\n[0/3] Exporting checkpoint ...") + MODEL_ID = fresh_export() + + print(f"MODEL_ID: {MODEL_ID}") + print(f"REFERENCE_MODEL_ID: {REFERENCE_MODEL_ID}") + print(f"Sequences: {len(SEQUENCES)}") + + # 1) vLLM on exported model + print("\n[1/3] vLLM inference (exported model) ...") + emb_vllm = vllm_embed(MODEL_ID, SEQUENCES) + + # 2) HuggingFace on exported model + print("\n[2/3] HuggingFace inference (exported model) ...") + emb_hf_exported = hf_embed(MODEL_ID, SEQUENCES) + + # 3) HuggingFace on reference Hub model + print("\n[3/3] HuggingFace inference (reference model) ...") + emb_hf_reference = hf_embed(REFERENCE_MODEL_ID, SEQUENCES) + + # ---- Pairwise comparisons ---- + pairs = [ + ("vLLM (exported)", "HF (exported)", emb_vllm, emb_hf_exported), + ("vLLM (exported)", "HF (reference)", emb_vllm, emb_hf_reference), + ("HF (exported)", "HF (reference)", emb_hf_exported, emb_hf_reference), + ] + + # ---- Summary table ---- + header = f"{'Pair':<35} {'max |diff|':>14} {'mean |diff|':>14} {'cos sim':>12} {'exact':>7}" + sep = "-" * len(header) + print(f"\n{sep}") + print(header) + print(sep) + + for name_a, name_b, a, b in pairs: + diffs = np.abs(a.astype(np.float64) - b.astype(np.float64)) + label = f"{name_a} vs {name_b}" + exact = np.array_equal(a, b) + print( + f"{label:<35} {diffs.max():>14.8e} {diffs.mean():>14.8e} " + f"{cosine_sim(a, b):>12.10f} {'YES' if exact else 'NO':>7}" + ) + + print(sep) + print(f"Tolerance: rtol={RTOL}, atol={ATOL} (0 = exact match required)") + + # Per-sequence breakdown + short = {"vLLM (exported)": "vllm", "HF (exported)": "hf_exp", "HF (reference)": "hf_ref"} + print("\nPer-sequence max |diff|:") + for i in range(len(SEQUENCES)): + row = f" seq {i}:" + for name_a, name_b, a, b in pairs: + d = float(np.abs(a[i].astype(np.float64) - b[i].astype(np.float64)).max()) + row += f" {short[name_a]}_vs_{short[name_b]}={d:.8e}" + print(row) + + print(sep) + + # Cleanup + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/vllm/vllm_test.py b/bionemo-recipes/vllm/vllm_test.py deleted file mode 100644 index 523677a4c..000000000 --- a/bionemo-recipes/vllm/vllm_test.py +++ /dev/null @@ -1,194 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""vLLM ESM2 embedding example -- loads directly from HuggingFace. - -The NVEsm model on HuggingFace follows the standard HuggingFace convention where -wrapper classes (NVEsmForMaskedLM, NVEsmForTokenClassification) store the base -model as ``self.model``, with ``base_model_prefix = "model"``. - -This means checkpoint weight keys align with vLLM's TransformersForEmbedding -wrapper out of the box: - -- Checkpoint bare keys: ``embeddings.*``, ``encoder.*`` -- vLLM mapper adds ``model.``: ``model.embeddings.*``, ``model.encoder.*`` -- Wrapper module tree: ``model.embeddings.*``, ``model.encoder.*`` - -No conversion scripts or weight renaming needed. -""" - -import numpy as np -import torch -from transformers import AutoModel, AutoTokenizer -from vllm import LLM - - -# MODEL_ID = "nvidia/esm2_t6_8M_UR50D" -# To test with a local re-exported checkpoint before pushing to HuggingFace, use a path: -MODEL_ID = "/workspace/bionemo-framework/bionemo-recipes/models/esm2/exported/esm2_t6_8M_UR50D" # after running export from models/esm2 - -# Reference: nvidia model on HuggingFace Hub (same as MODEL_ID when using Hub) — check local/conversion against it. -REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" - - -if __name__ == "__main__": - num_gpus = torch.cuda.device_count() - print(f"Detected {num_gpus} GPU(s)") - - # Load ESM2 directly from HuggingFace as a pooling/embedding model. - # No checkpoint conversion needed -- the model code and checkpoint are - # aligned so that vLLM's generic weight mapper works out of the box. - print(f"\nLoading model: {MODEL_ID}") - model = LLM( - model=MODEL_ID, - runner="pooling", - trust_remote_code=True, - dtype="float32", - # TransformerEngine layers use pydantic (ArgsKwargs) which torch.compile - # cannot trace. Use eager mode to avoid the dynamo error. - enforce_eager=True, - # vLLM's profiling run packs all tokens into a single batch-1 sequence. - # Cap batched tokens to max_position_embeddings (1026) so the rotary - # embeddings don't run out of positions. - max_num_batched_tokens=1026, - ) - - # Example protein sequences - prompts = [ - "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", - "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", - ] - - print(f"\nGenerating embeddings for {len(prompts)} sequences...") - outputs = model.embed(prompts) - - # Collect vLLM embeddings for comparison (one vector per sequence from pooling) - vllm_embeddings = [] - for i, (prompt, output) in enumerate(zip(prompts, outputs)): - embedding = output.outputs.embedding - if isinstance(embedding, list): - embedding = np.array(embedding) - vllm_embeddings.append(embedding) - print(f"\nSequence {i + 1}:") - print(f" Length: {len(prompt)} amino acids") - print(f" Embedding shape: {embedding.shape}") - print(f" First 5 dims: {embedding[:5].tolist()}") - - vllm_embeddings = np.stack(vllm_embeddings) - - print("\nSUCCESS: ESM2 embeddings generated with vLLM!") - - # ---- Native HuggingFace inference on the same sequences ---- - # Run the same model via transformers and compare outputs to vLLM. - print(f"\n--- HuggingFace (native) inference on same {len(prompts)} sequences ---") - hf_model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - # Use float16 to match vLLM's default dtype for numerical comparison - hf_model = hf_model.to("cuda", dtype=torch.float16) - hf_model.eval() - - # Exported checkpoints use add_pooling_layer=False (no pooler weights). - # vLLM pooling runner uses seq_pooling_type='LAST' (last token). Use pooler_output when - # present, else take the last non-padding token's hidden state to match vLLM. - hf_embeddings_list = [] - with torch.no_grad(): - for seq in prompts: - inputs = tokenizer(seq, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - out = hf_model(**inputs) - if out.pooler_output is not None: - vec = out.pooler_output.cpu().numpy().squeeze(0) - else: - # Last-token hidden state to match vLLM's seq_pooling_type='LAST' - last_hidden = out.last_hidden_state # (1, seq_len, hidden_size) - vec = last_hidden[0, -1, :].cpu().float().numpy() - # vLLM pooling runner L2-normalizes the embedding; do the same for comparison - norm = np.linalg.norm(vec) - if norm > 1e-9: - vec = vec / norm - hf_embeddings_list.append(vec) - - hf_embeddings = np.stack(hf_embeddings_list) - - # Compare vLLM vs HuggingFace (same sequences, same model). - # Use relaxed tolerance for fp16/FlashAttention vs fp32/PyTorch attention differences. - rtol, atol = 1e-2, 5e-4 - match = np.allclose(vllm_embeddings, hf_embeddings, rtol=rtol, atol=atol) - max_diff = np.abs(vllm_embeddings.astype(np.float64) - hf_embeddings.astype(np.float64)).max() - print("\nComparison (vLLM vs HuggingFace embedding):") - print(f" allclose(rtol={rtol}, atol={atol}): {match}") - print(f" max |diff|: {max_diff}") - if not match: - raise AssertionError( - "vLLM and HuggingFace outputs differ. " - f"max |diff| = {max_diff}; expected allclose(rtol={rtol}, atol={atol})." - ) - print(" Match: vLLM and HuggingFace results are the same.") - - del hf_model - del tokenizer - - # ---- Reference: nvidia model from HuggingFace Hub (check our conversion / local export against it) ---- - # Load nvidia/esm2_t6_8M_UR50D from Hub and run same sequences; compare to MODEL_ID (local or Hub). - print(f"\n--- Reference: HuggingFace Hub {REFERENCE_MODEL_ID} (same {len(prompts)} sequences) ---") - ref_model = AutoModel.from_pretrained(REFERENCE_MODEL_ID, trust_remote_code=True) - ref_tokenizer = AutoTokenizer.from_pretrained(REFERENCE_MODEL_ID, trust_remote_code=True) - ref_model = ref_model.to("cuda", dtype=torch.float16) - ref_model.eval() - - def _embed_from_output(out, use_pooler=True, last_token_and_l2=True): - """Get one embedding per batch item: pooler_output or last-token hidden state (L2-normalized).""" - if use_pooler and out.pooler_output is not None: - return out.pooler_output.cpu().numpy().squeeze(0) - last_hidden = out.last_hidden_state # (batch, seq_len, hidden_size) - vec = last_hidden[0, -1, :].cpu().float().numpy() - if last_token_and_l2: - norm = np.linalg.norm(vec) - if norm > 1e-9: - vec = vec / norm - return vec - - ref_embeddings_list = [] - with torch.no_grad(): - for seq in prompts: - inputs = ref_tokenizer(seq, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - out = ref_model(**inputs) - # Use last-token + L2 to match our model (no pooler); validates conversion. - ref_embeddings_list.append(_embed_from_output(out, use_pooler=False)) - - ref_embeddings = np.stack(ref_embeddings_list) - - # Compare our model (HF path) vs reference (facebook Hub): conversion should match. - rtol_ref, atol_ref = 1e-2, 5e-4 - match_ref = np.allclose(hf_embeddings, ref_embeddings, rtol=rtol_ref, atol=atol_ref) - max_diff_ref = np.abs(hf_embeddings.astype(np.float64) - ref_embeddings.astype(np.float64)).max() - print(f"\nComparison (our model vs reference {REFERENCE_MODEL_ID}):") - print(f" allclose(rtol={rtol_ref}, atol={atol_ref}): {match_ref}") - print(f" max |diff|: {max_diff_ref}") - if not match_ref: - raise AssertionError( - "Our model and reference (Hub) outputs differ. " - f"max |diff| = {max_diff_ref}; expected allclose(rtol={rtol_ref}, atol={atol_ref})." - ) - print(" Match: conversion matches reference HuggingFace Hub model.") - - del ref_model - del ref_tokenizer - - # Cleanup - del model - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/vllm/vllm_test_without_convert.py b/bionemo-recipes/vllm/vllm_test_without_convert.py deleted file mode 100644 index 74663738b..000000000 --- a/bionemo-recipes/vllm/vllm_test_without_convert.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from transformers import AutoModelForMaskedLM -from vllm import LLM - - -models = {"nvidia": "nvidia/esm2_t6_8M_UR50D", "facebook": "facebook/esm2_t6_8M_UR50D"} - -MODEL_SWITCH = "facebook" - - -def convert_to_hf(model): - """Convert the given model to a HuggingFace AutoModelForMaskedLM.""" - model_hf = AutoModelForMaskedLM.from_pretrained(models[MODEL_SWITCH]) - return model_hf - - -if __name__ == "__main__": - num_gpus = torch.cuda.device_count() - print(f"Detected {num_gpus} GPU(s)") - - model = LLM( - model=models[MODEL_SWITCH], - runner="pooling", - trust_remote_code=True, - dtype="float32", - # TransformerEngine layers use pydantic (ArgsKwargs) which torch.compile - # cannot trace. Use eager mode to avoid the dynamo error. - enforce_eager=True, - # vLLM's profiling run packs all tokens into a single batch-1 sequence. - # Cap batched tokens to max_position_embeddings (1026) so the rotary - # embeddings don't run out of positions. - max_num_batched_tokens=1026, - ) - - prompts = [ - "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", - "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", - ] - - outputs = model.embed(prompts) - breakpoint() - - del model - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() From e2b3fcd3350118df89c688175f99f2fda21e7247 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 23 Feb 2026 16:51:13 +0000 Subject: [PATCH 5/9] files from other branch --- bionemo-recipes/vllm/Dockerfile | 36 +++++++++++++ bionemo-recipes/vllm/README.md | 23 +++++++++ bionemo-recipes/vllm/launch.sh | 50 +++++++++++++++++++ ...n_values.py => test_esm2_golden_values.py} | 0 4 files changed, 109 insertions(+) create mode 100644 bionemo-recipes/vllm/Dockerfile create mode 100644 bionemo-recipes/vllm/README.md create mode 100644 bionemo-recipes/vllm/launch.sh rename bionemo-recipes/vllm/{test_vllm_golden_values.py => test_esm2_golden_values.py} (100%) diff --git a/bionemo-recipes/vllm/Dockerfile b/bionemo-recipes/vllm/Dockerfile new file mode 100644 index 000000000..f8b9067d6 --- /dev/null +++ b/bionemo-recipes/vllm/Dockerfile @@ -0,0 +1,36 @@ +# FROM nvcr.io/nvidia/vllm:26.01-py3 +FROM gitlab-master.nvidia.com:5005/dl/dgx/vllm:main-py3.43005406-devel +# using this because we need vllm >= 0.14 to work with Transformers v5. no released nvidia version with this yet. + +# The vLLM image has CUDA 13.1 runtime and nvcc, but missing dev headers (cusparse.h, nvtx, etc.) +# Install cuda-keyring to add NVIDIA's apt repo, then install the dev headers for transformer_engine +RUN apt-get update && apt-get install -y --no-install-recommends wget && \ + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \ + dpkg -i cuda-keyring_1.1-1_all.deb && \ + rm cuda-keyring_1.1-1_all.deb && \ + apt-get update && apt-get install -y --no-install-recommends \ + cuda-nvtx-13-1 \ + cuda-cupti-dev-13-1 \ + cuda-nvml-dev-13-1 \ + libcusparse-dev-13-1 \ + libcusolver-dev-13-1 \ + libcufft-dev-13-1 \ + libnvjitlink-dev-13-1 \ + libnvjpeg-dev-13-1 \ + libcublasmp0-dev-cuda-13 \ + libcudnn9-cuda-13 \ + && rm -rf /var/lib/apt/lists/* + +# Install remaining dependencies +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=requirements.txt,target=/requirements.txt \ + pip install -r /requirements.txt + +# Install transformer_engine from source (force build for CUDA 13.1, not pre-built cu12 wheel) +RUN pip install --no-build-isolation transformer_engine[pytorch] + +RUN pip install transformers[torch]==5.0.0 + + +WORKDIR /workspace/bionemo +COPY . . diff --git a/bionemo-recipes/vllm/README.md b/bionemo-recipes/vllm/README.md new file mode 100644 index 000000000..7fcb25c22 --- /dev/null +++ b/bionemo-recipes/vllm/README.md @@ -0,0 +1,23 @@ +# vLLM inference for BioNeMo Models + +To build the image: + +```bash +docker build -t vllm . +``` + +Set `HF_TOKEN` in your environment to avoid getting rate limited. + +To launch a container: + +```bash +docker run -it --gpus all --network host --ipc=host -e HF_TOKEN --rm -v ${PWD}:/workspace/bionemo vllm /bin/bash +``` + +or use `launch.sh`. + +To test ESM2 inference using vLLM inside the container: + +```python +python test_esm2_golden_values.py +``` diff --git a/bionemo-recipes/vllm/launch.sh b/bionemo-recipes/vllm/launch.sh new file mode 100644 index 000000000..e52a3ea53 --- /dev/null +++ b/bionemo-recipes/vllm/launch.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +# Convenience script to launch the vLLM container with the correct mounts and flags. +# Usage: ./launch.sh [--mount_dir] [--headless] +# Example: ./launch.sh vllm --mount_dir --headless + +MOUNT_DIR=false +HEADLESS=false +CONTAINER="" + +# Parse arguments +for arg in "$@"; do + case $arg in + --mount_dir) + MOUNT_DIR=true + ;; + --headless) + HEADLESS=true + ;; + *) + # First non-flag argument is the container name + if [ -z "$CONTAINER" ]; then + CONTAINER="$arg" + fi + ;; + esac +done + +if [ -z "$CONTAINER" ]; then + echo "Usage: $0 [--mount_dir] [--headless]" + echo "Example: $0 vllm --mount_dir --headless" + exit 1 +fi + +# Build docker run command +if [ "$HEADLESS" = true ]; then + DOCKER_CMD="docker run -itd --gpus all --network host --ipc=host -e HF_TOKEN --rm --name vllm_dev" +else + DOCKER_CMD="docker run -it --gpus all --network host --ipc=host -e HF_TOKEN --rm --name vllm_dev" +fi + +if [ "$MOUNT_DIR" = true ]; then + # Mount the project root (two levels up from this script) + PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + DOCKER_CMD="$DOCKER_CMD -v ${PROJECT_ROOT}:/workspace/bionemo-framework" +fi + +DOCKER_CMD="$DOCKER_CMD $CONTAINER /bin/bash" + +exec $DOCKER_CMD diff --git a/bionemo-recipes/vllm/test_vllm_golden_values.py b/bionemo-recipes/vllm/test_esm2_golden_values.py similarity index 100% rename from bionemo-recipes/vllm/test_vllm_golden_values.py rename to bionemo-recipes/vllm/test_esm2_golden_values.py From c34c09b339c0872a07074bd4a43b78e288cc87da Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Mon, 23 Feb 2026 17:03:22 +0000 Subject: [PATCH 6/9] remove zombie files --- .../models/esm2/tests/test_convert.py | 185 ---------- bionemo-recipes/models/esm2/tests/test_fp8.py | 192 ---------- .../esm2/tests/test_meta_device_init.py | 294 ---------------- bionemo-recipes/models/esm2/tests/test_thd.py | 329 ------------------ 4 files changed, 1000 deletions(-) delete mode 100644 bionemo-recipes/models/esm2/tests/test_convert.py delete mode 100644 bionemo-recipes/models/esm2/tests/test_fp8.py delete mode 100644 bionemo-recipes/models/esm2/tests/test_meta_device_init.py delete mode 100644 bionemo-recipes/models/esm2/tests/test_thd.py diff --git a/bionemo-recipes/models/esm2/tests/test_convert.py b/bionemo-recipes/models/esm2/tests/test_convert.py deleted file mode 100644 index dc3c1714a..000000000 --- a/bionemo-recipes/models/esm2/tests/test_convert.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from transformers import AutoModelForMaskedLM - - -def test_convert_te_to_hf_roundtrip(): - """Test that converting HF -> TE -> HF produces the same model.""" - from convert import convert_esm_hf_to_te, convert_esm_te_to_hf - - model_hf_original = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") - - model_te = convert_esm_hf_to_te(model_hf_original) - model_hf_converted = convert_esm_te_to_hf(model_te) - - original_state_dict = model_hf_original.state_dict() - converted_state_dict = model_hf_converted.state_dict() - original_keys = {k for k in original_state_dict.keys() if "contact_head" not in k} - converted_keys = set(converted_state_dict.keys()) - assert original_keys == converted_keys - - for key in original_state_dict.keys(): - if not key.endswith("_extra_state") and not key.endswith("inv_freq") and "contact_head" not in key: - torch.testing.assert_close(original_state_dict[key], converted_state_dict[key], atol=1e-5, rtol=1e-5) - - -def test_load_from_converted_checkpoint(te_model_checkpoint): - from modeling_esm_te import NVEsmForMaskedLM - - NVEsmForMaskedLM.from_pretrained(te_model_checkpoint) - - -def test_qkv_unpacking(): - """Test that QKV unpacking works correctly.""" - from convert import convert_esm_hf_to_te, convert_esm_te_to_hf - - model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") - model_te = convert_esm_hf_to_te(model_hf) - model_hf_converted = convert_esm_te_to_hf(model_te) - - for i in range(model_hf.config.num_hidden_layers): - hf_query = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] - hf_key = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] - hf_value = model_hf.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] - - converted_query = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.query.weight"] - converted_key = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.key.weight"] - converted_value = model_hf_converted.state_dict()[f"esm.encoder.layer.{i}.attention.self.value.weight"] - - torch.testing.assert_close(hf_query, converted_query, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(hf_key, converted_key, atol=1e-5, rtol=1e-5) - torch.testing.assert_close(hf_value, converted_value, atol=1e-5, rtol=1e-5) - - -def test_config_conversion(): - """Test that config conversion works correctly.""" - from convert import convert_esm_hf_to_te, convert_esm_te_to_hf - - model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") - model_te = convert_esm_hf_to_te(model_hf) - model_hf_converted = convert_esm_te_to_hf(model_te) - - original_config_dict = model_hf.config.to_dict() - converted_config_dict = model_hf_converted.config.to_dict() - - for key, value in original_config_dict.items(): - assert key in converted_config_dict, f"Config field '{key}' missing in converted model" - assert converted_config_dict[key] == value, ( - f"Config field '{key}' differs: original={value}, converted={converted_config_dict[key]}" - ) - - assert model_hf_converted.config.model_type == "esm" - - te_specific_fields = [ - "qkv_weight_interleaved", - "encoder_activation", - "attn_input_format", - "fuse_qkv_params", - "micro_batch_size", - "auto_map", - ] - for field in te_specific_fields: - assert not hasattr(model_hf_converted.config, field), ( - f"TE-specific field '{field}' should not be present in converted model" - ) - - -def test_padding_unpadding_operations(): - """Test that padding and unpadding operations work correctly for embeddings and decoder weights.""" - from convert import convert_esm_hf_to_te, convert_esm_te_to_hf - - model_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f") - model_te = convert_esm_hf_to_te(model_hf) - model_hf_converted = convert_esm_te_to_hf(model_te) - - # Test word embeddings - original_embeddings = model_hf.state_dict()["esm.embeddings.word_embeddings.weight"] - converted_embeddings = model_hf_converted.state_dict()["esm.embeddings.word_embeddings.weight"] - assert original_embeddings.shape == converted_embeddings.shape, ( - f"Embedding shapes don't match: {original_embeddings.shape} vs {converted_embeddings.shape}" - ) - torch.testing.assert_close(original_embeddings, converted_embeddings, atol=1e-5, rtol=1e-5) - - # Test decoder weights - original_decoder = model_hf.state_dict()["lm_head.decoder.weight"] - converted_decoder = model_hf_converted.state_dict()["lm_head.decoder.weight"] - assert original_decoder.shape == converted_decoder.shape, ( - f"Decoder shapes don't match: {original_decoder.shape} vs {converted_decoder.shape}" - ) - torch.testing.assert_close(original_decoder, converted_decoder, atol=1e-5, rtol=1e-5) - - # Test bias - original_bias = model_hf.state_dict()["lm_head.bias"] - converted_bias = model_hf_converted.state_dict()["lm_head.bias"] - assert original_bias.shape == converted_bias.shape, ( - f"Bias shapes don't match: {original_bias.shape} vs {converted_bias.shape}" - ) - torch.testing.assert_close(original_bias, converted_bias, atol=1e-5, rtol=1e-5) - - # Test that TE model has padded dimensions - te_embeddings = model_te.state_dict()["model.embeddings.word_embeddings.weight"] - te_decoder = model_te.state_dict()["lm_head.decoder.weight"] - assert te_embeddings.shape[0] >= original_embeddings.shape[0], "TE embeddings should be padded" - assert te_decoder.shape[0] >= original_decoder.shape[0], "TE decoder should be padded" - - # The padded parts should be zeros (for embeddings) or min values (for bias) - if te_embeddings.shape[0] > original_embeddings.shape[0]: - padding_rows = te_embeddings[original_embeddings.shape[0] :] - torch.testing.assert_close(padding_rows, torch.zeros_like(padding_rows), atol=1e-6, rtol=1e-6) - - -def test_weight_initialization_matches_hf(): - from transformers import AutoConfig, set_seed - from transformers.models.esm.modeling_esm import EsmForMaskedLM - - from convert import convert_esm_hf_to_te - from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM - - set_seed(42) - - config_hf = AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", vocab_size=64, revision="c731040f") - model_hf = EsmForMaskedLM(config_hf) - model_te_converted = convert_esm_hf_to_te(model_hf) - - config = NVEsmConfig(**model_hf.config.to_dict()) - model_te = NVEsmForMaskedLM(config) - model_te.to("cuda") - model_te_converted.to("cuda") - - state_dict_hf = model_te_converted.state_dict() - state_dict_te = model_te.state_dict() - - for name in state_dict_hf.keys(): - if name.endswith("_extra_state"): - continue - - torch.testing.assert_close( - state_dict_te[name].mean(), - state_dict_hf[name].mean(), - atol=1e-3, - rtol=1e-4, - msg=lambda x: f"Mean mismatch for parameter {name}: {x}", - ) - - torch.testing.assert_close( - state_dict_te[name].std(), - state_dict_hf[name].std(), - atol=1e-3, - rtol=1e-4, - msg=lambda x: f"Std mismatch for parameter {name}: {x}", - ) diff --git a/bionemo-recipes/models/esm2/tests/test_fp8.py b/bionemo-recipes/models/esm2/tests/test_fp8.py deleted file mode 100644 index 8de7dd06a..000000000 --- a/bionemo-recipes/models/esm2/tests/test_fp8.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch -import torch.distributed.checkpoint as dcp -import transformer_engine -from torch.distributed.checkpoint.state_dict import get_model_state_dict -from transformer_engine.common import recipe as recipe_module -from transformers import DataCollatorForLanguageModeling - -from collator import DataCollatorWithFlattening -from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM - - -try: - from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor -except ImportError: # TE nightly uses a new import path for QuantizedTensor - from transformer_engine.pytorch.quantized_tensor import QuantizedTensor - - -@pytest.fixture -def input_data_thd(tokenizer, tokenized_proteins): - mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, seed=42) - data_collator = DataCollatorWithFlattening( - collator=mlm_collator, - pad_to_multiple_of=32, # MXFP8 requires the sequence length to be divisible by 32, regular FP8 requires 16. - ) - - return data_collator(tokenized_proteins) - - -def test_fp8_forward_and_backward_pass(te_model_checkpoint, input_data, fp8_recipe): - model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - model_te.to("cuda") - - input_data = {k: v.to("cuda") for k, v in input_data.items()} - outputs = model_te(**input_data) - - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs_fp8 = model_te(**input_data) - outputs_fp8.loss.backward() - - if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling): - atol = 0.2 - rtol = 0.05 - else: - atol = None - rtol = None - - torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol) - - -def test_fp8_forward_and_backward_pass_thd(te_model_checkpoint, input_data_thd, fp8_recipe, monkeypatch): - if torch.cuda.get_device_capability() == (12, 0): - # TODO(BIONEMO-2840): On sm120, we need to set NVTE_FUSED_ATTN to 0 since TE will choose fused attn by default, - # but it's missing this THD implementation. - monkeypatch.setenv("NVTE_FUSED_ATTN", "0") - - model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - model_te.to("cuda") - - input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - outputs = model_te(**input_data) - - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs_fp8 = model_te(**input_data) - outputs_fp8.loss.backward() - - if isinstance(fp8_recipe, recipe_module.NVFP4BlockScaling): - atol = 0.2 - rtol = 0.05 - elif isinstance(fp8_recipe, recipe_module.DelayedScaling): - atol = 0.1 - rtol = 0.03 - else: - atol = None - rtol = None - - torch.testing.assert_close(outputs_fp8.loss, outputs.loss, atol=atol, rtol=rtol) - - -def test_fp8_model_init_forward_and_backward(te_model_checkpoint, input_data, fp8_recipe): - config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): - model_te = NVEsmForMaskedLM(config) - - assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor) - - model_te.to("cuda") - input_data = {k: v.to("cuda") for k, v in input_data.items()} - - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs_fp8 = model_te(**input_data) - - outputs_fp8.loss.backward() - - -@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained loading is not currently supported.") -def test_fp8_model_init_from_pretrained(te_model_checkpoint, fp8_recipe): - # TODO: this will be renamed to quantized_model_init in the future, fp8_model_init will be removed in 3.0 - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): - model_te = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - - assert isinstance(model_te.model.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor) - assert isinstance(model_te.lm_head.dense.weight, QuantizedTensor) - - -@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init and pretrained saving is not currently supported.") -def test_fp8_model_init_save_pretrained(te_model_checkpoint, tmp_path, fp8_recipe): - config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): - model_fp8 = NVEsmForMaskedLM(config) - - assert isinstance(model_fp8.model.encoder.layers[0].layernorm_mlp.fc2_weight, QuantizedTensor) - assert isinstance(model_fp8.lm_head.dense.weight, QuantizedTensor) - - model_fp8.save_pretrained(tmp_path / "fp8_checkpoint") - del model_fp8 - NVEsmForMaskedLM.from_pretrained(tmp_path / "fp8_checkpoint", dtype=torch.bfloat16) - - -def test_fp8_model_distributed_checkpointing_save_and_load(te_model_checkpoint, tmp_path, input_data, fp8_recipe): - config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): - model_fp8 = NVEsmForMaskedLM(config) - - model_fp8.to("cuda") - input_data = {k: v.to("cuda") for k, v in input_data.items()} - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - outputs = model_fp8(**input_data) - outputs.loss.backward() - - state_dict = get_model_state_dict(model_fp8) - state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")} - dcp.save(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint") - - del model_fp8, state_dict - - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe): - model_fp8 = NVEsmForMaskedLM(config) - - state_dict = model_fp8.state_dict() - state_dict = {key: val for key, val in state_dict.items() if not key.endswith("_extra_state")} - dcp.load(state_dict, checkpoint_id=tmp_path / "fp8_checkpoint") - - -def _format_bytes(num: int, suffix: str = "B") -> str: - """Format bytes as a human-readable string (e.g. 1.2 MB).""" - for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): - if abs(num) < 1024.0: - return f"{num:3.1f} {unit}{suffix}" - num /= 1024.0 - return f"{num:.1f} Y{suffix}" - - -@pytest.mark.xfail(reason="BIONEMO-3055: fp8 model init seems to have issues.") -def test_fp8_model_init_uses_less_memory(te_model_checkpoint, fp8_recipe): - torch.cuda.empty_cache() - - config = NVEsmConfig.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - torch.cuda.reset_peak_memory_stats() - memory_before = torch.cuda.memory_allocated() - with transformer_engine.pytorch.fp8_model_init(enabled=True, recipe=fp8_recipe), torch.device("cuda"): - model_fp8 = NVEsmForMaskedLM(config) - peak_memory_fp8 = torch.cuda.max_memory_allocated() - memory_before - del model_fp8 - torch.cuda.empty_cache() - - torch.cuda.reset_peak_memory_stats() - memory_before = torch.cuda.memory_allocated() - with transformer_engine.pytorch.fp8_model_init(enabled=False, recipe=fp8_recipe), torch.device("cuda"): - model_bf16 = NVEsmForMaskedLM(config) - peak_memory_bf16 = torch.cuda.max_memory_allocated() - memory_before - del model_bf16 - - assert peak_memory_fp8 < peak_memory_bf16, ( - f"FP8 model init uses more memory than BF16 model init: {_format_bytes(peak_memory_fp8)} " - f"vs {_format_bytes(peak_memory_bf16)}" - ) diff --git a/bionemo-recipes/models/esm2/tests/test_meta_device_init.py b/bionemo-recipes/models/esm2/tests/test_meta_device_init.py deleted file mode 100644 index bc38efe25..000000000 --- a/bionemo-recipes/models/esm2/tests/test_meta_device_init.py +++ /dev/null @@ -1,294 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test that parameter distributions are identical with and without meta device initialization. - -These tests verify that when using meta device initialization (creating the model on meta device, then calling -`to_empty` and `_init_weights`), the resulting parameter distributions (mean and std) match those from normal -initialization. This is important because we previously observed differences in convergence between meta-device-init and -non-meta-device-init training, which suggested that the initialization was not being applied correctly after `to_empty`. -By explicitly calling `_init_weights` after `to_empty`, we ensure that parameters are properly initialized, leading to -consistent training behavior regardless of whether meta device initialization is used. -""" - -import os -import subprocess - -import pytest -import torch -import transformer_engine.pytorch -from torch.distributed.fsdp import fully_shard -from torch.distributed.tensor import DTensor -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformers import AutoConfig, set_seed - -from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM, NVEsmForTokenClassification - - -requires_multi_gpu = pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="Test requires at least 2 GPUs", -) - - -def verify_model_parameters_initialized_correctly( - model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False -): - config = model.config - - for name, parameter in model.named_parameters(): - assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" - - for name, module in model.named_modules(): - - def msg(x): - return f"Mismatch in module {name}: {x}" - - if isinstance(module, torch.nn.Embedding): - torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close( - module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg - ) - - elif name == "lm_head.decoder": - # Make sure the lm_head decoder weights are still tied to the encoder weights - assert module.weight is model.model.embeddings.word_embeddings.weight, ( - "Decoder weight tying has been broken" - ) - - elif isinstance(module, transformer_engine.pytorch.Linear): - torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close( - module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg - ) - torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) - if should_be_fp8: - assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a QuantizedTensor" - - elif isinstance(module, transformer_engine.pytorch.LayerNormLinear): - torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close( - module.weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg - ) - torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) - torch.testing.assert_close(module.layer_norm_weight, torch.ones_like(module.layer_norm_weight), msg=msg) - torch.testing.assert_close(module.layer_norm_bias, torch.zeros_like(module.layer_norm_bias), msg=msg) - if should_be_fp8: - assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a QuantizedTensor" - - elif isinstance(module, transformer_engine.pytorch.LayerNormMLP): - torch.testing.assert_close(module.fc1_weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close( - module.fc1_weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg - ) - torch.testing.assert_close(module.fc2_weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close( - module.fc2_weight.std().item(), config.initializer_range, atol=atol, rtol=rtol, msg=msg - ) - torch.testing.assert_close(module.fc1_bias, torch.zeros_like(module.fc1_bias), msg=msg) - torch.testing.assert_close(module.fc2_bias, torch.zeros_like(module.fc2_bias), msg=msg) - torch.testing.assert_close(module.layer_norm_weight, torch.ones_like(module.layer_norm_weight), msg=msg) - torch.testing.assert_close(module.layer_norm_bias, torch.zeros_like(module.layer_norm_bias), msg=msg) - if should_be_fp8: - assert isinstance(module.fc1_weight, QuantizedTensor), ( - f"Module {name} fc1_weight is not a QuantizedTensor" - ) - assert isinstance(module.fc2_weight, QuantizedTensor), ( - f"Module {name} fc2_weight is not a QuantizedTensor" - ) - - elif isinstance(module, torch.nn.LayerNorm): - torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) - torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) - - elif isinstance(module, transformer_engine.pytorch.attention.rope.RotaryPositionEmbedding): - dim = config.hidden_size // config.num_attention_heads - expected_inv_freq = 1.0 / (10_000.0 ** (torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") / dim)) - torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg) - - -def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-2, rtol=1e-3): - for name, p in model.named_parameters(): - - def msg(x): - return f"Mismatch in parameter {name}: {x}" - - assert p.numel() > 0, f"{name} is empty" - assert torch.isfinite(p).all(), f"{name} has NaN/Inf" - - max_abs = p.abs().max().item() - assert max_abs < 1e3, f"{name} extreme values: {max_abs}" - - if name == "classifier.weight": - torch.testing.assert_close(p.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) - torch.testing.assert_close(p.std().item(), model.config.initializer_range, atol=atol, rtol=rtol, msg=msg) - - if name == "classifier.bias": - torch.testing.assert_close(p, torch.zeros_like(p), msg=msg) - - -def test_cuda_init(): - config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) - - set_seed(42) - model = NVEsmForMaskedLM(config) - model.to("cuda") - - verify_model_parameters_initialized_correctly(model) - - -def test_meta_init(): - config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) - - set_seed(42) - with torch.device("meta"): - model = NVEsmForMaskedLM(config) - - # Assert parameters are actually on the meta device - for name, parameter in model.named_parameters(): - assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" - - # Move the model to the cuda device and initialize the parameters - model.init_empty_weights() - - verify_model_parameters_initialized_correctly(model) - - -def test_cuda_fp8_init(fp8_recipe): - config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) - - set_seed(42) - with transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe): - model = NVEsmForMaskedLM(config) - - model.to("cuda") - - verify_model_parameters_initialized_correctly(model, atol=1e-2, should_be_fp8=True) - - -def test_meta_fp8_init(fp8_recipe): - config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", revision="c731040f").to_dict()) - - set_seed(42) - with transformer_engine.pytorch.fp8_model_init(recipe=fp8_recipe), torch.device("meta"): - model = NVEsmForMaskedLM(config) - - # Move the model to the cuda device and initialize the parameters - model.init_empty_weights() - - verify_model_parameters_initialized_correctly(model, should_be_fp8=True) - - -def test_model_for_token_classification_init(te_model_checkpoint): - set_seed(42) - - config = NVEsmConfig.from_pretrained(te_model_checkpoint) - model = NVEsmForTokenClassification.from_pretrained(te_model_checkpoint, config=config, dtype=torch.bfloat16) - # model.classifier.reset_parameters() - model.to("cuda") - verify_pretrained_model_sanity(model) - - -@pytest.mark.parametrize("num_gpus", [1, pytest.param(2, marks=requires_multi_gpu)]) -def test_meta_device_init_after_fully_shard(num_gpus: int): - cmd = [ - "torchrun", - f"--nproc_per_node={num_gpus}", - os.path.relpath(__file__), - ] - - result = subprocess.run( - cmd, - check=False, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=240, - ) - - if result.returncode != 0: - print(f"STDOUT:\n{result.stdout}") - print(f"STDERR:\n{result.stderr}") - pytest.fail(f"Command failed with exit code {result.returncode}") - - -if __name__ == "__main__": - torch.distributed.init_process_group(backend="cuda:nccl") - torch.cuda.set_device(torch.distributed.get_rank()) - - config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict()) - - set_seed(42) - - with torch.device("meta"): - model_meta_device = NVEsmForMaskedLM(config) - - for layer in model_meta_device.model.encoder.layers: - fully_shard(layer) - fully_shard(model_meta_device) - - # Assert parameters are actually on the meta device - for name, parameter in model_meta_device.named_parameters(): - assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" - - model_meta_device.init_empty_weights() - - # Assert parameters are actually on the cuda device after to_empty - for name, parameter in model_meta_device.named_parameters(): - assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" - - set_seed(42) - model_normal_init = NVEsmForMaskedLM(config) - - for layer in model_normal_init.model.encoder.layers: - fully_shard(layer) - fully_shard(model_normal_init) - - state_dict_meta_init = model_meta_device.state_dict() - state_dict_normal_init = model_normal_init.state_dict() - - for key in state_dict_meta_init.keys(): - if key.endswith("_extra_state"): - continue - - meta_tensor = state_dict_meta_init[key] - normal_tensor = state_dict_normal_init[key] - - torch.testing.assert_close( - normal_tensor.mean(), - meta_tensor.mean(), - atol=1e-3, - rtol=1e-4, - msg=lambda x: f"Mean mismatch for parameter {key}: {x}", - ) - - if isinstance(normal_tensor, DTensor) and isinstance(meta_tensor, DTensor): - torch.testing.assert_close( - normal_tensor.full_tensor().std(), - meta_tensor.full_tensor().std(), - atol=1e-2, - rtol=1e-4, - msg=lambda x: f"Std mismatch for parameter {key}: {x}", - ) - - else: - torch.testing.assert_close( - normal_tensor.std(), - meta_tensor.std(), - atol=1e-2, - rtol=1e-4, - msg=lambda x: f"Std mismatch for parameter {key}: {x}", - ) diff --git a/bionemo-recipes/models/esm2/tests/test_thd.py b/bionemo-recipes/models/esm2/tests/test_thd.py deleted file mode 100644 index 07b624389..000000000 --- a/bionemo-recipes/models/esm2/tests/test_thd.py +++ /dev/null @@ -1,329 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import os - -import pytest -import torch -from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends -from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp -from transformers import DataCollatorForLanguageModeling - -from collator import DataCollatorWithFlattening -from modeling_esm_te import NVEsmConfig, NVEsmEmbeddings, NVEsmForMaskedLM - - -compute_capability = torch.cuda.get_device_capability() - -# TODO(@jomitchell): Delete once https://nvbugspro.nvidia.com/bug/5458694 is fixed. -requires_datacenter_hardware = pytest.mark.skipif( - not torch.cuda.is_available() - or not any( - gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] - ), - reason="Test requires datacenter hardware (H100, H200, B100, B200, B300)", -) - - -@pytest.fixture -def input_data_thd(tokenizer, tokenized_proteins): - """The collator here needs to exactly match the one used in the `input_data` fixture for golden values to pass.""" - data_collator = DataCollatorWithFlattening( - collator=DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=0.15, - pad_to_multiple_of=32, - seed=42, - ) - ) - return data_collator(tokenized_proteins) - - -@pytest.fixture -def input_data_thd_padded_from_input_data_thd(input_data_thd): - input_data_thd_padded = copy.deepcopy(input_data_thd) - input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( - input_data_thd_padded["input_ids"], - input_data_thd_padded["labels"], - input_data_thd_padded["cu_seq_lens_q"], - 16, - padding_token_id=1, - padding_label_id=-100, - ) - - input_data_thd_padded["input_ids"] = input_ids_padded.unsqueeze(0) - input_data_thd_padded["labels"] = labels_padded.unsqueeze(0) - input_data_thd_padded["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) - input_data_thd_padded["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) - input_data_thd_padded["pad_between_seqs"] = True - return input_data_thd_padded - - -@pytest.mark.parametrize("use_token_dropout", [True, False]) -def test_nv_esm_embeddings_random_init(te_model_checkpoint, input_data_thd, input_data, use_token_dropout): - config = NVEsmConfig.from_pretrained(te_model_checkpoint) - assert config.token_dropout is True - embedding = NVEsmEmbeddings(config) - embedding.token_dropout = use_token_dropout - embedding.to("cuda") - - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - input_data_thd.pop("labels") - outputs_thd = embedding(**input_data_thd) - - input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} - input_data_bshd.pop("labels") - outputs_bshd = embedding(**input_data_bshd) - - # Reshape outputs_bshd to match outputs_thd - outputs_bshd = outputs_bshd[input_data_bshd["attention_mask"].to(bool)].unsqueeze(0) - torch.testing.assert_close(outputs_thd, outputs_bshd, atol=1e-8, rtol=1e-8) - - -@pytest.mark.parametrize("use_token_dropout", [True, False]) -def test_nv_esm_embeddings_from_model(te_model_checkpoint, input_data_thd, input_data, use_token_dropout): - model = NVEsmForMaskedLM.from_pretrained( - te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16, token_dropout=use_token_dropout - ) - embedding = model.model.embeddings - assert embedding.token_dropout == use_token_dropout - embedding.to("cuda") - - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - input_data_thd.pop("labels") - outputs_thd = embedding(**input_data_thd) - - input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} - input_data_bshd.pop("labels") - outputs_bshd = embedding(**input_data_bshd) - - # Reshape outputs_bshd to match outputs_thd - outputs_bshd = outputs_bshd[input_data_bshd["attention_mask"].to(bool)].unsqueeze(0) - torch.testing.assert_close(outputs_thd, outputs_bshd, atol=1e-8, rtol=1e-8) - - -def test_thd_from_collator_output(te_model_checkpoint, input_data_thd): - model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - model_thd.to("cuda") - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - with torch.no_grad(): - outputs = model_thd(**input_data_thd, output_hidden_states=True) - - assert outputs.loss < 3.0 - - -@pytest.fixture(params=["flash_attn", "fused_attn"]) -def attn_impl(request, monkeypatch): - if request.param == "flash_attn": - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_FLASH_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - - else: - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_FLASH_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - - return request.param - - -def test_thd_losses_match(te_model_checkpoint, input_data, input_data_thd, attn_impl): - if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: - pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") - elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: - pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") - - torch.testing.assert_close( - input_data["input_ids"][input_data["attention_mask"].to(bool)], - input_data_thd["input_ids"].flatten(0), - ) - - torch.testing.assert_close( - input_data["labels"][input_data["attention_mask"].to(bool)], - input_data_thd["labels"].flatten(0), - ) - - model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - model_bshd.to("cuda") - model_thd.to("cuda") - - input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - - bshd_outputs = model_bshd(**input_data_bshd) - thd_outputs = model_thd(**input_data_thd) - - torch.testing.assert_close(bshd_outputs.loss, thd_outputs.loss) - - -def test_thd_logits_match_with_bf16_autocast(te_model_checkpoint, input_data, input_data_thd, attn_impl): - if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: - pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") - elif attn_impl == "flash_attn" and torch.cuda.get_device_capability()[0] == 8: - pytest.xfail("BIONEMO-2801: On Ada and Ampere, the flash attention logits don't seem to match.") - elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: - pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") - - # Ensure the input data is the same - torch.testing.assert_close( - input_data["input_ids"][input_data["attention_mask"].to(bool)], - input_data_thd["input_ids"].flatten(0), - ) - - torch.testing.assert_close( - input_data["labels"][input_data["attention_mask"].to(bool)], - input_data_thd["labels"].flatten(0), - ) - - # Create models - model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - - model_bshd.to("cuda") - model_thd.to("cuda") - - input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - - thd_outputs = model_thd(**input_data_thd, output_hidden_states=True) - bshd_outputs = model_bshd(**input_data_bshd, output_hidden_states=True) - - for i, (bshd_hidden, thd_hidden) in enumerate(zip(bshd_outputs.hidden_states, thd_outputs.hidden_states)): - torch.testing.assert_close( - bshd_hidden[input_data_bshd["attention_mask"].to(bool)], - thd_hidden.squeeze(0), - msg=lambda msg: "Hidden states do not match going into layer " + str(i + 1) + ": " + msg, - atol=1e-1 if compute_capability[0] == 8 else 1e-5, - rtol=1.6e-2, - ) - - if compute_capability[0] == 8: - break # On Ada and Ampere, we see much larger numerical errors so we stop after the first layer - - bshd_logits = bshd_outputs.logits[input_data_bshd["attention_mask"].to(bool)] - torch.testing.assert_close(bshd_logits, thd_outputs.logits, atol=1e-8, rtol=1e-8) - - -def test_thd_backwards_works(te_model_checkpoint, input_data_thd, attn_impl): - if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0): - pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.") - elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: - pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") - - model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - model_thd.to("cuda") - input_data = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - outputs = model_thd(**input_data) - outputs.loss.backward() - - -def test_thd_backwards_passes_match(te_model_checkpoint, input_data, input_data_thd, attn_impl): - if attn_impl == "fused_attn" and torch.cuda.get_device_capability() == (12, 0): - pytest.xfail("BIONEMO-2840: On sm120 the THD backwards implementation is not available for fused attn.") - elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: - pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") - - torch.testing.assert_close( - input_data["input_ids"][input_data["attention_mask"].to(bool)], - input_data_thd["input_ids"].flatten(0), - ) - - torch.testing.assert_close( - input_data["labels"][input_data["attention_mask"].to(bool)], - input_data_thd["labels"].flatten(0), - ) - - model_bshd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, dtype=torch.bfloat16) - model_thd = NVEsmForMaskedLM.from_pretrained(te_model_checkpoint, attn_input_format="thd", dtype=torch.bfloat16) - model_bshd.to("cuda") - model_thd.to("cuda") - - input_data_bshd = {k: v.to("cuda") for k, v in input_data.items()} - input_data_thd = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - - bshd_outputs = model_bshd(**input_data_bshd) - thd_outputs = model_thd(**input_data_thd) - - thd_outputs.loss.backward() - bshd_outputs.loss.backward() - - thd_grads = {name: p.grad for name, p in model_thd.named_parameters() if p.grad is not None} - bshd_grads = {name: p.grad for name, p in model_bshd.named_parameters() if p.grad is not None} - - # max_diff_by_layer = {key: (thd_grads[key] - bshd_grads[key]).abs().max().item() for key in thd_grads.keys()} - - # For some reason, the word embeddings grads have a slightly higher numerical error. - thd_word_embeddings_grad = thd_grads.pop("model.embeddings.word_embeddings.weight") - bshd_word_embeddings_grad = bshd_grads.pop("model.embeddings.word_embeddings.weight") - torch.testing.assert_close( - thd_grads, - bshd_grads, - atol=1e-2 if compute_capability[0] == 8 else 1e-5, - rtol=1.6e-2, - ) - - torch.testing.assert_close(thd_word_embeddings_grad, bshd_word_embeddings_grad, atol=1e-2, rtol=1e-5) - - -@requires_datacenter_hardware -def test_thd_vs_padded_thd_equivalence( - te_model_checkpoint, input_data_thd, input_data_thd_padded_from_input_data_thd, attn_impl -): - if attn_impl == "flash_attn": - pytest.xfail("Flash attention is not supported for padded sequences.") - - input_data_thd_padded = input_data_thd_padded_from_input_data_thd - seqlens_q = input_data_thd_padded["cu_seq_lens_q_padded"][1:] - input_data_thd_padded["cu_seq_lens_q_padded"][:-1] - max_length_q = int((seqlens_q.max().item() + 63) // 64 * 64) # TODO(@jomitchell): Not sure if I need this anymore. - max_length_k = max_length_q - input_data_thd_padded["max_length_q"] = max_length_q - input_data_thd_padded["max_length_k"] = max_length_k - - input_data_thd_gpu = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd.items()} - input_data_thd_padded_gpu = { - k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in input_data_thd_padded.items() - } - - # Run the input data thd through. - model_thd = NVEsmForMaskedLM.from_pretrained( - te_model_checkpoint, attn_input_format="thd", token_dropout=False, dtype=torch.bfloat16 - ) - model_thd.to("cuda") - outputs_thd = model_thd(**input_data_thd_gpu) - outputs_thd_padded = model_thd(**input_data_thd_padded_gpu) - - cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] - cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] - cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q - seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] - - num_real_tokens = outputs_thd.logits.shape[0] # should be cu_seq_lens_q[-1] - - # How much we need to shift each sequence by. - offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) - - # The indices of the real tokens as appears in the padded logits. - real_idx = torch.arange(0, num_real_tokens) + offsets - - assert ( - input_data_thd["input_ids"].squeeze() - input_data_thd_padded["input_ids"].squeeze().index_select(0, real_idx) - ).abs().max().item() == 0 - - # Now index select the padded logits to get the real logits. - logits_unpadded = outputs_thd_padded.logits.index_select(0, real_idx.cuda()) - - torch.testing.assert_close(outputs_thd.logits, logits_unpadded, atol=1e-8, rtol=1e-5) From a67df148799b7bbf2545428e4132538a28875451 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 25 Feb 2026 15:47:44 +0000 Subject: [PATCH 7/9] addressed feedback Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/README.md | 25 +++ bionemo-recipes/models/esm2/export.py | 14 +- .../models/esm2/modeling_esm_te.py | 36 +-- bionemo-recipes/models/esm2/requirements.txt | 1 + .../models/esm2/tests/test_vllm.py | 146 ++++++++++++ .../example_8m_checkpoint/esm_nv.py | 36 +-- .../example_8m_checkpoint/esm_nv.py | 36 +-- .../example_8m_checkpoint/esm_nv.py | 36 +-- bionemo-recipes/vllm/Dockerfile | 36 --- bionemo-recipes/vllm/README.md | 23 -- bionemo-recipes/vllm/launch.sh | 50 ----- .../vllm/test_esm2_golden_values.py | 210 ------------------ 12 files changed, 259 insertions(+), 390 deletions(-) create mode 100644 bionemo-recipes/models/esm2/tests/test_vllm.py delete mode 100644 bionemo-recipes/vllm/Dockerfile delete mode 100644 bionemo-recipes/vllm/README.md delete mode 100644 bionemo-recipes/vllm/launch.sh delete mode 100644 bionemo-recipes/vllm/test_esm2_golden_values.py diff --git a/bionemo-recipes/models/esm2/README.md b/bionemo-recipes/models/esm2/README.md index a0f8effbb..2f7a917dd 100644 --- a/bionemo-recipes/models/esm2/README.md +++ b/bionemo-recipes/models/esm2/README.md @@ -60,6 +60,31 @@ inputs = tokenizer(gfp_P42212, return_tensors="pt") output = model(**inputs) ``` +### Running inference with vLLM + +The exported TE checkpoints on HuggingFace Hub are directly compatible with +[vLLM](https://github.com/vllm-project/vllm) (>= 0.14) as pooling/embedding models. +No conversion scripts or weight renaming are needed: + +```python +from vllm import LLM + +model = LLM( + model="nvidia/esm2_t6_8M_UR50D", + runner="pooling", + trust_remote_code=True, + enforce_eager=True, + max_num_batched_tokens=1026, +) + +prompts = ["MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"] +outputs = model.embed(prompts) +print(outputs[0].outputs.embedding[:5]) +``` + +See [tests/test_vllm.py](tests/test_vllm.py) for a full golden-value validation across +vLLM, native HuggingFace, and the nvidia Hub reference model. + ## Recipe Links Training recipes are available in the `bionemo-recipes/recipes/` directory: diff --git a/bionemo-recipes/models/esm2/export.py b/bionemo-recipes/models/esm2/export.py index 3ac5e89f2..748e46784 100644 --- a/bionemo-recipes/models/esm2/export.py +++ b/bionemo-recipes/models/esm2/export.py @@ -72,11 +72,9 @@ def export_hf_checkpoint(tag: str, export_path: Path): model_hf = AutoModel.from_pretrained(f"facebook/{tag}") model_hf_masked_lm.esm.pooler = model_hf.pooler - # Export with padded_vocab_size=None (defaults to vocab_size) so that the checkpoint - # stores embeddings/decoder at the real vocab_size without zero-padding. Padding is - # only needed at runtime for FP8 training efficiency; users who train with FP8 pass - # padded_vocab_size explicitly. Keeping vocab_size-sized weights in the checkpoint - # avoids shape-mismatch assertions in vLLM's VocabParallelEmbedding. + # Export without vocab padding so the checkpoint stores embeddings at the real + # vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding, + # which expects vocab_size-shaped weights. model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) model_te.save_pretrained(export_path / tag) @@ -89,12 +87,6 @@ def export_hf_checkpoint(tag: str, export_path: Path): config["auto_map"] = AUTO_MAP - # Disable pooler in the exported checkpoint. NVEsmForMaskedLM saves with - # add_pooling_layer=False, so pooler weights are absent. Setting this to false - # prevents vLLM from creating a pooler module and then erroring on missing weights. - # (HuggingFace tolerates missing keys via strict=False, but vLLM does not.) - config["add_pooling_layer"] = False - with open(export_path / tag / "config.json", "w") as f: json.dump(config, f, indent=2, sort_keys=True) diff --git a/bionemo-recipes/models/esm2/modeling_esm_te.py b/bionemo-recipes/models/esm2/modeling_esm_te.py index fb2e4136d..05e22b9c0 100644 --- a/bionemo-recipes/models/esm2/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/modeling_esm_te.py @@ -70,7 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", - add_pooling_layer: bool = True, + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -101,9 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. - add_pooling_layer: Whether the base model should include a pooling layer. Set to - ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` - (which does not use a pooler). This avoids missing-weight errors in vLLM. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -405,6 +405,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel): _tied_weights_keys: ClassVar[dict[str, str]] = { "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -463,7 +464,8 @@ def forward( **kwargs, ) sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -494,15 +496,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) - with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -522,7 +524,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.fp8_autocast(enabled=False): + with transformer_engine.pytorch.autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -617,7 +619,11 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: diff --git a/bionemo-recipes/models/esm2/requirements.txt b/bionemo-recipes/models/esm2/requirements.txt index b4c358680..9f2bb0e09 100644 --- a/bionemo-recipes/models/esm2/requirements.txt +++ b/bionemo-recipes/models/esm2/requirements.txt @@ -9,3 +9,4 @@ torch torchao!=0.14.0 transformer_engine[pytorch] transformers +vllm diff --git a/bionemo-recipes/models/esm2/tests/test_vllm.py b/bionemo-recipes/models/esm2/tests/test_vllm.py new file mode 100644 index 000000000..78d03b869 --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_vllm.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden-value tests for ESM2 vLLM compatibility. + +Performs a fresh facebook -> TE export, then cross-validates embeddings across +vLLM, HuggingFace (exported checkpoint), and HuggingFace (nvidia Hub reference). + +vLLM's pooling runner returns last-token, L2-normalised embeddings by default, +so the HuggingFace paths replicate that post-processing for comparison. +""" + +import os +from pathlib import Path + +import numpy as np +import pytest +import torch +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + +from export import export_hf_checkpoint + + +EXPORT_TAG = "esm2_t6_8M_UR50D" +REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" +ESM2_MODEL_DIR = Path(__file__).resolve().parent.parent + +SEQUENCES = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", +] + + +def _last_token_l2(hidden_state: torch.Tensor) -> np.ndarray: + """Extract last-token hidden state and L2-normalise (matches vLLM pooling defaults).""" + vec = hidden_state[0, -1, :].cpu().float().numpy() + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + return vec + + +def _hf_embed(model_id: str, sequences: list[str], dtype=torch.float32) -> np.ndarray: + """Run HuggingFace inference and return last-token L2-normalised embeddings.""" + model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=dtype).eval() + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + + vecs = [] + with torch.no_grad(): + for seq in sequences: + inputs = tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = model(**inputs) + vecs.append(_last_token_l2(out.last_hidden_state)) + + del model, tokenizer + torch.cuda.empty_cache() + return np.stack(vecs) + + +def _vllm_embed(model_id: str, sequences: list[str]) -> np.ndarray: + """Run vLLM pooling inference and return embeddings.""" + engine = LLM( + model=model_id, + runner="pooling", + trust_remote_code=True, + dtype="float32", + enforce_eager=True, + max_num_batched_tokens=1026, + ) + outputs = engine.embed(sequences) + + vecs = [] + for output in outputs: + emb = output.outputs.embedding + if isinstance(emb, list): + emb = np.array(emb) + vecs.append(emb) + + del engine + return np.stack(vecs) + + +# ---- Fixtures ---- + + +@pytest.fixture(scope="session") +def exported_checkpoint(tmp_path_factory): + """Fresh facebook -> TE export. Session-scoped so it runs once.""" + export_dir = tmp_path_factory.mktemp("vllm_export") + original_cwd = os.getcwd() + os.chdir(ESM2_MODEL_DIR) + try: + export_hf_checkpoint(EXPORT_TAG, export_dir) + finally: + os.chdir(original_cwd) + return str(export_dir / EXPORT_TAG) + + +@pytest.fixture(scope="session") +def vllm_embeddings(exported_checkpoint): + """Embeddings from vLLM pooling runner on the exported checkpoint.""" + return _vllm_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_exported_embeddings(exported_checkpoint): + """Embeddings from HuggingFace on the exported checkpoint.""" + return _hf_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_reference_embeddings(): + """Embeddings from HuggingFace on the nvidia Hub model (ground truth).""" + return _hf_embed(REFERENCE_MODEL_ID, SEQUENCES) + + +# ---- Tests ---- + + +def test_vllm_vs_hf_exported(vllm_embeddings, hf_exported_embeddings): + """vLLM and native HuggingFace on the same exported checkpoint must match.""" + np.testing.assert_array_equal(vllm_embeddings, hf_exported_embeddings) + + +def test_vllm_vs_hf_reference(vllm_embeddings, hf_reference_embeddings): + """vLLM on exported checkpoint must match HuggingFace on nvidia Hub model.""" + np.testing.assert_array_equal(vllm_embeddings, hf_reference_embeddings) + + +def test_hf_exported_vs_hf_reference(hf_exported_embeddings, hf_reference_embeddings): + """Our exported checkpoint must produce identical results to the nvidia Hub model.""" + np.testing.assert_array_equal(hf_exported_embeddings, hf_reference_embeddings) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index fb2e4136d..05e22b9c0 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -70,7 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", - add_pooling_layer: bool = True, + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -101,9 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. - add_pooling_layer: Whether the base model should include a pooling layer. Set to - ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` - (which does not use a pooler). This avoids missing-weight errors in vLLM. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -405,6 +405,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel): _tied_weights_keys: ClassVar[dict[str, str]] = { "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -463,7 +464,8 @@ def forward( **kwargs, ) sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -494,15 +496,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) - with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -522,7 +524,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.fp8_autocast(enabled=False): + with transformer_engine.pytorch.autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -617,7 +619,11 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index fb2e4136d..05e22b9c0 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -70,7 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", - add_pooling_layer: bool = True, + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -101,9 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. - add_pooling_layer: Whether the base model should include a pooling layer. Set to - ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` - (which does not use a pooler). This avoids missing-weight errors in vLLM. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -405,6 +405,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel): _tied_weights_keys: ClassVar[dict[str, str]] = { "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -463,7 +464,8 @@ def forward( **kwargs, ) sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -494,15 +496,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) - with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -522,7 +524,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.fp8_autocast(enabled=False): + with transformer_engine.pytorch.autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -617,7 +619,11 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index fb2e4136d..05e22b9c0 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -70,7 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", - add_pooling_layer: bool = True, + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -101,9 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. - add_pooling_layer: Whether the base model should include a pooling layer. Set to - ``False`` for exported checkpoints that are saved from ``NVEsmForMaskedLM`` - (which does not use a pooler). This avoids missing-weight errors in vLLM. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -405,6 +405,7 @@ class NVEsmForMaskedLM(NVEsmPreTrainedModel): _tied_weights_keys: ClassVar[dict[str, str]] = { "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -463,7 +464,8 @@ def forward( **kwargs, ) sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) # Truncate logits back to original vocab_size if padding was used if self.config.padded_vocab_size != self.config.vocab_size: @@ -494,15 +496,15 @@ def __init__(self, config: NVEsmConfig): config (NVEsmConfig): The configuration of the model. """ super().__init__() - self.dense = transformer_engine.pytorch.Linear( - config.hidden_size, - config.hidden_size, - params_dtype=config.dtype, - device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", - init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), - ) + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) - with transformer_engine.pytorch.fp8_model_init(enabled=False): self.decoder = transformer_engine.pytorch.LayerNormLinear( config.hidden_size, config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, @@ -522,7 +524,7 @@ def forward(self, features, **kwargs): """ # Keep the last layers of the network in higher precision to avoid numerical instability. # Please see recipes/fp8_analysis/README.md for more details. - with transformer_engine.pytorch.fp8_autocast(enabled=False): + with transformer_engine.pytorch.autocast(enabled=False): x = self.dense(features) x = torch.nn.functional.gelu(x) x = self.decoder(x) @@ -617,7 +619,11 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths, dim=0) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) if self.layer_norm is not None: diff --git a/bionemo-recipes/vllm/Dockerfile b/bionemo-recipes/vllm/Dockerfile deleted file mode 100644 index f8b9067d6..000000000 --- a/bionemo-recipes/vllm/Dockerfile +++ /dev/null @@ -1,36 +0,0 @@ -# FROM nvcr.io/nvidia/vllm:26.01-py3 -FROM gitlab-master.nvidia.com:5005/dl/dgx/vllm:main-py3.43005406-devel -# using this because we need vllm >= 0.14 to work with Transformers v5. no released nvidia version with this yet. - -# The vLLM image has CUDA 13.1 runtime and nvcc, but missing dev headers (cusparse.h, nvtx, etc.) -# Install cuda-keyring to add NVIDIA's apt repo, then install the dev headers for transformer_engine -RUN apt-get update && apt-get install -y --no-install-recommends wget && \ - wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \ - dpkg -i cuda-keyring_1.1-1_all.deb && \ - rm cuda-keyring_1.1-1_all.deb && \ - apt-get update && apt-get install -y --no-install-recommends \ - cuda-nvtx-13-1 \ - cuda-cupti-dev-13-1 \ - cuda-nvml-dev-13-1 \ - libcusparse-dev-13-1 \ - libcusolver-dev-13-1 \ - libcufft-dev-13-1 \ - libnvjitlink-dev-13-1 \ - libnvjpeg-dev-13-1 \ - libcublasmp0-dev-cuda-13 \ - libcudnn9-cuda-13 \ - && rm -rf /var/lib/apt/lists/* - -# Install remaining dependencies -RUN --mount=type=cache,target=/root/.cache/pip \ - --mount=type=bind,source=requirements.txt,target=/requirements.txt \ - pip install -r /requirements.txt - -# Install transformer_engine from source (force build for CUDA 13.1, not pre-built cu12 wheel) -RUN pip install --no-build-isolation transformer_engine[pytorch] - -RUN pip install transformers[torch]==5.0.0 - - -WORKDIR /workspace/bionemo -COPY . . diff --git a/bionemo-recipes/vllm/README.md b/bionemo-recipes/vllm/README.md deleted file mode 100644 index 7fcb25c22..000000000 --- a/bionemo-recipes/vllm/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# vLLM inference for BioNeMo Models - -To build the image: - -```bash -docker build -t vllm . -``` - -Set `HF_TOKEN` in your environment to avoid getting rate limited. - -To launch a container: - -```bash -docker run -it --gpus all --network host --ipc=host -e HF_TOKEN --rm -v ${PWD}:/workspace/bionemo vllm /bin/bash -``` - -or use `launch.sh`. - -To test ESM2 inference using vLLM inside the container: - -```python -python test_esm2_golden_values.py -``` diff --git a/bionemo-recipes/vllm/launch.sh b/bionemo-recipes/vllm/launch.sh deleted file mode 100644 index e52a3ea53..000000000 --- a/bionemo-recipes/vllm/launch.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -# Convenience script to launch the vLLM container with the correct mounts and flags. -# Usage: ./launch.sh [--mount_dir] [--headless] -# Example: ./launch.sh vllm --mount_dir --headless - -MOUNT_DIR=false -HEADLESS=false -CONTAINER="" - -# Parse arguments -for arg in "$@"; do - case $arg in - --mount_dir) - MOUNT_DIR=true - ;; - --headless) - HEADLESS=true - ;; - *) - # First non-flag argument is the container name - if [ -z "$CONTAINER" ]; then - CONTAINER="$arg" - fi - ;; - esac -done - -if [ -z "$CONTAINER" ]; then - echo "Usage: $0 [--mount_dir] [--headless]" - echo "Example: $0 vllm --mount_dir --headless" - exit 1 -fi - -# Build docker run command -if [ "$HEADLESS" = true ]; then - DOCKER_CMD="docker run -itd --gpus all --network host --ipc=host -e HF_TOKEN --rm --name vllm_dev" -else - DOCKER_CMD="docker run -it --gpus all --network host --ipc=host -e HF_TOKEN --rm --name vllm_dev" -fi - -if [ "$MOUNT_DIR" = true ]; then - # Mount the project root (two levels up from this script) - PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" - DOCKER_CMD="$DOCKER_CMD -v ${PROJECT_ROOT}:/workspace/bionemo-framework" -fi - -DOCKER_CMD="$DOCKER_CMD $CONTAINER /bin/bash" - -exec $DOCKER_CMD diff --git a/bionemo-recipes/vllm/test_esm2_golden_values.py b/bionemo-recipes/vllm/test_esm2_golden_values.py deleted file mode 100644 index fde9347bf..000000000 --- a/bionemo-recipes/vllm/test_esm2_golden_values.py +++ /dev/null @@ -1,210 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End-to-end golden-value test for ESM2 vLLM compatibility. - -Performs a fresh facebook -> TE export, then cross-validates embeddings across -three backends on the same protein sequences: - -1. **vLLM** - freshly exported model loaded via ``LLM(runner="pooling")``. -2. **HF (exported)** - same exported checkpoint loaded via ``AutoModel``. -3. **HF (reference)**- nvidia Hub model loaded via ``AutoModel`` (ground truth). - -vLLM's pooling runner returns *last-token, L2-normalised* embeddings by default, -so the HuggingFace runs replicate that post-processing for an apples-to-apples comparison. -""" - -import os -import sys -from pathlib import Path - -import numpy as np -import torch -from transformers import AutoModel, AutoTokenizer -from vllm import LLM - - -# ---- Fresh export ---- -# The export script uses relative paths (modeling_esm_te.py, esm_fast_tokenizer, etc.) -# so we need to run it from the esm2 model directory. -ESM2_MODEL_DIR = Path(__file__).resolve().parent.parent / "models" / "esm2" -EXPORT_DIR = Path(__file__).resolve().parent / "exported_checkpoint" -EXPORT_TAG = "esm2_t6_8M_UR50D" - -sys.path.insert(0, str(ESM2_MODEL_DIR)) - - -def fresh_export() -> str: - """Run the full facebook -> TE export and return the path to the exported checkpoint.""" - from export import export_hf_checkpoint - - # export_hf_checkpoint uses relative paths, so temporarily chdir - original_cwd = os.getcwd() - os.chdir(ESM2_MODEL_DIR) - try: - EXPORT_DIR.mkdir(parents=True, exist_ok=True) - print(f"Exporting facebook/{EXPORT_TAG} -> {EXPORT_DIR / EXPORT_TAG}") - export_hf_checkpoint(EXPORT_TAG, EXPORT_DIR) - finally: - os.chdir(original_cwd) - - return str(EXPORT_DIR / EXPORT_TAG) - - -# ---- Configuration ---- -REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" - -SEQUENCES = [ - "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", - "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", -] - -RTOL, ATOL = 0, 0 - - -# ---- Helpers ---- - - -def last_token_l2(hidden_state: torch.Tensor) -> np.ndarray: - """Extract last-token hidden state and L2-normalise (matches vLLM pooling defaults).""" - vec = hidden_state[0, -1, :].cpu().float().numpy() - norm = np.linalg.norm(vec) - if norm > 1e-9: - vec = vec / norm - return vec - - -def hf_embed(model_id: str, sequences: list[str], dtype=torch.float32) -> np.ndarray: - """Run HuggingFace inference and return last-token L2-normalised embeddings.""" - model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=dtype).eval() - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) - - vecs = [] - with torch.no_grad(): - for seq in sequences: - inputs = tokenizer(seq, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - out = model(**inputs) - vecs.append(last_token_l2(out.last_hidden_state)) - - del model, tokenizer - torch.cuda.empty_cache() - return np.stack(vecs) - - -def vllm_embed(model_id: str, sequences: list[str]) -> np.ndarray: - """Run vLLM pooling inference and return embeddings.""" - engine = LLM( - model=model_id, - runner="pooling", - trust_remote_code=True, - dtype="float32", - enforce_eager=True, - max_num_batched_tokens=1026, - ) - outputs = engine.embed(sequences) - - vecs = [] - for output in outputs: - emb = output.outputs.embedding - if isinstance(emb, list): - emb = np.array(emb) - vecs.append(emb) - - del engine - return np.stack(vecs) - - -def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: - """Element-wise maximum absolute difference between two arrays.""" - return float(np.abs(a.astype(np.float64) - b.astype(np.float64)).max()) - - -def cosine_sim(a: np.ndarray, b: np.ndarray) -> float: - """Mean cosine similarity across rows.""" - sims = [] - for va, vb in zip(a, b): - dot = np.dot(va, vb) - na, nb = np.linalg.norm(va), np.linalg.norm(vb) - sims.append(dot / max(na * nb, 1e-12)) - return float(np.mean(sims)) - - -# ---- Main ---- - -if __name__ == "__main__": - print(f"GPUs: {torch.cuda.device_count()}") - - # Step 0: fresh export (facebook HF -> our TE format) - print("\n[0/3] Exporting checkpoint ...") - MODEL_ID = fresh_export() - - print(f"MODEL_ID: {MODEL_ID}") - print(f"REFERENCE_MODEL_ID: {REFERENCE_MODEL_ID}") - print(f"Sequences: {len(SEQUENCES)}") - - # 1) vLLM on exported model - print("\n[1/3] vLLM inference (exported model) ...") - emb_vllm = vllm_embed(MODEL_ID, SEQUENCES) - - # 2) HuggingFace on exported model - print("\n[2/3] HuggingFace inference (exported model) ...") - emb_hf_exported = hf_embed(MODEL_ID, SEQUENCES) - - # 3) HuggingFace on reference Hub model - print("\n[3/3] HuggingFace inference (reference model) ...") - emb_hf_reference = hf_embed(REFERENCE_MODEL_ID, SEQUENCES) - - # ---- Pairwise comparisons ---- - pairs = [ - ("vLLM (exported)", "HF (exported)", emb_vllm, emb_hf_exported), - ("vLLM (exported)", "HF (reference)", emb_vllm, emb_hf_reference), - ("HF (exported)", "HF (reference)", emb_hf_exported, emb_hf_reference), - ] - - # ---- Summary table ---- - header = f"{'Pair':<35} {'max |diff|':>14} {'mean |diff|':>14} {'cos sim':>12} {'exact':>7}" - sep = "-" * len(header) - print(f"\n{sep}") - print(header) - print(sep) - - for name_a, name_b, a, b in pairs: - diffs = np.abs(a.astype(np.float64) - b.astype(np.float64)) - label = f"{name_a} vs {name_b}" - exact = np.array_equal(a, b) - print( - f"{label:<35} {diffs.max():>14.8e} {diffs.mean():>14.8e} " - f"{cosine_sim(a, b):>12.10f} {'YES' if exact else 'NO':>7}" - ) - - print(sep) - print(f"Tolerance: rtol={RTOL}, atol={ATOL} (0 = exact match required)") - - # Per-sequence breakdown - short = {"vLLM (exported)": "vllm", "HF (exported)": "hf_exp", "HF (reference)": "hf_ref"} - print("\nPer-sequence max |diff|:") - for i in range(len(SEQUENCES)): - row = f" seq {i}:" - for name_a, name_b, a, b in pairs: - d = float(np.abs(a[i].astype(np.float64) - b[i].astype(np.float64)).max()) - row += f" {short[name_a]}_vs_{short[name_b]}={d:.8e}" - print(row) - - print(sep) - - # Cleanup - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() From 8e8a87c374001fd569963609e6ab913720cfe5c6 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 25 Feb 2026 15:58:13 +0000 Subject: [PATCH 8/9] remove unnecessary diff Signed-off-by: Gagan Kaushik --- bionemo-recipes/models/esm2/modeling_esm_te.py | 8 ++++---- .../esm2_accelerate_te/example_8m_checkpoint/esm_nv.py | 8 ++++---- .../esm2_native_te/example_8m_checkpoint/esm_nv.py | 8 ++++---- .../recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/bionemo-recipes/models/esm2/modeling_esm_te.py b/bionemo-recipes/models/esm2/modeling_esm_te.py index 05e22b9c0..cbd5f6159 100644 --- a/bionemo-recipes/models/esm2/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/modeling_esm_te.py @@ -612,6 +612,10 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -619,10 +623,6 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index 05e22b9c0..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -612,6 +612,10 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -619,10 +623,6 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index 05e22b9c0..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -612,6 +612,10 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -619,10 +623,6 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index 05e22b9c0..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -612,6 +612,10 @@ def forward( else: src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths # We need to find the number of masked tokens in each sequence in the padded batch. is_masked = (input_ids == self.mask_token_id).squeeze(0) n_masked_per_seq = torch.nested.nested_tensor_from_jagged( @@ -619,10 +623,6 @@ def forward( ).sum(1) mask_ratio_observed = n_masked_per_seq.float() / src_lengths scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) - if "cu_seq_lens_q_padded" in kwargs: - src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) - else: - src_lengths_padded = src_lengths reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) From 36cdbb2798de109593bf8bcbdc5e3b7351edfe43 Mon Sep 17 00:00:00 2001 From: Gagan Kaushik Date: Wed, 25 Feb 2026 16:33:58 +0000 Subject: [PATCH 9/9] fix recipes Signed-off-by: Gagan Kaushik --- .../recipes/esm2_native_te/tests/test_stop_and_go.py | 6 ++++-- bionemo-recipes/recipes/esm2_native_te/train_ddp.py | 3 ++- bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py | 3 ++- bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py | 5 ++--- bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py | 6 +++--- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py index 271e25723..bbb9fa01d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py @@ -74,8 +74,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.model.contact_head + del base.contact_head except AttributeError: pass @@ -156,8 +157,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True, dtype=torch.bfloat16) resumed_model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + resumed_base = resumed_model.model if hasattr(resumed_model, "model") else resumed_model.esm try: - del resumed_model.model.contact_head + del resumed_base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 3ad704916..2148d235f 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -84,8 +84,9 @@ def main(args: DictConfig) -> float | None: # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.model.contact_head + del base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index c7150f677..fef63aa98 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -111,8 +111,9 @@ def main(args: DictConfig) -> float | None: process_group=group_fsdp_cp, ) + base = model.module.model if hasattr(model.module, "model") else model.module.esm if args.cp_size > 1: - for i, transformer_layer in enumerate(model.module.model.encoder.layers): + for i, transformer_layer in enumerate(base.encoder.layers): logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}") transformer_layer.set_context_parallel_group( device_mesh["cp"].get_group(), diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 7a6917de5..f9e2a1d46 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -96,9 +96,8 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = ( - model.model.encoder.layers if hasattr(model.model.encoder, "layers") else model.model.encoder.layer - ) + base = model.model if hasattr(model, "model") else model.esm + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index d268441bf..d41fa509e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -112,10 +112,10 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # TE models use `model.model`, facebook HF models use `model.esm`. + base = model.model if hasattr(model, "model") else model.esm # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = ( - model.model.encoder.layers if hasattr(model.model.encoder, "layers") else model.model.encoder.layer - ) + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer # Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension). # FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension. for layer in transformer_stack: