diff --git a/bionemo-recipes/models/esm2/README.md b/bionemo-recipes/models/esm2/README.md index a0f8effbb..2f7a917dd 100644 --- a/bionemo-recipes/models/esm2/README.md +++ b/bionemo-recipes/models/esm2/README.md @@ -60,6 +60,31 @@ inputs = tokenizer(gfp_P42212, return_tensors="pt") output = model(**inputs) ``` +### Running inference with vLLM + +The exported TE checkpoints on HuggingFace Hub are directly compatible with +[vLLM](https://github.com/vllm-project/vllm) (>= 0.14) as pooling/embedding models. +No conversion scripts or weight renaming are needed: + +```python +from vllm import LLM + +model = LLM( + model="nvidia/esm2_t6_8M_UR50D", + runner="pooling", + trust_remote_code=True, + enforce_eager=True, + max_num_batched_tokens=1026, +) + +prompts = ["MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"] +outputs = model.embed(prompts) +print(outputs[0].outputs.embedding[:5]) +``` + +See [tests/test_vllm.py](tests/test_vllm.py) for a full golden-value validation across +vLLM, native HuggingFace, and the nvidia Hub reference model. + ## Recipe Links Training recipes are available in the `bionemo-recipes/recipes/` directory: diff --git a/bionemo-recipes/models/esm2/convert.py b/bionemo-recipes/models/esm2/convert.py index 08a1729ff..69f55de9e 100644 --- a/bionemo-recipes/models/esm2/convert.py +++ b/bionemo-recipes/models/esm2/convert.py @@ -23,18 +23,18 @@ mapping = { - "esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight", - "esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias", - "esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", - "esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", - "esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight", - "esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias", - "esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight", - "esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias", - "esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight", - "esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias", - "esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight", - "esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias", + "esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight", + "esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias", + "esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight", + "esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias", + "esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight", + "esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias", "lm_head.dense.weight": "lm_head.dense.weight", "lm_head.dense.bias": "lm_head.dense.bias", "lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight", @@ -146,7 +146,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: "esm.encoder.layer.*.attention.self.key.weight", "esm.encoder.layer.*.attention.self.value.weight", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", ) def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" @@ -168,7 +168,7 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value): "esm.encoder.layer.*.attention.self.key.bias", "esm.encoder.layer.*.attention.self.value.bias", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", ) def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): """Pad the embedding layer to the new input dimension.""" @@ -185,7 +185,7 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value): @io.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", target_key=( "esm.encoder.layer.*.attention.self.query.weight", "esm.encoder.layer.*.attention.self.key.weight", @@ -214,7 +214,7 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight): @io.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", target_key=( "esm.encoder.layer.*.attention.self.query.bias", "esm.encoder.layer.*.attention.self.key.bias", @@ -259,7 +259,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed): _pad_embeddings = io.state_transform( source_key="esm.embeddings.word_embeddings.weight", - target_key="esm.embeddings.word_embeddings.weight", + target_key="model.embeddings.word_embeddings.weight", )(_pad_weights) _pad_decoder_weights = io.state_transform( @@ -268,7 +268,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed): )(_pad_weights) _unpad_embeddings = io.state_transform( - source_key="esm.embeddings.word_embeddings.weight", + source_key="model.embeddings.word_embeddings.weight", target_key="esm.embeddings.word_embeddings.weight", )(_unpad_weights) diff --git a/bionemo-recipes/models/esm2/export.py b/bionemo-recipes/models/esm2/export.py index 12f13e45a..748e46784 100644 --- a/bionemo-recipes/models/esm2/export.py +++ b/bionemo-recipes/models/esm2/export.py @@ -71,7 +71,11 @@ def export_hf_checkpoint(tag: str, export_path: Path): model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") model_hf = AutoModel.from_pretrained(f"facebook/{tag}") model_hf_masked_lm.esm.pooler = model_hf.pooler - model_te = convert_esm_hf_to_te(model_hf_masked_lm) + + # Export without vocab padding so the checkpoint stores embeddings at the real + # vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding, + # which expects vocab_size-shaped weights. + model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) model_te.save_pretrained(export_path / tag) tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation. diff --git a/bionemo-recipes/models/esm2/modeling_esm_te.py b/bionemo-recipes/models/esm2/modeling_esm_te.py index d4ee0845e..cbd5f6159 100644 --- a/bionemo-recipes/models/esm2/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/modeling_esm_te.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,10 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +456,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -633,7 +646,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +672,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/models/esm2/requirements.txt b/bionemo-recipes/models/esm2/requirements.txt index b4c358680..9f2bb0e09 100644 --- a/bionemo-recipes/models/esm2/requirements.txt +++ b/bionemo-recipes/models/esm2/requirements.txt @@ -9,3 +9,4 @@ torch torchao!=0.14.0 transformer_engine[pytorch] transformers +vllm diff --git a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py index 2af776b88..5e9d7f96b 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py @@ -209,8 +209,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -262,7 +262,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -347,8 +347,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_cp_thd.py b/bionemo-recipes/models/esm2/tests/test_cp_thd.py index c17618ba9..98dd62b74 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_thd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_thd.py @@ -200,8 +200,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -253,7 +253,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -344,8 +344,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index 57fd3f9fb..1018875e7 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -165,7 +165,7 @@ def is_main_process(self) -> bool: model = NVEsmForMaskedLM(config) if args.strategy is Strategy.FSDP2: - for layer in model.esm.encoder.layers: + for layer in model.model.encoder.layers: fully_shard(layer, mesh=device_mesh["dp"]) fully_shard(model, mesh=device_mesh["dp"]) model.to(device) diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py index 8347f0f1e..ab96c6590 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py @@ -193,7 +193,7 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis revision="c731040f", ) model = NVEsmForMaskedLM(config) - transformer_layers = model.esm.encoder.layers + transformer_layers = model.model.encoder.layers else: model = AutoModelForMaskedLM.from_pretrained( "facebook/esm2_t6_8M_UR50D", diff --git a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py index 18e21ffcd..2099636f8 100644 --- a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py @@ -81,7 +81,7 @@ def get_upstream_model_class(self) -> Type[PreTrainedModel]: def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: """Return the list of transformer layers.""" - return list(model.esm.encoder.layers) # type: ignore + return list(model.model.encoder.layers) # type: ignore def get_reference_model_no_weights(self) -> PreTrainedModel: """For checkpoint conversion tests to pass, we need to remove the unused contact head.""" @@ -195,7 +195,7 @@ def test_convert_state_dict_explicit_check(self): # Check packed QKV weights for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.weight" + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.weight" v = [ f"esm.encoder.layer.{i}.attention.self.query.weight", f"esm.encoder.layer.{i}.attention.self.key.weight", @@ -217,7 +217,7 @@ def test_convert_state_dict_explicit_check(self): # Check packed QKV biases for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.bias" + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.bias" v = [ f"esm.encoder.layer.{i}.attention.self.query.bias", f"esm.encoder.layer.{i}.attention.self.key.bias", @@ -243,7 +243,7 @@ def test_convert_state_dict_explicit_check(self): torch.testing.assert_close( _pad_weights(ctx_mock, model_hf.state_dict()["esm.embeddings.word_embeddings.weight"]), - model_te.state_dict()["esm.embeddings.word_embeddings.weight"], + model_te.state_dict()["model.embeddings.word_embeddings.weight"], ) torch.testing.assert_close( _pad_weights(ctx_mock, model_hf.state_dict()["lm_head.decoder.weight"]), @@ -254,7 +254,7 @@ def test_convert_state_dict_explicit_check(self): model_te.state_dict()["lm_head.decoder.bias"], ) - te_state_dict_keys.remove("esm.embeddings.word_embeddings.weight") + te_state_dict_keys.remove("model.embeddings.word_embeddings.weight") te_state_dict_keys.remove("lm_head.decoder.weight") te_state_dict_keys.remove("lm_head.decoder.bias") @@ -267,6 +267,6 @@ def test_convert_state_dict_explicit_check(self): ) assert ( - model_te.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() + model_te.state_dict()["model.embeddings.word_embeddings.weight"].data_ptr() == model_te.state_dict()["lm_head.decoder.weight"].data_ptr() ) diff --git a/bionemo-recipes/models/esm2/tests/test_vllm.py b/bionemo-recipes/models/esm2/tests/test_vllm.py new file mode 100644 index 000000000..78d03b869 --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/test_vllm.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden-value tests for ESM2 vLLM compatibility. + +Performs a fresh facebook -> TE export, then cross-validates embeddings across +vLLM, HuggingFace (exported checkpoint), and HuggingFace (nvidia Hub reference). + +vLLM's pooling runner returns last-token, L2-normalised embeddings by default, +so the HuggingFace paths replicate that post-processing for comparison. +""" + +import os +from pathlib import Path + +import numpy as np +import pytest +import torch +from transformers import AutoModel, AutoTokenizer +from vllm import LLM + +from export import export_hf_checkpoint + + +EXPORT_TAG = "esm2_t6_8M_UR50D" +REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" +ESM2_MODEL_DIR = Path(__file__).resolve().parent.parent + +SEQUENCES = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", +] + + +def _last_token_l2(hidden_state: torch.Tensor) -> np.ndarray: + """Extract last-token hidden state and L2-normalise (matches vLLM pooling defaults).""" + vec = hidden_state[0, -1, :].cpu().float().numpy() + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + return vec + + +def _hf_embed(model_id: str, sequences: list[str], dtype=torch.float32) -> np.ndarray: + """Run HuggingFace inference and return last-token L2-normalised embeddings.""" + model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=dtype).eval() + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + + vecs = [] + with torch.no_grad(): + for seq in sequences: + inputs = tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = model(**inputs) + vecs.append(_last_token_l2(out.last_hidden_state)) + + del model, tokenizer + torch.cuda.empty_cache() + return np.stack(vecs) + + +def _vllm_embed(model_id: str, sequences: list[str]) -> np.ndarray: + """Run vLLM pooling inference and return embeddings.""" + engine = LLM( + model=model_id, + runner="pooling", + trust_remote_code=True, + dtype="float32", + enforce_eager=True, + max_num_batched_tokens=1026, + ) + outputs = engine.embed(sequences) + + vecs = [] + for output in outputs: + emb = output.outputs.embedding + if isinstance(emb, list): + emb = np.array(emb) + vecs.append(emb) + + del engine + return np.stack(vecs) + + +# ---- Fixtures ---- + + +@pytest.fixture(scope="session") +def exported_checkpoint(tmp_path_factory): + """Fresh facebook -> TE export. Session-scoped so it runs once.""" + export_dir = tmp_path_factory.mktemp("vllm_export") + original_cwd = os.getcwd() + os.chdir(ESM2_MODEL_DIR) + try: + export_hf_checkpoint(EXPORT_TAG, export_dir) + finally: + os.chdir(original_cwd) + return str(export_dir / EXPORT_TAG) + + +@pytest.fixture(scope="session") +def vllm_embeddings(exported_checkpoint): + """Embeddings from vLLM pooling runner on the exported checkpoint.""" + return _vllm_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_exported_embeddings(exported_checkpoint): + """Embeddings from HuggingFace on the exported checkpoint.""" + return _hf_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_reference_embeddings(): + """Embeddings from HuggingFace on the nvidia Hub model (ground truth).""" + return _hf_embed(REFERENCE_MODEL_ID, SEQUENCES) + + +# ---- Tests ---- + + +def test_vllm_vs_hf_exported(vllm_embeddings, hf_exported_embeddings): + """vLLM and native HuggingFace on the same exported checkpoint must match.""" + np.testing.assert_array_equal(vllm_embeddings, hf_exported_embeddings) + + +def test_vllm_vs_hf_reference(vllm_embeddings, hf_reference_embeddings): + """vLLM on exported checkpoint must match HuggingFace on nvidia Hub model.""" + np.testing.assert_array_equal(vllm_embeddings, hf_reference_embeddings) + + +def test_hf_exported_vs_hf_reference(hf_exported_embeddings, hf_reference_embeddings): + """Our exported checkpoint must produce identical results to the nvidia Hub model.""" + np.testing.assert_array_equal(hf_exported_embeddings, hf_reference_embeddings) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,10 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +456,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -633,7 +646,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +672,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,10 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +456,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -633,7 +646,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +672,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py index a8b8afc6a..bbb9fa01d 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py @@ -74,8 +74,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.esm.contact_head + del base.contact_head except AttributeError: pass @@ -156,8 +157,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True, dtype=torch.bfloat16) resumed_model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + resumed_base = resumed_model.model if hasattr(resumed_model, "model") else resumed_model.esm try: - del resumed_model.esm.contact_head + del resumed_base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index 168d25b57..2148d235f 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -84,8 +84,9 @@ def main(args: DictConfig) -> float | None: # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.esm.contact_head + del base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index c5a8dad34..fef63aa98 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -111,8 +111,9 @@ def main(args: DictConfig) -> float | None: process_group=group_fsdp_cp, ) + base = model.module.model if hasattr(model.module, "model") else model.module.esm if args.cp_size > 1: - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(base.encoder.layers): logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}") transformer_layer.set_context_parallel_group( device_mesh["cp"].get_group(), diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index c74a5ad6c..f9e2a1d46 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -96,7 +96,8 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + base = model.model if hasattr(model, "model") else model.esm + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16 diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 6a824cc9f..d41fa509e 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -112,8 +112,10 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # TE models use `model.model`, facebook HF models use `model.esm`. + base = model.model if hasattr(model, "model") else model.esm # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer # Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension). # FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension. for layer in transformer_stack: diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index d4ee0845e..cbd5f6159 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -70,6 +70,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, **kwargs, ): """Initialize the NVEsmConfig with additional TE-related config options. @@ -100,6 +101,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. **kwargs: Additional config options to pass to EsmConfig. """ super().__init__(**kwargs) @@ -111,6 +115,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer # Set padded_vocab_size with default fallback to vocab_size self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -247,11 +252,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -276,14 +281,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -391,8 +402,10 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} - _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") def __init__(self, config: NVEsmConfig): """Initialize a NVEsmForMaskedLM. @@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -443,7 +456,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -633,7 +646,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -659,7 +672,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids,