Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions bionemo-recipes/models/esm2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ inputs = tokenizer(gfp_P42212, return_tensors="pt")
output = model(**inputs)
```

### Running inference with vLLM

The exported TE checkpoints on HuggingFace Hub are directly compatible with
[vLLM](https://github.com/vllm-project/vllm) (>= 0.14) as pooling/embedding models.
No conversion scripts or weight renaming are needed:

```python
from vllm import LLM

model = LLM(
model="nvidia/esm2_t6_8M_UR50D",
runner="pooling",
trust_remote_code=True,
enforce_eager=True,
max_num_batched_tokens=1026,
)

prompts = ["MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"]
outputs = model.embed(prompts)
print(outputs[0].outputs.embedding[:5])
```

See [tests/test_vllm.py](tests/test_vllm.py) for a full golden-value validation across
vLLM, native HuggingFace, and the nvidia Hub reference model.

## Recipe Links

Training recipes are available in the `bionemo-recipes/recipes/` directory:
Expand Down
36 changes: 18 additions & 18 deletions bionemo-recipes/models/esm2/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@


mapping = {
"esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight",
"esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias",
"esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight",
"esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias",
"esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight",
"esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias",
"esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight",
"esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias",
"esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight",
"esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias",
"esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight",
"esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias",
"lm_head.dense.weight": "lm_head.dense.weight",
"lm_head.dense.bias": "lm_head.dense.bias",
"lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight",
Expand Down Expand Up @@ -146,7 +146,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
"esm.encoder.layer.*.attention.self.key.weight",
"esm.encoder.layer.*.attention.self.value.weight",
),
target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight",
target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight",
)
def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
Expand All @@ -168,7 +168,7 @@ def _pack_qkv_weight(ctx: io.TransformCTX, query, key, value):
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias",
target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias",
)
def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):
"""Pad the embedding layer to the new input dimension."""
Expand All @@ -185,7 +185,7 @@ def _pack_qkv_bias(ctx: io.TransformCTX, query, key, value):


@io.state_transform(
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight",
source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight",
target_key=(
"esm.encoder.layer.*.attention.self.query.weight",
"esm.encoder.layer.*.attention.self.key.weight",
Expand Down Expand Up @@ -214,7 +214,7 @@ def _unpack_qkv_weight(ctx: io.TransformCTX, qkv_weight):


@io.state_transform(
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias",
source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias",
target_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
Expand Down Expand Up @@ -259,7 +259,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed):

_pad_embeddings = io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="esm.embeddings.word_embeddings.weight",
target_key="model.embeddings.word_embeddings.weight",
)(_pad_weights)

_pad_decoder_weights = io.state_transform(
Expand All @@ -268,7 +268,7 @@ def _pad_weights(ctx: io.TransformCTX, source_embed):
)(_pad_weights)

_unpad_embeddings = io.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
source_key="model.embeddings.word_embeddings.weight",
target_key="esm.embeddings.word_embeddings.weight",
)(_unpad_weights)

Expand Down
6 changes: 5 additions & 1 deletion bionemo-recipes/models/esm2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def export_hf_checkpoint(tag: str, export_path: Path):
model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}")
model_hf = AutoModel.from_pretrained(f"facebook/{tag}")
model_hf_masked_lm.esm.pooler = model_hf.pooler
model_te = convert_esm_hf_to_te(model_hf_masked_lm)

# Export without vocab padding so the checkpoint stores embeddings at the real
# vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding,
# which expects vocab_size-shaped weights.
model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None)
model_te.save_pretrained(export_path / tag)

tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation.
Expand Down
49 changes: 31 additions & 18 deletions bionemo-recipes/models/esm2/modeling_esm_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
max_seq_length: Optional[int] = None,
padded_vocab_size: Optional[int] = 64,
attn_mask_type: str = "padding",
add_pooling_layer: bool = False,
**kwargs,
):
"""Initialize the NVEsmConfig with additional TE-related config options.
Expand Down Expand Up @@ -100,6 +101,9 @@ def __init__(
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
to vocab_size. Must be greater than or equal to vocab_size.
attn_mask_type: The type of attention mask to use.
add_pooling_layer: Whether the base model should include a pooling layer.
Defaults to ``False`` because exported checkpoints do not contain pooler
weights. Set to ``True`` only if you have a checkpoint with pooler weights.
**kwargs: Additional config options to pass to EsmConfig.
"""
super().__init__(**kwargs)
Expand All @@ -111,6 +115,7 @@ def __init__(
self.micro_batch_size = micro_batch_size
self.max_seq_length = max_seq_length
self.attn_mask_type = attn_mask_type
self.add_pooling_layer = add_pooling_layer

# Set padded_vocab_size with default fallback to vocab_size
self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
Expand Down Expand Up @@ -231,7 +236,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel):
"""An abstract class to handle weights initialization and pretrained model loading."""

config_class = NVEsmConfig
base_model_prefix = "esm"
base_model_prefix = "model"
supports_gradient_checkpointing = False
accepts_loss_kwargs = False
_no_split_modules = (
Expand All @@ -247,11 +252,11 @@ def init_empty_weights(self):
if hasattr(module, "reset_parameters"):
module.reset_parameters()

# The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
# The embeddings layer is the only non-TE layer in this model we need to deal with. We use
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
# deviation.
self.esm.embeddings.word_embeddings.to_empty(device="cuda")
self.esm.embeddings.apply(self._init_weights)
# deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel.
self.base_model.embeddings.word_embeddings.to_empty(device="cuda")
self.base_model.embeddings.apply(self._init_weights)

# Meta-device init seems to break weight tying, so we re-tie the weights here.
self.tie_weights()
Expand All @@ -276,14 +281,16 @@ def _init_weights(self, module):
super()._init_weights(module)

def state_dict(self, *args, **kwargs):
"""Override state_dict to filter out TransformerEngine's _extra_state keys.
"""Override state_dict to filter out non-loadable keys.

TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
These are filtered out to ensure checkpoints can be loaded with from_pretrained().
Filters out:
- ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5.
- ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed
in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates
over ``named_parameters``, not ``named_buffers``).
"""
state_dict = super().state_dict(*args, **kwargs)
# Filter out _extra_state keys which are TransformerEngine-specific and not loadable
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")}


class NVEsmModel(NVEsmPreTrainedModel):
Expand All @@ -292,16 +299,20 @@ class NVEsmModel(NVEsmPreTrainedModel):
This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
"""

def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True):
def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None):
"""Initialize a NVEsmModel.

Args:
config (NVEsmConfig): The configuration of the model.
add_pooling_layer (bool): Whether to add a pooling layer.
add_pooling_layer (bool): Whether to add a pooling layer. If ``None``,
reads ``config.add_pooling_layer`` (defaults to ``True``).
"""
super().__init__(config)
self.config = config

if add_pooling_layer is None:
add_pooling_layer = getattr(config, "add_pooling_layer", True)

# Ensure pad_token_id is set properly, defaulting to 0 if not specified
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
config.pad_token_id = 0
Expand Down Expand Up @@ -391,8 +402,10 @@ def forward(
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""

_tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
_do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're deleting _do_not_quantize? we need that

_tied_weights_keys: ClassVar[dict[str, str]] = {
"lm_head.decoder.weight": "model.embeddings.word_embeddings.weight"
}
_do_not_quantize = ("lm_head.dense", "lm_head.decoder")

def __init__(self, config: NVEsmConfig):
"""Initialize a NVEsmForMaskedLM.
Expand All @@ -408,7 +421,7 @@ def __init__(self, config: NVEsmConfig):
"bi-directional self-attention."
)

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.model = NVEsmModel(config, add_pooling_layer=False)
self.lm_head = NVEsmLMHead(config)

self.post_init()
Expand Down Expand Up @@ -443,7 +456,7 @@ def forward(
Returns:
MaskedLMOutput: The output of the model.
"""
outputs = self.esm(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down Expand Up @@ -633,7 +646,7 @@ def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.model = NVEsmModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = transformer_engine.pytorch.Linear(
config.hidden_size,
Expand All @@ -659,7 +672,7 @@ def forward(
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.esm(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down
1 change: 1 addition & 0 deletions bionemo-recipes/models/esm2/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ torch
torchao!=0.14.0
transformer_engine[pytorch]
transformers
vllm
10 changes: 5 additions & 5 deletions bionemo-recipes/models/esm2/tests/test_cp_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def test_context_parallel_equivalence_2process():

# Sample gradients from a few layers for comparison
sample_layers = [
model.esm.encoder.layers[0].self_attention.core_attention,
model.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.model.encoder.layers[0].self_attention.core_attention,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the rename here from esm -> models?

model.model.encoder.layers[0].self_attention.layernorm_qkv,
]

# Now grab the gradients from the sample layers
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_context_parallel_equivalence_2process():
cp_world_size = torch.distributed.get_world_size(group=cp_group)

# Set up context parallelism for each layer
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
for i, transformer_layer in enumerate(model.module.model.encoder.layers):
transformer_layer.set_context_parallel_group(
cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream()
)
Expand Down Expand Up @@ -347,8 +347,8 @@ def test_context_parallel_equivalence_2process():
# Capture gradients from the same layers in the CP model
# Note: DDP wraps the model with 'module.' prefix
sample_layers_cp = [
model.module.esm.encoder.layers[0].self_attention.core_attention,
model.module.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.module.model.encoder.layers[0].self_attention.core_attention,
model.module.model.encoder.layers[0].self_attention.layernorm_qkv,
]

gradients_cp = {}
Expand Down
10 changes: 5 additions & 5 deletions bionemo-recipes/models/esm2/tests/test_cp_thd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def test_context_parallel_equivalence_2process():

# Sample gradients from a few layers for comparison
sample_layers = [
model.esm.encoder.layers[0].self_attention.core_attention,
model.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.model.encoder.layers[0].self_attention.core_attention,
model.model.encoder.layers[0].self_attention.layernorm_qkv,
]

# Now grab the gradients from the sample layers
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_context_parallel_equivalence_2process():
cp_world_size = torch.distributed.get_world_size(group=cp_group)

# Set up context parallelism for each layer
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
for i, transformer_layer in enumerate(model.module.model.encoder.layers):
transformer_layer.set_context_parallel_group(
cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream()
)
Expand Down Expand Up @@ -344,8 +344,8 @@ def test_context_parallel_equivalence_2process():
# Capture gradients from the same layers in the CP model
# Note: DDP wraps the model with 'module.' prefix
sample_layers_cp = [
model.module.esm.encoder.layers[0].self_attention.core_attention,
model.module.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.module.model.encoder.layers[0].self_attention.core_attention,
model.module.model.encoder.layers[0].self_attention.layernorm_qkv,
]

gradients_cp = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def is_main_process(self) -> bool:
model = NVEsmForMaskedLM(config)

if args.strategy is Strategy.FSDP2:
for layer in model.esm.encoder.layers:
for layer in model.model.encoder.layers:
fully_shard(layer, mesh=device_mesh["dp"])
fully_shard(model, mesh=device_mesh["dp"])
model.to(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis
revision="c731040f",
)
model = NVEsmForMaskedLM(config)
transformer_layers = model.esm.encoder.layers
transformer_layers = model.model.encoder.layers
else:
model = AutoModelForMaskedLM.from_pretrained(
"facebook/esm2_t6_8M_UR50D",
Expand Down
Loading
Loading