diff --git a/bionemo-recipes/models/esmc/.ruff.toml b/bionemo-recipes/models/esmc/.ruff.toml new file mode 100644 index 0000000000..7e9a31bf5d --- /dev/null +++ b/bionemo-recipes/models/esmc/.ruff.toml @@ -0,0 +1 @@ +extend = "../.ruff.toml" diff --git a/bionemo-recipes/models/esmc/collator.py b/bionemo-recipes/models/esmc/collator.py new file mode 100644 index 0000000000..4f536845d4 --- /dev/null +++ b/bionemo-recipes/models/esmc/collator.py @@ -0,0 +1,911 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data collator for THD input format tests. + +This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import datasets +import torch +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollator, DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorWithFlattening: + """Data collator that wraps a DataCollatorForLanguageModeling and flattens inputs for flash-attention. + + This collator enables efficient training on batches containing variable-length sequences, by first flattening + (packing) multiple input sequences into a single contiguous tensor without padding between sequences. Then, it + applies masked language modeling (MLM) masking using the provided DataCollatorForLanguageModeling instance. + + The collator also generates metadata required for Flash Attention or context-parallel attention: + - `cu_seq_lens_q` and `cu_seq_lens_k` tensors, denoting cumulative sequence lengths so that sequence boundaries + within the packed tensor are known during attention computation. + + Optionally, the collator can: + - Pad the total number of tokens in the batch to be divisible by `pad_to_multiple_of` (by appending a mock + sequence). + - Pad each individual sequence to be divisible by `pad_sequences_to_be_divisible_by` if provided. + + Only PyTorch tensors (`return_tensors="pt"`) are supported. + + Args: + collator (DataCollatorForLanguageModeling): The collator to use for MLM masking. This is a captive + collator and should be constructed externally and passed in. + return_position_ids (bool): Whether to return position ids (default False). + pad_to_multiple_of (int, optional): If set, pads the total sequence length to be divisible by this number. + pad_sequences_to_be_divisible_by (int, optional): If set, each individual sequence is padded to this value. + separator_id (int, optional): A label to insert between sequences, typically should be -100 for causal LM. + + Example: + >>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + >>> mlm_collator = DataCollatorForLanguageModeling(tokenizer) + >>> flat_collator = DataCollatorWithFlattening( + ... collator=mlm_collator, + ... pad_to_multiple_of=8, + ... ) + >>> + >>> # Input: variable length protein sequences + >>> sequences = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... {"input_ids": [0, 12, 13, 2]}, # 4 tokens + ... ] # Total: 15 tokens + >>> batch = flat_collator(sequences) + >>> print(batch['input_ids'].shape) # torch.Size([1, 16]) + >>> print(batch['labels'].shape) # torch.Size([1, 16]) + >>> print(batch['cu_seq_lens_q']) # tensor([0, 5, 11, 15, 16], dtype=torch.int32) + + Note: + The output is a THD-format (Total, Height, Depth) batch, where all input sequences are packed without + inter-sequence padding. Sequence boundaries are preserved using `cu_seq_lens_q`/`cu_seq_lens_k`, enabling + Flash Attention or context-parallelism without traditional attention masks. + """ + + collator: DataCollatorForLanguageModeling + return_position_ids: bool = False + pad_to_multiple_of: int | None = None + pad_sequences_to_be_divisible_by: int | None = None + separator_id: int | None = None + + def __post_init__(self): + """Ensure padding options are not used together.""" + if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") + + def __call__(self, features, return_tensors=None): + """Process a batch of variable-length sequences for Flash Attention with MLM. + + This method performs the following steps: + 1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata + 2. Applies MLM masking to the flattened sequence while preserving special tokens + 3. Optionally pads to a multiple of a specified number for hardware optimization + + Args: + features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing + 'input_ids' and optionally 'attention_mask'. Example: + [ + {"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1 + {"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2 + {"input_ids": [0, 12, 13, 2]} # Protein sequence 3 + ] + return_tensors (str, optional): Format for returned tensors. Only "pt" (PyTorch) + is supported. Defaults to None (uses collator default). + + Returns: + Dict[str, torch.Tensor]: Batch dictionary containing: + - input_ids (torch.Tensor): Flattened and MLM-masked token sequences. + Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths + (plus padding if pad_to_multiple_of is specified). + - labels (torch.Tensor): MLM labels with -100 for non-masked tokens and + original token IDs for masked positions. Same shape as input_ids. + - cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries. + Shape: [num_sequences + 1] or [num_sequences + 2] if padding is added. + Example: [0, 5, 11, 15] or [0, 5, 11, 15, 16] with padding. + - cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys. + Same as cu_seq_lens_q for self-attention. + - max_length_q (int): Maximum sequence length in the batch. + - max_length_k (int): Same as max_length_q for self-attention. + - attention_mask (torch.Tensor): Attention mask with 1s for actual tokens + and 0s for padding tokens (if any). + + Raises: + NotImplementedError: If return_tensors is not "pt". + + Example: + >>> # Input features + >>> features = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... ] + >>> + >>> batch = collator(features) + >>> + >>> # Output shapes and values + >>> batch['input_ids'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['labels'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['cu_seq_lens_q'] # tensor([0, 5, 11], dtype=torch.int32) or larger + + Note: + The output is in THD (Total, Height, Depth) format with batch_size=1 and + sequence_length=total_tokens, optimized for Flash Attention's variable-length + sequence processing capabilities. When pad_to_multiple_of is used, an additional + mock sequence is appended to reach the desired total length. + """ + # Perform the masking with the BSHD collator. + bshd_batch = self.collator(features) + + # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values. + packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids) + + # Get the masked input_ids and labels from the BSHD batch. + masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + + if self.separator_id is not None: + masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id + + # Update the packed batch with the masked input_ids and labels. + packed_batch["input_ids"] = masked_input_ids + packed_batch["labels"] = masked_labels + + if self.pad_to_multiple_of is not None: + packed_batch = self._pad_batch_to_multiple_of(packed_batch) + + elif self.pad_sequences_to_be_divisible_by is not None: + packed_batch = self._pad_sequences_to_be_divisible_by(packed_batch) + + return packed_batch + + def _pad_batch_to_multiple_of(self, batch): + """Add a mock sequence to make the total number of tokens divisible by pad_to_multiple_of.""" + # Ensure token_pad is an integer, defaulting to 1 if pad_token_id is None or invalid + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_to_multiple_of is not None, "pad_to_multiple_of must be set" + + return _pt_pad_to_multiple_of( + batch, + self.pad_to_multiple_of, + token_pad=pad_token_id, + label_pad=-100, + ) + + def _pad_sequences_to_be_divisible_by(self, batch): + """Pad individual sequences using cu_seq_lens_*_padded for context parallelism.""" + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_sequences_to_be_divisible_by is not None, "pad_sequences_to_be_divisible_by must be set" + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + batch["input_ids"], + batch["labels"], + batch["cu_seq_lens_q"], + self.pad_sequences_to_be_divisible_by, + padding_token_id=pad_token_id, + padding_label_id=-100, + ) + + batch["input_ids"] = input_ids_padded.unsqueeze(0) + batch["labels"] = labels_padded.unsqueeze(0) + batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + batch["pad_between_seqs"] = True + return batch + + +@dataclass +class TokenPackingDataset(torch.utils.data.IterableDataset): + """Dataset that uses sequence packing to construct batches with variable length up to a maximum number of tokens.""" + + dataset: datasets.IterableDataset + """Dataset to pack.""" + max_tokens_per_batch: int + """Maximum number of tokens per batch.""" + drop_last: bool = True + """Whether to drop the last batch if it's less than max_length.""" + split_samples: bool = False + """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + + def __iter__(self): + """Yield batches of samples, each with a variable number of tokens up to the maximum length. + + When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting + the final sample if needed. The remaining tokens from the split sample start the next batch. + + Returns: + A generator of batches of samples, each with a variable number of tokens up to the maximum length. + """ + samples = [] + current_length = 0 + for sample in iter(self.dataset): + current_length += len(sample["input_ids"]) + if current_length == self.max_tokens_per_batch: + yield [*samples, sample] + samples = [] + current_length = 0 + + elif current_length > self.max_tokens_per_batch: + if not self.split_samples: + # If we are not splitting samples, we can just yield the current batch (before this sample) and + # start a new one. + yield samples + samples = [sample] + + else: + # Calculate how many tokens are already in the batch + tokens_in_batch = current_length - len(sample["input_ids"]) + # Calculate how many tokens we can fit from this sample + tokens_available = self.max_tokens_per_batch - tokens_in_batch + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] + + current_length = len(samples[0]["input_ids"]) + else: + samples.append(sample) + + if not self.drop_last and samples: + yield samples + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset.""" + self.dataset.set_epoch(epoch) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + # For BSHD context parallelism, we can't handle padding, so we remove the attention mask. + del batch_shard["attention_mask"] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length * round(max_length / 64) + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, StopIteration): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. Stores result in _prefetch_result.""" + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except Exception: + # Process group may have been destroyed; signal stop. + self._prefetch_result = StopIteration() + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a sample dictionary at a specified number of tokens. + + This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens, + and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask, + token_type_ids, labels, etc.) are split accordingly. + + Args: + sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc. + num_tokens: Number of tokens to include in the first part of the split. + + Returns: + A tuple of two dictionaries: (first_part, remaining_part), where: + - first_part contains the first `num_tokens` tokens from each sequence field + - remaining_part contains the remaining tokens from each sequence field + + Example: + >>> sample = { + ... "input_ids": [0, 5, 6, 7, 8, 9, 2], + ... "attention_mask": [1, 1, 1, 1, 1, 1, 1], + ... "labels": [0, 5, 6, 7, 8, 9, 2] + ... } + >>> first, remaining = split_sample_by_num_tokens(sample, 3) + >>> first["input_ids"] # [0, 5, 6] + >>> remaining["input_ids"] # [7, 8, 9, 2] + """ + sample_length = len(sample["input_ids"]) + if num_tokens >= sample_length: + raise ValueError( + f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample" + ) + if num_tokens <= 0: + raise ValueError(f"num_tokens ({num_tokens}) must be positive") + + first_part = {} + remaining_part = {} + + # Fields that should be split by tokens (sequence fields) + sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"] + + for key, value in sample.items(): + if key in sequence_fields: + # Handle both list and tensor inputs + if isinstance(value, torch.Tensor): + first_part[key] = value[:num_tokens].clone() + remaining_part[key] = value[num_tokens:].clone() + elif isinstance(value, list): + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + else: + # For other types, try to slice if possible + try: + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + except (TypeError, IndexError): + # If slicing doesn't work, copy the value to both parts + # This handles fields that shouldn't be split (like metadata) + first_part[key] = value + remaining_part[key] = value + else: + # For non-sequence fields, copy to both parts + # This handles metadata fields that shouldn't be split + first_part[key] = value + remaining_part[key] = value + + return first_part, remaining_part + + +def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + is_labels_provided = "labels" in features[0] + sample_lengths = [len(sample["input_ids"]) for sample in features] + + batch = {} + batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths) + batch["input_ids"] = torch.tensor( + [[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64 + ) + if is_labels_provided: + batch["labels"] = torch.tensor( + [[label for sample in features for label in sample["labels"]]], dtype=torch.int64 + ) + cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32) + cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32) + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + if "attention_mask" in features[0]: + batch["attention_mask"] = torch.tensor( + [[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64 + ) + if return_position_ids: + batch["position_ids"] = torch.hstack( + [torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths] + ).unsqueeze(0) + + return batch + + +def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int): + """Pad a batch to a multiple of pad_to_multiple_of. + + Appends a mock sequence to the end of the batch with the given token_pad and label_pad to make the total number of + tokens divisible by pad_to_multiple_of. + + Args: + batch: Input batch, possibly containing labels and/or cu_seq_lens / max_length keys. + pad_to_multiple_of: Multiple to pad to. + token_pad: Token to pad with. + label_pad: Label to pad with. + + Returns: + Batch dictionary with padded input_ids, labels, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k. + """ + # Number of tokens we need to pad to make the total number of tokens divisible by pad_to_multiple_of + remainder = -batch["input_ids"].numel() % pad_to_multiple_of + + if remainder == 0: + return batch + + batch["input_ids"] = torch.cat( + [batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)], dim=1 + ) + + if "labels" in batch: + batch["labels"] = torch.cat( + [batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)], dim=1 + ) + + if "cu_seq_lens_q" in batch: + batch["cu_seq_lens_q"] = torch.cat( + [ + batch["cu_seq_lens_q"], + torch.tensor([batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype), + ], + dim=0, + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + if "max_length_q" in batch: + batch["max_length_q"] = max(batch["max_length_q"], remainder) + batch["max_length_k"] = batch["max_length_q"] + + if "attention_mask" in batch: + batch["attention_mask"] = torch.cat( + [batch["attention_mask"], torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype)], dim=1 + ) + + if "position_ids" in batch: + batch["position_ids"] = torch.cat( + [batch["position_ids"], torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0)], dim=1 + ) + + return batch + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank. + cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it + were executing on that rank without querying `torch.distributed.get_rank`. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError("Make sure the inputs are in THD format and padded correctly.") + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + + elif qvk_format == "bshd": + # BSHD format: [batch, seq_len, ...] + # Split along sequence dimension (dim=1) + # Each sequence is split into 2*cp_world_size chunks + # Each rank gets chunks at positions: [cp_rank, 2*cp_world_size - cp_rank - 1] + + def process_tensor_bshd(val): + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if chunk_size == 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + input_ids_padded = process_tensor_bshd(input_ids_padded) + labels_padded = process_tensor_bshd(labels_padded) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary fo THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/models/esmc/convert.py b/bionemo-recipes/models/esmc/convert.py new file mode 100644 index 0000000000..118e4b9b98 --- /dev/null +++ b/bionemo-recipes/models/esmc/convert.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Weight conversion between EvolutionaryScale ESMC and NVEsmc (TransformerEngine) formats. + +The ESMC reference model uses: +- QKV as a Sequential(LayerNorm, Linear) producing [Q||K||V] concatenated +- QK LayerNorm over full d_model dimension (960), bias=False +- Residue scaling: divides attn output and FFN output by sqrt(n_layers/36) +- FFN as Sequential(LayerNorm, Linear, SwiGLU, Linear) + +NVEsmc TE model uses: +- LayerNormLinear for QKV with [Q||K||V] concatenated weights (no interleaving) +- Full d_model QK LayerNorm via separate TE LayerNorm modules (exact match) +- DotProductAttention for flash/fused attention +- Residue scaling absorbed into output projection and fc2 weights +- LayerNormMLP for fused FFN +""" + +import math + +import torch + +from modeling_esmc_te import NVEsmcConfig, NVEsmcForMaskedLM + + +def convert_esmc_to_te(ref_state_dict: dict[str, torch.Tensor], config: NVEsmcConfig) -> NVEsmcForMaskedLM: + """Convert EvolutionaryScale ESMC weights to NVEsmc (TransformerEngine) format. + + This performs: + 1. Key remapping from ESMC ref format to TE format + 2. QK norm weight direct copy (both use full d_model LayerNorm) + 3. Residue scaling absorption into output projection and fc2 weights + + Args: + ref_state_dict: State dict from the EvolutionaryScale ESMC model (.pth file). + config: NVEsmcConfig for the target TE model. + + Returns: + NVEsmcForMaskedLM with converted weights. + """ + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + scale_factor = math.sqrt(num_layers / 36) + + te_state_dict = {} + + # Embedding + te_state_dict["esmc.embed_tokens.weight"] = ref_state_dict["embed.weight"] + + for layer_idx in range(num_layers): + ref_prefix = f"transformer.blocks.{layer_idx}" + te_prefix = f"esmc.layers.{layer_idx}" + + # Attention LayerNorm (pre-QKV) + te_state_dict[f"{te_prefix}.layernorm_qkv.layer_norm_weight"] = ref_state_dict[ + f"{ref_prefix}.attn.layernorm_qkv.0.weight" + ] + te_state_dict[f"{te_prefix}.layernorm_qkv.layer_norm_bias"] = ref_state_dict[ + f"{ref_prefix}.attn.layernorm_qkv.0.bias" + ] + + # QKV weight: direct copy (stored as [Q||K||V] concatenated, no interleaving) + te_state_dict[f"{te_prefix}.layernorm_qkv.weight"] = ref_state_dict[ + f"{ref_prefix}.attn.layernorm_qkv.1.weight" + ] + + # QK norm: direct copy (both use full d_model LayerNorm) + # Reference has bias=False, TE LayerNorm always has bias -> set to zeros + te_state_dict[f"{te_prefix}.q_norm.weight"] = ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"] + te_state_dict[f"{te_prefix}.q_norm.bias"] = torch.zeros( + hidden_size, dtype=ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"].dtype + ) + te_state_dict[f"{te_prefix}.k_norm.weight"] = ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"] + te_state_dict[f"{te_prefix}.k_norm.bias"] = torch.zeros( + hidden_size, dtype=ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"].dtype + ) + + # Attention output projection: absorb residue scaling + out_proj_weight = ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"] + te_state_dict[f"{te_prefix}.proj.weight"] = out_proj_weight / scale_factor + + # FFN LayerNorm (pre-MLP) + te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_weight"] = ref_state_dict[f"{ref_prefix}.ffn.0.weight"] + te_state_dict[f"{te_prefix}.layernorm_mlp.layer_norm_bias"] = ref_state_dict[f"{ref_prefix}.ffn.0.bias"] + + # FFN fc1 (gate + up proj concatenated for SwiGLU) + te_state_dict[f"{te_prefix}.layernorm_mlp.fc1_weight"] = ref_state_dict[f"{ref_prefix}.ffn.1.weight"] + + # FFN fc2 (down proj): absorb residue scaling + fc2_weight = ref_state_dict[f"{ref_prefix}.ffn.3.weight"] + te_state_dict[f"{te_prefix}.layernorm_mlp.fc2_weight"] = fc2_weight / scale_factor + + # Final LayerNorm + te_state_dict["esmc.norm.weight"] = ref_state_dict["transformer.norm.weight"] + # ESMC final norm has bias=False, but TE LayerNorm always has bias. Set to zeros. + te_state_dict["esmc.norm.bias"] = torch.zeros(hidden_size, dtype=ref_state_dict["transformer.norm.weight"].dtype) + + # Sequence head (RegressionHead): Linear -> GELU -> LayerNorm -> Linear + # ref: sequence_head.0 = Linear(960, 960) + te_state_dict["sequence_head.dense.weight"] = ref_state_dict["sequence_head.0.weight"] + te_state_dict["sequence_head.dense.bias"] = ref_state_dict["sequence_head.0.bias"] + # ref: sequence_head.2 = LayerNorm(960), sequence_head.3 = Linear(960, 64) + # TE LayerNormLinear fuses both + te_state_dict["sequence_head.decoder.layer_norm_weight"] = ref_state_dict["sequence_head.2.weight"] + te_state_dict["sequence_head.decoder.layer_norm_bias"] = ref_state_dict["sequence_head.2.bias"] + te_state_dict["sequence_head.decoder.weight"] = ref_state_dict["sequence_head.3.weight"] + te_state_dict["sequence_head.decoder.bias"] = ref_state_dict["sequence_head.3.bias"] + + # Build the TE model and load state dict + with torch.device("meta"): + model_te = NVEsmcForMaskedLM(config) + + target_state = model_te.state_dict() + + # Directly load the pre-transformed state dict + for key in list(target_state.keys()): + if key.endswith("_extra_state"): + continue + if key in te_state_dict: + target_state[key] = te_state_dict[key] + + # Load into model + model_te.load_state_dict(target_state, strict=False, assign=True) + model_te.tie_weights() + + return model_te + + +def convert_esmc_te_to_ref(model_te: NVEsmcForMaskedLM) -> dict[str, torch.Tensor]: + """Convert NVEsmc (TransformerEngine) weights back to EvolutionaryScale ESMC format. + + This reverses the transformations from convert_esmc_to_te: + 1. QK norm weight direct copy (both use full d_model) + 2. Residue scaling removal from projection weights + + Args: + model_te: NVEsmcForMaskedLM model with TE weights. + + Returns: + State dict in EvolutionaryScale ESMC format. + """ + config = model_te.config + num_layers = config.num_hidden_layers + scale_factor = math.sqrt(num_layers / 36) + + te_sd = model_te.state_dict() + ref_state_dict = {} + + # Embedding + ref_state_dict["embed.weight"] = te_sd["esmc.embed_tokens.weight"] + + for layer_idx in range(num_layers): + te_prefix = f"esmc.layers.{layer_idx}" + ref_prefix = f"transformer.blocks.{layer_idx}" + + # Attention LayerNorm + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.weight"] = te_sd[ + f"{te_prefix}.layernorm_qkv.layer_norm_weight" + ] + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.0.bias"] = te_sd[f"{te_prefix}.layernorm_qkv.layer_norm_bias"] + + # QKV weight: direct copy (no deinterleaving needed) + ref_state_dict[f"{ref_prefix}.attn.layernorm_qkv.1.weight"] = te_sd[f"{te_prefix}.layernorm_qkv.weight"] + + # QK norm: direct copy (both use full d_model LayerNorm) + ref_state_dict[f"{ref_prefix}.attn.q_ln.weight"] = te_sd[f"{te_prefix}.q_norm.weight"] + ref_state_dict[f"{ref_prefix}.attn.k_ln.weight"] = te_sd[f"{te_prefix}.k_norm.weight"] + + # Attention output projection: reverse scaling + ref_state_dict[f"{ref_prefix}.attn.out_proj.weight"] = te_sd[f"{te_prefix}.proj.weight"] * scale_factor + + # FFN LayerNorm + ref_state_dict[f"{ref_prefix}.ffn.0.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_weight"] + ref_state_dict[f"{ref_prefix}.ffn.0.bias"] = te_sd[f"{te_prefix}.layernorm_mlp.layer_norm_bias"] + + # FFN fc1 + ref_state_dict[f"{ref_prefix}.ffn.1.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc1_weight"] + + # FFN fc2: reverse scaling + ref_state_dict[f"{ref_prefix}.ffn.3.weight"] = te_sd[f"{te_prefix}.layernorm_mlp.fc2_weight"] * scale_factor + + # Final LayerNorm (no bias in ref) + ref_state_dict["transformer.norm.weight"] = te_sd["esmc.norm.weight"] + + # Sequence head + ref_state_dict["sequence_head.0.weight"] = te_sd["sequence_head.dense.weight"] + ref_state_dict["sequence_head.0.bias"] = te_sd["sequence_head.dense.bias"] + ref_state_dict["sequence_head.2.weight"] = te_sd["sequence_head.decoder.layer_norm_weight"] + ref_state_dict["sequence_head.2.bias"] = te_sd["sequence_head.decoder.layer_norm_bias"] + ref_state_dict["sequence_head.3.weight"] = te_sd["sequence_head.decoder.weight"] + ref_state_dict["sequence_head.3.bias"] = te_sd["sequence_head.decoder.bias"] + + return ref_state_dict diff --git a/bionemo-recipes/models/esmc/esmc_fast_tokenizer/special_tokens_map.json b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..d11a7b4028 --- /dev/null +++ b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/special_tokens_map.json @@ -0,0 +1,8 @@ +{ + "cls_token": "", + "eos_token": "", + "mask_token": "", + "pad_token": "", + "unk_token": "", + "additional_special_tokens": ["|"] +} diff --git a/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer.json b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer.json new file mode 100644 index 0000000000..f07d537bce --- /dev/null +++ b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer.json @@ -0,0 +1,177 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 31, + "content": "|", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 32, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 1 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [0], + "tokens": [""] + }, + "": { + "id": "", + "ids": [2], + "tokens": [""] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "L": 4, + "A": 5, + "G": 6, + "V": 7, + "S": 8, + "E": 9, + "R": 10, + "T": 11, + "I": 12, + "D": 13, + "P": 14, + "K": 15, + "Q": 16, + "N": 17, + "F": 18, + "Y": 19, + "M": 20, + "H": 21, + "W": 22, + "C": 23, + "X": 24, + "B": 25, + "U": 26, + "Z": 27, + "O": 28, + ".": 29, + "-": 30, + "|": 31, + "": 32 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer_config.json b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..e778ce7e39 --- /dev/null +++ b/bionemo-recipes/models/esmc/esmc_fast_tokenizer/tokenizer_config.json @@ -0,0 +1,69 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "31": { + "content": "|", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": ["|"], + "clean_up_tokenization_spaces": false, + "bos_token": "", + "cls_token": "", + "eos_token": "", + "mask_token": "", + "pad_token": "", + "unk_token": "", + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "tokenizer_class": "PreTrainedTokenizerFast", + "add_bos_token": true, + "add_eos_token": true, + "model_input_names": [ + "input_ids", + "attention_mask" + ] +} diff --git a/bionemo-recipes/models/esmc/export.py b/bionemo-recipes/models/esmc/export.py new file mode 100644 index 0000000000..2d7980d231 --- /dev/null +++ b/bionemo-recipes/models/esmc/export.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Export ESMC checkpoint to HuggingFace-compatible format with TransformerEngine layers. + +This script: +1. Loads the EvolutionaryScale ESMC-300M pretrained weights +2. Converts them to TransformerEngine format +3. Saves the converted model for use with HuggingFace's `AutoModel.from_pretrained()` +""" + +import json +import shutil +from pathlib import Path + +import convert +from modeling_esmc_te import AUTO_MAP, NVEsmcConfig + + +def export_esmc_checkpoint(export_path: Path): + """Export the ESMC-300M model to a TE checkpoint. + + Args: + export_path: Directory to save the exported checkpoint. + """ + from esm.pretrained import ESMC_300M_202412 + + # Load reference model on CPU to save GPU memory + ref_model = ESMC_300M_202412(device="cpu", use_flash_attn=False) + ref_state_dict = ref_model.state_dict() + del ref_model + + # Create config matching ESMC-300M architecture + config = NVEsmcConfig( + vocab_size=64, + hidden_size=960, + num_hidden_layers=30, + num_attention_heads=15, + intermediate_size=2560, + ) + + # Convert and save + model_te = convert.convert_esmc_to_te(ref_state_dict, config) + model_te.to("cpu") + model_te.save_pretrained(export_path) + + # Patch the config with auto_map + with open(export_path / "config.json") as f: + config_json = json.load(f) + + config_json["auto_map"] = AUTO_MAP + + with open(export_path / "config.json", "w") as f: + json.dump(config_json, f, indent=2, sort_keys=True) + + # Copy modeling file for standalone loading + shutil.copy(Path(__file__).parent / "modeling_esmc_te.py", export_path / "modeling_esmc_te.py") + + # Save tokenizer + from esm.tokenization import EsmSequenceTokenizer + + tokenizer = EsmSequenceTokenizer() + tokenizer.save_pretrained(export_path) + + +if __name__ == "__main__": + export_esmc_checkpoint(Path("checkpoint_export")) diff --git a/bionemo-recipes/models/esmc/modeling_esmc_te.py b/bionemo-recipes/models/esmc/modeling_esmc_te.py new file mode 100644 index 0000000000..595c5b2189 --- /dev/null +++ b/bionemo-recipes/models/esmc/modeling_esmc_te.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TransformerEngine-optimized ESMC (EvolutionaryScale ESM-Cambrian) model. + +This module provides HuggingFace-compatible ESMC model classes using NVIDIA's TransformerEngine +for optimized attention and MLP computation. The model is an encoder-only protein language model +with bidirectional attention, RoPE, SwiGLU activation, and full d_model QK LayerNorm. + +Unlike models that use TE's TransformerLayer (which applies per-head QK norm), this implementation +uses lower-level TE components (LayerNormLinear, DotProductAttention, LayerNormMLP) to apply QK +LayerNorm across the full hidden dimension, exactly matching the reference ESMC model. + +Reference: EvolutionaryScale's ESMC-300M (esm PyPI package). +""" + +from typing import ClassVar, Literal, Optional, TypedDict, Unpack + +import torch +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding, apply_rotary_pos_emb +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput +from transformers.modeling_utils import PreTrainedModel + + +AUTO_MAP = { + "AutoConfig": "modeling_esmc_te.NVEsmcConfig", + "AutoModel": "modeling_esmc_te.NVEsmcModel", + "AutoModelForMaskedLM": "modeling_esmc_te.NVEsmcForMaskedLM", +} + + +class TransformersKwargs(TypedDict): + """Transformers v4 does not export a TransformersKwargs class, so we define our own.""" + + cu_seq_lens_q: Optional[torch.Tensor] + cu_seq_lens_k: Optional[torch.Tensor] + max_length_q: Optional[int] + max_length_k: Optional[int] + pad_between_seqs: Optional[int] + cu_seqlens_q_padded: Optional[torch.Tensor] + cu_seqlens_k_padded: Optional[torch.Tensor] + + +class NVEsmcConfig(PretrainedConfig): + """Configuration for the NVEsmc TransformerEngine model.""" + + model_type: str = "nv_esmc" + + def __init__( + self, + vocab_size: int = 64, + hidden_size: int = 960, + num_hidden_layers: int = 30, + num_attention_heads: int = 15, + intermediate_size: int = 2560, + layer_norm_eps: float = 1e-5, + position_embedding_type: str = "rotary", + initializer_range: float = 0.02, + pad_token_id: int = 0, + # TE-specific options + attn_input_format: Literal["bshd", "thd"] = "bshd", + self_attn_mask_type: str = "padding", + tie_word_embeddings: bool = False, + **kwargs, + ): + """Initialize NVEsmcConfig. + + Args: + vocab_size: Vocabulary size (padded to 64 from real vocab of 33). + hidden_size: Dimension of hidden representations. + num_hidden_layers: Number of transformer layers. + num_attention_heads: Number of attention heads. + intermediate_size: FFN intermediate dimension (SwiGLU corrected). + layer_norm_eps: Layer normalization epsilon. + position_embedding_type: Type of position embedding (only "rotary" supported). + initializer_range: Standard deviation for weight initialization. + pad_token_id: Padding token ID. + attn_input_format: Attention input format for TE ("bshd" or "thd"). + self_attn_mask_type: Attention mask type ("padding" for bidirectional). + tie_word_embeddings: Whether to tie input/output embeddings. + **kwargs: Additional config options. + """ + super().__init__( + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.initializer_range = initializer_range + self.attn_input_format = attn_input_format + self.self_attn_mask_type = self_attn_mask_type + + +class EsmcTransformerBlock(nn.Module): + """Custom ESMC transformer block using lower-level TE components. + + This block implements full d_model QK LayerNorm (matching the reference ESMC model) + by using individual TE components instead of TE's TransformerLayer which only supports + per-head QK norm. + + Architecture: + 1. LayerNormLinear: pre-attention LayerNorm + QKV projection + 2. LayerNorm(d_model): full-dimension Q normalization + 3. LayerNorm(d_model): full-dimension K normalization + 4. RoPE application + 5. DotProductAttention: flash/fused attention + 6. Linear: output projection (residue scaling absorbed in weights) + 7. LayerNormMLP: pre-FFN LayerNorm + SwiGLU MLP (residue scaling absorbed in fc2) + """ + + def __init__(self, config: NVEsmcConfig, layer_idx: int): + """Initialize EsmcTransformerBlock.""" + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + device = "meta" if torch.get_default_device() == torch.device("meta") else "cuda" + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + # Pre-attention LayerNorm + QKV projection (fused) + self.layernorm_qkv = transformer_engine.pytorch.LayerNormLinear( + hidden_size, + 3 * hidden_size, + bias=False, + eps=config.layer_norm_eps, + params_dtype=config.torch_dtype, + device=device, + init_method=_init_method, + ) + + # Full d_model QK LayerNorm (matching reference model exactly) + self.q_norm = transformer_engine.pytorch.LayerNorm( + hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.torch_dtype, + device=device, + ) + self.k_norm = transformer_engine.pytorch.LayerNorm( + hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.torch_dtype, + device=device, + ) + + # Attention computation (flash/fused attention backends) + self.core_attention = transformer_engine.pytorch.DotProductAttention( + num_attention_heads=num_heads, + kv_channels=head_dim, + num_gqa_groups=num_heads, + attention_dropout=0, + qkv_format=config.attn_input_format, + attn_mask_type=config.self_attn_mask_type, + layer_number=layer_idx + 1, + ) + + # Output projection + self.proj = transformer_engine.pytorch.Linear( + hidden_size, + hidden_size, + bias=False, + params_dtype=config.torch_dtype, + device=device, + init_method=_init_method, + ) + + # FFN: pre-LayerNorm + SwiGLU MLP (fused) + self.layernorm_mlp = transformer_engine.pytorch.LayerNormMLP( + hidden_size, + config.intermediate_size, + bias=False, + eps=config.layer_norm_eps, + activation="swiglu", + params_dtype=config.torch_dtype, + device=device, + init_method=_init_method, + output_layer_init_method=_init_method, + ) + + self.num_heads = num_heads + self.head_dim = head_dim + self.attn_input_format = config.attn_input_format + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + pad_between_seqs: Optional[bool] = None, + ) -> torch.Tensor: + """Forward pass for a single transformer block. + + Args: + hidden_states: Input tensor [B, S, D] (BSHD) or [T, D] (THD). + attention_mask: Attention mask for BSHD format. + rotary_pos_emb: Precomputed rotary position embeddings. + cu_seqlens_q: Cumulative sequence lengths for queries (THD format). + cu_seqlens_kv: Cumulative sequence lengths for keys/values (THD format). + cu_seqlens_q_padded: Padded cumulative sequence lengths for queries. + cu_seqlens_kv_padded: Padded cumulative sequence lengths for keys/values. + max_seqlen_q: Maximum query sequence length (THD format). + max_seqlen_kv: Maximum key/value sequence length (THD format). + pad_between_seqs: Whether there is padding between sequences. + + Returns: + Output tensor with same shape as input. + """ + residual = hidden_states + + # Pre-attention LayerNorm + QKV projection + qkv = self.layernorm_qkv(hidden_states) # [*, 3*D] + q, k, v = qkv.chunk(3, dim=-1) # each [*, D] + + # Full d_model QK LayerNorm (matching reference model) + q = self.q_norm(q) + k = self.k_norm(k) + + # Reshape to multi-head format: [B, S, H, d_head] or [T, H, d_head] + head_shape = (*q.shape[:-1], self.num_heads, self.head_dim) + q = q.view(head_shape) + k = k.view(head_shape) + v = v.view(head_shape) + + # Apply RoPE + if rotary_pos_emb is not None: + tensor_format = "thd" if self.attn_input_format == "thd" else "bshd" + q = apply_rotary_pos_emb( + q, + rotary_pos_emb, + tensor_format=tensor_format, + cu_seqlens=cu_seqlens_q if tensor_format == "thd" else None, + ) + k = apply_rotary_pos_emb( + k, + rotary_pos_emb, + tensor_format=tensor_format, + cu_seqlens=cu_seqlens_kv if tensor_format == "thd" else None, + ) + + # Attention + attn_output = self.core_attention( + q, + k, + v, + attention_mask=attention_mask, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + pad_between_seqs=pad_between_seqs, + ) # [B, S, D] or [T, D] (DotProductAttention folds heads internally) + + # Output projection + attn_output = self.proj(attn_output) + + # Residual connection (residue scaling absorbed in proj weights) + hidden_states = residual + attn_output + + # FFN with pre-LayerNorm and residual (residue scaling absorbed in fc2 weights) + residual = hidden_states + hidden_states = residual + self.layernorm_mlp(hidden_states) + + return hidden_states + + +class NVEsmcPreTrainedModel(PreTrainedModel): + """Base class for NVEsmc models.""" + + config_class = NVEsmcConfig + base_model_prefix = "esmc" + _no_split_modules = ("EsmcTransformerBlock",) + _tied_weights_keys: ClassVar[dict[str, str]] = {} + + def init_empty_weights(self): + """Move model from meta device to CUDA and initialize weights.""" + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + self.esmc.embed_tokens.to_empty(device="cuda") + self.esmc.embed_tokens.apply(self._init_weights) + + # Meta-device init breaks weight tying, so re-tie. + self.tie_weights() + + def _init_weights(self, module): + """Initialize weights for standard pytorch modules. + + TE modules handle their own initialization through `init_method` and `reset_parameters`. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + return + + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def state_dict(self, *args, **kwargs): + """Override to filter out TE's _extra_state keys.""" + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVEsmcModel(NVEsmcPreTrainedModel): + """ESMC encoder model with TransformerEngine layers.""" + + def __init__(self, config: NVEsmcConfig): + """Initialize the NVEsmc model.""" + super().__init__(config) + self.config = config + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id, dtype=config.torch_dtype + ) + + self.layers = nn.ModuleList( + [EsmcTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + # Final LayerNorm (no bias in reference, but TE LayerNorm always has bias; set to zeros) + self.norm = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.torch_dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + self.gradient_checkpointing = False + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: + """Forward pass for the NVEsmc encoder model. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + inputs_embeds: Pre-computed input embeddings. + **kwargs: Additional keyword arguments (THD params, output_hidden_states, etc.). + + Returns: + BaseModelOutput with last_hidden_state and optional hidden_states. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + output_hidden_states = kwargs.get("output_hidden_states", False) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # Handle THD format conversion + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert 2D HF mask (1=attend, 0=pad) to 4D TE mask (True=masked, False=attend) + attention_mask = attention_mask[:, None, None, :] == 0 + + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=4096) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmcForMaskedLM(NVEsmcPreTrainedModel): + """ESMC model with masked language modeling head.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {} + _do_not_quantize: ClassVar[list[str]] = ["sequence_head.dense", "sequence_head.decoder"] + + def __init__(self, config: NVEsmcConfig): + """Initialize NVEsmcForMaskedLM.""" + super().__init__(config) + self.esmc = NVEsmcModel(config) + self.sequence_head = NVEsmcLMHead(config) + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings.""" + return self.sequence_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings.""" + self.sequence_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass with masked language modeling loss. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + inputs_embeds: Pre-computed input embeddings. + labels: Labels for masked token prediction. + **kwargs: Additional keyword arguments. + + Returns: + MaskedLMOutput with loss, logits, and optional hidden states. + """ + outputs = self.esmc( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs.last_hidden_state + + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.sequence_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmcLMHead(nn.Module): + """ESMC language modeling head: Linear -> GELU -> LayerNorm -> Linear. + + This matches the EvolutionaryScale `RegressionHead(d_model, output_dim)` architecture. + """ + + def __init__(self, config: NVEsmcConfig): + """Initialize NVEsmcLMHead.""" + super().__init__() + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + bias=True, + params_dtype=config.torch_dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.torch_dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features): + """Forward pass: Dense -> GELU -> LayerNormLinear.""" + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +# ===================== Utility Functions for THD Packing ===================== + +torch._dynamo.config.capture_scalar_outputs = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to BSHD format.""" + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to THD format.""" + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: + return ( + hidden_states.squeeze(1), + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), + 1, + 1, + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) diff --git a/bionemo-recipes/models/esmc/requirements.txt b/bionemo-recipes/models/esmc/requirements.txt new file mode 100644 index 0000000000..23fdd7b647 --- /dev/null +++ b/bionemo-recipes/models/esmc/requirements.txt @@ -0,0 +1,6 @@ +esm @ git+https://github.com/evolutionaryscale/esm.git +transformer_engine +transformers +torch +accelerate +pytest diff --git a/bionemo-recipes/models/esmc/state.py b/bionemo-recipes/models/esmc/state.py new file mode 100644 index 0000000000..d69b1cab12 --- /dev/null +++ b/bionemo-recipes/models/esmc/state.py @@ -0,0 +1,705 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion utilities adapted from nemo.lightning.io.state.""" + +import inspect +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger(__name__) + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + """Transform Data class Definition.""" + + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +class _ModelState: + """Helper class for used for to modify state dict of a source model during model conversion.""" + + def __init__(self, state_dict, config=None): + self._state_dict = state_dict + self.config = config + + def state_dict(self): + # pylint: disable=C0115,C0116 + return self._state_dict + + def to(self, dtype): + # pylint: disable=C0115,C0116 + for k, v in self._state_dict.items(): + if v.dtype != dtype: + logger.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + +@torch.no_grad +def apply_transforms( + source: Union[nn.Module, _ModelState], + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [], + state_dict_ignored_entries: List = [], + cast_dtype: Optional[torch.dtype] = None, +) -> TargetModuleT: + """Transform the state dictionary of a source module to match the structure of a target module's state dictionary. + + This function renames keys according to a provided mapping and modifies values using a list + of transformation functions. Each transformation function typically is decorated + with `io.state_transform`. + + Args: + source (nn.Module): The source module from which parameters and buffers are taken. + target (TargetModuleT): The target module to which parameters and buffers are adapted. + mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary + is mapped to a corresponding key in the target state dictionary. + transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions + that modify the `TransformCTX` object. If None, no transformations beyond key renaming + are applied. Defaults to None. + state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases + where multiple entries in model's state_dict point to one entry in model's named_parameter. + E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, + `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight + in T5 Huggingface implementation.). In these cases, ignore redundant entries. + cast_dtype: case the output state dict to a certain precision. + + Returns: + TargetModuleT: The modified target module with its state dictionary adjusted according to + the specified mappings and transformations. + + Raises: + ValueError: If there's a mismatch in shape between corresponding source and target parameters + or buffers. + RuntimeError: If the target state dictionary contains keys that are not present in the source + state dictionary after all transformations. + + Examples: + >>> source_module = nn.Linear(10, 5) + >>> target_module = nn.Linear(10, 5) + >>> mapping = {'weight': 'weights', 'bias': 'biases'} + @io.state_transform( + source_key="weight", + target_key="weights" + ) + def scale_weights(ctx): + ctx.target_state['weights'] = ctx.source_state['weight'] * 2 + return ctx + >>> transformed_target = apply_transforms( + ... source_module, target_module, mapping, [scale_weights] + ... ) + >>> print(transformed_target.state_dict()['weights']) + + See Also: + - `TransformCTX`: For more details on the context object used in transformations. + - `StateDictTransform`: For creating complex transformations. + + Note: + This function is particularly useful when adapting models from different frameworks or + when consolidating models with different architectural changes. + """ + # Track dtypes to make sure they weren't modified during conversion. + target_orig_dtypes = extract_dtypes(target.named_parameters()) + + target_state = target.state_dict() + ctx = TransformCTX( + source=source, + source_state=source.state_dict(), + target=target, + target_state=target_state, + ) + + for key, val in mapping.items(): + logger.debug(f"Mapping {key} -> {val}") + ctx = StateDictTransform(key, val)(ctx) + + for transform in transforms: + logger.debug(f"Transforming {transform.source_key} -> {transform.target_key}") + ctx = transform(ctx) + + _params: Dict[str, nn.Parameter] = {} + for name, param in target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError( + f"Shape mismatch for parameter {name}: target shape {param.shape} vs " + f"converted source shape {target_param.shape}" + ) + + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in target model but is in source model.") + + for key, val in _params.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_parameter(_key, val) + + _buffers = {} + for name, buffer in target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_buffer(_key, val) + + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) + keys = [key for key in keys if key not in state_dict_ignored_entries] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.") + + if hasattr(target, "tie_weights"): + target.tie_weights() + + meta_tensor_keys = [] + for name, param in target.named_parameters(): + if param.is_meta: + meta_tensor_keys.append(name) + + assert not meta_tensor_keys, ( + f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." + f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" + ) + + if cast_dtype: + logger.info(f"Casting model to {cast_dtype}...") + target.to(cast_dtype) + logger.info(f"Casting model to {cast_dtype} complete.") + else: + target_new_dtypes = extract_dtypes(target.named_parameters()) + for key in target_orig_dtypes.keys(): + if key in target_new_dtypes: # For tied weights, these parameters may disappear. + assert target_orig_dtypes[key] == target_new_dtypes[key], ( + f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}" + ) + + return target + + +def _default_transform(inp): + return inp + + +class StateDictTransform(Generic[F]): + """A transformation class for state dictionaries. + + Allows for flexible key matching and transformation of values between source and target state dictionaries. + + Attributes: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + transform: A callable that performs the transformation on matched keys' values. + + Examples: + >>> def example_transform(ctx, *args): + ... return sum(args) + >>> transform = StateDictTransform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", + ... transform=example_transform + ... ) + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + """Initialize the StateDictTransform.""" + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + """Perform the transformation on the given context.""" + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + np.set_printoptions(threshold=10) + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + matched = False + if isinstance(source_key, (dict, tuple)): + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): + # Wrap in a list if it's a single layer (ie non-expert) + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] # noqa: PLW2901 + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) + logger.debug(f"Matched (transform)! {layer_names_group=}") + matched = True + else: + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size == 1 and target_matches == np.array(None): + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + _matches = [_match_keys(target_keys, key) for key in target_key] + target_matches = np.stack(_matches, axis=-1) + + # Determine if we are dealing with multiple source matches or multiple target matches + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + try: + source_match = source_matches[target_index] + except IndexError as e: + logger.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}") + raise e + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + logger.debug(f"Matched (multi source)! {target_match=} {source_match=}") + matched = True + else: + for source_index, source_match in np.ndenumerate(source_matches): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = dict(zip(fn_params, source_values)) + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + logger.debug(f"Matched (single source)! {target_match=} {source_match=}") + matched = True + if not matched: + logger.warning(f"No matches found for source key: {source_key=} {target_key=}") + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + """Perform transform and check if the given args valid.""" + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ["self", "ctx"]]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if "ctx" in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + escaped_pattern = "" + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == "**": + escaped_pattern += r"(.+)" # Match any characters including dots + wildcard_positions.append("**") + i += 2 + elif pattern[i] == "*": + escaped_pattern += r"([^.]+)" # Match any characters except dots + wildcard_positions.append("*") + i += 1 + else: + if pattern[i] == ".": + escaped_pattern += r"\." # Escape the dot + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + # Sort the wildcard matches to maintain consistent ordering + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + # Determine the shape of the output array based on the unique matches for each wildcard + shape = [len(matches) for matches in wildcard_matches] + + if len(wildcard_matches) == 0: + # If there is no wildcard matches, assuming it is a single match + shape = [1] + # Initialize an empty array with the determined shape + output_array = np.empty(shape, dtype=object) + + # Populate the array with the keys, now that we have the correct shape and ordering + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + # Convert match groups to indices based on their position in wildcard_matches + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key # Place the key in the array based on the indices + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """Create a StateDictTransform instance with specified source and target keys, and a transformation function. + + Args: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + fn: An optional callable that performs the transformation on matched keys' values. If not + provided, the decorator can be used to wrap a function definition. + + Returns: + ------- + A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that + takes a function and returns a StateDictTransform instance. + + Examples: + -------- + >>> @state_transform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" + ... ) + ... def sum_transform(ctx, *args): + ... return sum(args) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + + return wrapper(fn) + + +class TransformFns: + """A collection of common functions used in state dict transformation.""" + + @staticmethod + def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): + """Split interleave-concatenated qkv to q, k, v. + + Example: export layer linear_qkv to HF {q|k|v}_proj + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + hidden_size = linear_qkv.size(-1) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + @staticmethod + def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Split interleave-concatenated qkv bias to separate q, k, v bias. + + Example: export layer linear_qkv bias to HF {q|k|v}_proj bias + """ + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + @staticmethod + def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor): + """Merge naively concatenated q, k, v to interleave-concatenated qkv. + + Example: import HF qkv to layer linear_qkv + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0) + return TransformFns.merge_qkv(ctx, q, k, v) + + @staticmethod + def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge q, k, v to interleave-concatenated qkv. + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + @staticmethod + def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF qkv bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + qb, kb, vb = qkv_bias.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + return TransformFns.merge_qkv_bias(ctx, qb, kb, vb) + + @staticmethod + def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): + """Merge q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF {q|k|v}_proj bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + + new_q_tensor_shape = (head_num, head_size) + new_kv_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_tensor_shape) + kb = kb.view(*new_kv_tensor_shape) + vb = vb.view(*new_kv_tensor_shape) + + qkv_bias = torch.empty((0, head_size)).type_as(qb) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) + return qkv_bias + + @staticmethod + def merge_fc1(gate: torch.Tensor, up: torch.Tensor): + """Merge gate and up proj into concatenated fc1. + + Example: import HF {gate|up}_proj to layer linear_fc1 + """ + return torch.cat((gate, up), dim=0) + + @staticmethod + def split_fc1(linear_fc1: torch.Tensor): + """Split concatenated fc1 to gate and up proj. + + Example: export layer linear_fc1 to HF {gate|up}_proj + """ + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + return gate_proj, up_proj + + @staticmethod + def duplicate2(param: torch.Tensor): + """Duplicate the source parameter to two target parameters. + + Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A + """ + return param, param + + @staticmethod + def duplicate3(param: torch.Tensor): + """Duplicate the source parameter to three target parameters. + + Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A + """ + return param, param, param + + @staticmethod + def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): + """Prune the embedding size to vocab size. + + Example: export embedding/output layer to HF with non-padded vocab size + """ + megatron_config = ctx.target.config + return embedding[: megatron_config.vocab_size, :] + + +def extract_dtypes(ckpt): + """Extract dtype from the input iterator. + + ckpt can be module.named_parameters or module.state_dict().items() + """ + dtypes = {} + for key, val in ckpt: + if hasattr(val, "dtype"): + dtypes[key] = val.dtype + elif hasattr(val, "data") and hasattr(val.data, "dtype"): + # if it's ShardedTensor populated with data. + dtypes[key] = val.data.dtype + return dtypes diff --git a/bionemo-recipes/models/esmc/tests/common/README.md b/bionemo-recipes/models/esmc/tests/common/README.md new file mode 100644 index 0000000000..1259fd297c --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/common/README.md @@ -0,0 +1,64 @@ +# BioNeMo Common Test Library + +Shared test infrastructure for BioNeMo models. One base class, **BaseModelTest**: inherit and implement the abstract methods to get the full test suite (golden values, conversion, FP8, meta init, smoke tests). + +## Structure + +``` +tests/common/ +├── __init__.py # Public API exports +├── test_modeling_common.py # BaseModelTest, TestTolerances +├── fixtures.py # input_format, fp8_recipe, te_attn_backend, etc. +└── README.md +``` + +**Required:** In your top-level `tests/conftest.py` (e.g. `bionemo-recipes/models/esm2/tests/conftest.py`), add: + +```python +pytest_plugins = ["tests.common.fixtures"] +``` + +Without this, parametrized fixtures will not load. + +## BaseModelTest + +Inherit from `BaseModelTest` and implement: + +| Method | Returns | Description | +| ------------------------------------------------- | ------------------------- | ----------------------------------------------- | +| `get_model_class()` | `Type[PreTrainedModel]` | TE model class | +| `get_tokenizer()` | `PreTrainedTokenizer` | Tokenizer | +| `get_config_class()` | `Type[PretrainedConfig]` | Config class | +| `get_upstream_model_id()` | `str` | HF model ID | +| `get_upstream_model_revision()` | `Optional[str]` | Revision or None | +| `get_upstream_model_class()` | `Type[PreTrainedModel]` | HF model class | +| `get_layer_path(model)` | `List[nn.Module]` | Transformer layers | +| `get_test_input_data(format, pad_to_multiple_of)` | `Dict[str, torch.Tensor]` | Inputs on CUDA; `format` is `"bshd"` or `"thd"` | +| `get_hf_to_te_converter()` | `Callable` | HF → TE | +| `get_te_to_hf_converter()` | `Callable` | TE → HF | + +**Optional overrides:** `get_tolerances()` → `TestTolerances`, `get_attn_input_formats()`, `get_reference_model_no_weights()`. + +**Helpers:** `create_test_config()`, `get_reference_model()`, `get_reference_model_no_weights()`, `compare_outputs()`, `verify_model_parameters_initialized_correctly()`, `get_converted_te_model_checkpoint()`, `get_converted_te_model()`. + +**Tests included:** Meta/CUDA init (`test_cuda_init`, `test_meta_init`, …), smoke (parametrized by `input_format`), conversion, golden values (BSHD + THD), FP8 (parametrized by `fp8_recipe`, `input_format`). + +## TestTolerances + +Dataclass in `test_modeling_common.py`. Override `get_tolerances()` to return a custom instance. Fields: `golden_value_*`, `cp_*`, `fp8_*`, `init_*` (see class definition). + +## Fixtures (fixtures.py) + +| Fixture | Description | +| ----------------- | ----------------------------------- | +| `input_format` | `"bshd"` / `"thd"` | +| `fp8_recipe` | FP8 recipe (skipped if unsupported) | +| `te_attn_backend` | `"flash_attn"` / `"fused_attn"` | +| `unused_tcp_port` | For distributed tests | +| `use_te_debug` | Autouse: `NVTE_DEBUG=1` | + +## Usage + +1. Create a class inheriting from `BaseModelTest` and implement the abstract methods (see `esm2/tests/test_modeling_esm_te.py` for a full example). +2. Add `pytest_plugins = ["tests.common.fixtures"]` to `tests/conftest.py`. +3. Run `pytest tests/test_modeling__te.py -v`. diff --git a/bionemo-recipes/models/esmc/tests/common/__init__.py b/bionemo-recipes/models/esmc/tests/common/__init__.py new file mode 100644 index 0000000000..dee29e4297 --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/common/__init__.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common test utilities for BioNeMo models. + +This package provides reusable test infrastructure following HuggingFace +transformers patterns, including: + +- BioNeMoModelTester: Abstract base class for model-specific test configuration +- BioNeMoModelTest: Base test class with all common test methods +- TestTolerances: Dataclass for model-specific numerical tolerances +- Distributed testing utilities for multi-GPU tests +- Shared fixtures for common test requirements + +Example usage: + + ```python + from tests.common import BioNeMoModelTester, BioNeMoModelTest, TestTolerances + + class ESM2ModelTester(BioNeMoModelTester): + def get_model_class(self): + return NVEsmForMaskedLM + # ... implement other abstract methods + ``` +""" + +from .test_modeling_common import HAS_DATA_CENTER_GPU, BaseModelTest, TestTolerances + + +__all__ = [ + "HAS_DATA_CENTER_GPU", + "BaseModelTest", + "TestTolerances", +] diff --git a/bionemo-recipes/models/esmc/tests/common/fixtures.py b/bionemo-recipes/models/esmc/tests/common/fixtures.py new file mode 100644 index 0000000000..d786d76506 --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/common/fixtures.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test fixtures for BioNeMo models.""" + +import os +import socket + +import pytest +from transformer_engine.common import recipe as recipe_module +from transformer_engine.pytorch import fp8 +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends + + +@pytest.fixture +def unused_tcp_port() -> int: + """Get an unused TCP port for distributed testing. + + Returns: + An available TCP port number. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +@pytest.fixture(autouse=True) +def use_te_debug(): + """Auto-use fixture to enable TransformerEngine debugging. + + This fixture automatically enables debug mode for TransformerEngine + in all tests for better error messages. + """ + import os + + os.environ["NVTE_DEBUG"] = "1" + yield + del os.environ["NVTE_DEBUG"] + + +ALL_RECIPES = [ + recipe_module.DelayedScaling(), + recipe_module.Float8CurrentScaling(), + recipe_module.Float8BlockScaling(), + recipe_module.MXFP8BlockScaling(), + recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +] + + +def _check_recipe_support(recipe: recipe_module.Recipe): + """Check if a recipe is supported and return (supported, reason).""" + if isinstance(recipe, recipe_module.DelayedScaling): + recipe_supported, reason = fp8.check_fp8_support() + elif isinstance(recipe, recipe_module.Float8CurrentScaling): + recipe_supported, reason = fp8.check_fp8_support() + elif isinstance(recipe, recipe_module.Float8BlockScaling): + recipe_supported, reason = fp8.check_fp8_block_scaling_support() + elif isinstance(recipe, recipe_module.MXFP8BlockScaling): + recipe_supported, reason = fp8.check_mxfp8_support() + elif isinstance(recipe, recipe_module.NVFP4BlockScaling): + recipe_supported, reason = fp8.check_nvfp4_support() + else: + recipe_supported = False + reason = "Unsupported recipe" + return recipe_supported, reason + + +def parametrize_recipes_with_support(recipes): + """Generate pytest.param objects with skip marks for unsupported recipes.""" + parametrized_recipes = [] + for recipe in recipes: + recipe_supported, reason = _check_recipe_support(recipe) + parametrized_recipes.append( + pytest.param( + recipe, + id=recipe.__class__.__name__, + marks=pytest.mark.xfail( + condition=not recipe_supported, + reason=reason, + ), + ) + ) + return parametrized_recipes + + +@pytest.fixture(params=parametrize_recipes_with_support(ALL_RECIPES)) +def fp8_recipe(request): + """Fixture to parametrize the FP8 recipe.""" + return request.param + + +@pytest.fixture(params=["bshd", "thd"]) +def input_format(request): + """Fixture to parametrize the input format.""" + return request.param + + +@pytest.fixture(params=["flash_attn", "fused_attn"]) +def te_attn_backend(request): + """Fixture to parametrize the attention implementation.""" + if request.param == "flash_attn": + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_FLASH_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + else: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + yield request.param + + del os.environ["NVTE_FUSED_ATTN"] + del os.environ["NVTE_FLASH_ATTN"] + _attention_backends["backend_selection_requires_update"] = True diff --git a/bionemo-recipes/models/esmc/tests/common/test_modeling_common.py b/bionemo-recipes/models/esmc/tests/common/test_modeling_common.py new file mode 100644 index 0000000000..d7541c4d28 --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/common/test_modeling_common.py @@ -0,0 +1,971 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common test class for BioNeMo models, following HuggingFace transformers patterns.""" + +import gc +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Literal, Type + +import pytest +import torch +import transformer_engine.pytorch +from torch import nn +from transformer_engine.pytorch import QuantizedTensor +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, set_seed + + +HAS_DATA_CENTER_GPU = any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] +) + + +@dataclass +class TestTolerances: + """Model-specific test tolerances for numerical comparisons.""" + + # Golden value test tolerances + golden_value_loss_atol: float = 1e-2 + golden_value_loss_rtol: float = 1e-3 + golden_value_logits_atol: float = 2.0 + golden_value_logits_rtol: float = 1e-4 + golden_value_hidden_states_atol: float = 0.1 + golden_value_hidden_states_rtol: float = 0.05 + + # Context parallel test tolerances + cp_loss_atol: float = 0.1 + cp_loss_rtol: float = 0.05 + cp_logits_atol: float = 1.0 + cp_logits_rtol: float = 0.1 + cp_gradients_atol: float = 0.1 + cp_gradients_rtol: float = 0.1 + + # FP8 test tolerances + fp8_loss_atol: float = 0.1 + fp8_loss_rtol: float = 0.05 + fp8_logits_atol: float = 5.0 + fp8_logits_rtol: float = 0.1 + + # Meta device initialization tolerances + init_mean_atol: float = 1e-3 + init_mean_rtol: float = 1e-4 + init_std_atol: float = 1e-3 + init_std_rtol: float = 1e-4 + + +class BaseModelTest(ABC): + """Abstract base class for testing BioNeMo models. + + This class provides common test utilities and defines the interface that + model-specific testers must implement. It follows the pattern used in + HuggingFace transformers for model testing. + + Subclasses must implement all abstract methods to provide model-specific + configuration, data preparation, and conversion functions. + + Example: + ```python + class ESM2ModelTester(BioNeMoModelTester): + def get_model_class(self): + return NVEsmForMaskedLM + + def get_config_class(self): + return NVEsmConfig + + def get_upstream_model_id(self): + return "facebook/esm2_t6_8M_UR50D" + + # ... implement other abstract methods + ``` + """ + + @abstractmethod + def get_model_class(self) -> Type[PreTrainedModel]: + """Return the TransformerEngine model class to test. + + Returns: + The model class (e.g., NVEsmForMaskedLM, NVLlamaForCausalLM). + """ + pass + + @abstractmethod + def get_tokenizer(self) -> PreTrainedTokenizer: + """Return the tokenizer for the model. + + Returns: + The tokenizer (e.g., AutoTokenizer). + """ + pass + + @abstractmethod + def get_config_class(self) -> Type[PretrainedConfig]: + """Return the config class for the model. + + Returns: + The config class (e.g., NVEsmConfig, NVLlamaConfig). + """ + pass + + @abstractmethod + def get_upstream_model_id(self) -> str: + """Return the HuggingFace model ID for the reference model. + + Returns: + Model ID string (e.g., "facebook/esm2_t6_8M_UR50D"). + """ + pass + + @abstractmethod + def get_upstream_model_revision(self) -> str: + """Return the specific revision/commit hash for the upstream model. + + Returns: + Revision string or 'main' for latest. + """ + pass + + @abstractmethod + def get_upstream_model_class(self) -> Type[PreTrainedModel]: + """Return the HuggingFace reference model class. + + Returns: + The HF model class (e.g., AutoModelForMaskedLM, AutoModelForCausalLM). + """ + pass + + @abstractmethod + def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: + """Return the list of transformer layers in the model. + + Args: + model: The model instance. + + Returns: + List of transformer layer modules. + + Example: + For ESM2: model.esm.encoder.layers + For LLaMA3: model.model.layers + """ + pass + + @abstractmethod + def get_test_input_data( + self, + format: Literal["bshd", "thd"] = "bshd", + pad_to_multiple_of: int | None = None, + ) -> Dict[str, torch.Tensor]: + """Prepare test input data for the model. + + Args: + format: Whether to use sequence packing (THD) or bshd format. + + Returns: + Dictionary of input tensors (input_ids, attention_mask, etc.). + """ + pass + + @abstractmethod + def get_hf_to_te_converter(self) -> Callable: + """Return the function that converts HF model to TE model. + + Returns: + Conversion function with signature: (hf_model, **kwargs) -> te_model + """ + pass + + @abstractmethod + def get_te_to_hf_converter(self) -> Callable: + """Return the function that converts TE model to HF model. + + Returns: + Conversion function with signature: (te_model, **kwargs) -> hf_model + """ + pass + + def get_tolerances(self) -> TestTolerances: + """Return test tolerances for this model. + + Override this method to provide model-specific tolerances. + + Returns: + TestTolerances instance with appropriate values. + """ + return TestTolerances() + + def get_attn_input_formats(self) -> List[str]: + """Return supported attention input formats. + + Returns: + List of format strings (e.g., ["bshd", "thd"]). + """ + return ["bshd"] + + def verify_model_parameters_initialized_correctly( + self, + model: PreTrainedModel, + atol: float | None = None, + rtol: float | None = None, + should_be_fp8: bool = False, + ) -> None: + """Verify that model parameters are initialized correctly. + + This can be overridden for models that use non-standard weight initialization. + + This checks that: + 1. All parameters are on CUDA device + 2. Embeddings have correct mean and std + 3. Linear layers have correct weight/bias initialization + 4. LayerNorm parameters are initialized correctly + 5. FP8 quantization is applied if requested + + Args: + model: The model to verify. + atol: Absolute tolerance for comparisons (uses default if None). + rtol: Relative tolerance for comparisons (uses default if None). + should_be_fp8: Whether to expect FP8 quantized weights. + """ + config = model.config + tolerances = self.get_tolerances() + + if atol is None: + atol = tolerances.init_mean_atol + if rtol is None: + rtol = tolerances.init_mean_rtol + + # Verify all parameters are on CUDA + for name, parameter in model.named_parameters(): + assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" + + # Verify initialization for each module type + for name, module in model.named_modules(): + + def msg(x): + return f"Mismatch in module {name}: {x}" + + if isinstance(module, torch.nn.Embedding): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), + config.initializer_range, + atol=tolerances.init_std_atol, + rtol=tolerances.init_std_rtol, + msg=msg, + ) + + elif isinstance(module, transformer_engine.pytorch.Linear): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), + config.initializer_range, + atol=tolerances.init_std_atol, + rtol=tolerances.init_std_rtol, + msg=msg, + ) + if module.bias is not None: + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + if should_be_fp8: + if f"{name}.weight" in set(model._tied_weights_keys): + continue # Skip tied weights + elif hasattr(model, "_do_not_quantize") and name in model._do_not_quantize: + continue # Skip weights that should be kept in bf16 + assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a Float8Tensor" + + elif isinstance(module, transformer_engine.pytorch.LayerNorm): + torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + elif isinstance(module, torch.nn.LayerNorm): + torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) + if module.bias is not None: + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + def create_test_config(self, **kwargs) -> PretrainedConfig: + """Create a test configuration with optional overrides. + + Args: + **kwargs: Configuration parameters to override. + + Returns: + Configuration instance. + """ + config_class = self.get_config_class() + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + return config_class.from_pretrained(upstream_id, revision=revision, **kwargs) + + def get_reference_model( + self, + dtype: torch.dtype = torch.bfloat16, + attn_implementation: str = "flash_attention_2", + ) -> PreTrainedModel: + """Load the reference HuggingFace model. + + Args: + dtype: Data type for the model. + device: Device to load model on. + attn_implementation: Attention implementation to use. + + Returns: + The loaded reference model. + """ + upstream_class = self.get_upstream_model_class() + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + + kwargs = { + "dtype": dtype, + "attn_implementation": attn_implementation, + } + if revision is not None: + kwargs["revision"] = revision + + model = upstream_class.from_pretrained(upstream_id, **kwargs) + model.to("cuda") + return model + + def get_reference_model_no_weights(self) -> PreTrainedModel: + """Load the reference HuggingFace model with random weights.""" + return self.get_upstream_model_class()( + AutoConfig.from_pretrained( + self.get_upstream_model_id(), + dtype=torch.float32, + revision=self.get_upstream_model_revision(), + ) + ) + + def compare_outputs( + self, + te_outputs, + hf_outputs, + input_data: Dict[str, torch.Tensor], + compare_loss: bool = True, + compare_logits: bool = True, + compare_hidden_states: bool = False, + ) -> None: + """Compare outputs from TE and HF models. + + Args: + te_outputs: Outputs from TransformerEngine model. + hf_outputs: Outputs from HuggingFace model. + input_data: Input data dictionary (for attention mask). + compare_loss: Whether to compare loss values. + compare_logits: Whether to compare logits. + compare_hidden_states: Whether to compare hidden states. + """ + tolerances = self.get_tolerances() + + if compare_loss and hasattr(te_outputs, "loss") and hasattr(hf_outputs, "loss"): + torch.testing.assert_close( + te_outputs.loss, + hf_outputs.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + msg=lambda x: f"Loss mismatch between TE and HF models: {x}", + ) + + if compare_logits and hasattr(te_outputs, "logits") and hasattr(hf_outputs, "logits"): + # Only compare logits where attention mask is True + if "attention_mask" in input_data: + mask = input_data["attention_mask"].to(bool) + torch.testing.assert_close( + te_outputs.logits[mask], + hf_outputs.logits[mask], + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + msg=lambda x: f"Logits mismatch between TE and HF models: {x}", + ) + else: + torch.testing.assert_close( + te_outputs.logits, + hf_outputs.logits, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + msg=lambda x: f"Logits mismatch between TE and HF models: {x}", + ) + + if compare_hidden_states and hasattr(te_outputs, "hidden_states") and hasattr(hf_outputs, "hidden_states"): + for i, (te_hidden, hf_hidden) in enumerate(zip(te_outputs.hidden_states, hf_outputs.hidden_states)): + torch.testing.assert_close( + te_hidden, + hf_hidden, + atol=tolerances.golden_value_hidden_states_atol, + rtol=tolerances.golden_value_hidden_states_rtol, + msg=lambda x: f"Hidden states mismatch at layer {i}: {x}", + ) + + @pytest.fixture(autouse=True, scope="function") + def clear_gpu_memory(self): + """Clear GPU memory before and after each test to prevent OOM from fragmentation.""" + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + @pytest.fixture(autouse=True, scope="function") + def set_seed(self): + set_seed(42) + + @pytest.fixture(autouse=True, scope="function") + def reset_fp8_context(self): + """Make sure we clean up the FP8 context after each test.""" + FP8GlobalStateManager.reset() + + # ==================== Forward and Backward Smoke Tests ==================== + + def test_smoke_forward_pass(self, input_format): + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Forward pass with output_hidden_states + with torch.no_grad(): + outputs = model(**input_data, output_hidden_states=True) + + # Verify outputs + assert outputs.logits is not None, "Model should output logits" + assert outputs.hidden_states is not None, "Model should output hidden states when requested" + assert len(outputs.hidden_states) == config.num_hidden_layers + 1, ( + f"Expected {config.num_hidden_layers + 1} hidden states, got {len(outputs.hidden_states)}" + ) + + def test_smoke_backward_pass(self, input_format): + """Smoke test: backward pass.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Forward pass + outputs = model(**input_data, output_hidden_states=True) + + # Backward pass + outputs.logits.mean().backward() + + # Verify all parameters have gradients + for param in model.parameters(): + if param.requires_grad: + assert param.grad is not None, "All trainable parameters should have gradients after backward pass" + + def test_smoke_model_with_loss(self, input_format): + """Smoke test: model forward pass with labels produces loss.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data with labels + input_data = self.get_test_input_data(input_format) + + # Ensure labels are present + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward pass + with torch.no_grad(): + outputs = model(**input_data) + + # Verify loss is computed + assert outputs.loss is not None, "Model should compute loss when labels are provided" + assert outputs.loss.item() > 0, "Loss should be positive" + + def test_forward_and_backward(self, input_format): + """Test that model can perform forward and backward passes.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Add labels for loss computation + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward pass + outputs = model(**input_data) + loss = outputs.loss + + # Backward pass + loss.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient" + + # ==================== Conversion Tests ==================== + + def test_convert_hf_to_te(self): + """Test that HF model can be converted to TE format.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + # Convert to TE + convert_fn = self.get_hf_to_te_converter() + model_te = convert_fn(model_hf_original) + + # Verify model structure + assert model_te is not None + assert isinstance(model_te, self.get_model_class()) + + def test_convert_te_to_hf(self): + """Test that TE model can be converted back to HF format.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + # Convert to TE + hf_to_te_fn = self.get_hf_to_te_converter() + model_te = hf_to_te_fn(model_hf_original) + + # Convert back to HF + te_to_hf_fn = self.get_te_to_hf_converter() + model_hf_converted = te_to_hf_fn(model_te) + + # Verify model structure + assert model_hf_converted is not None + assert isinstance(model_hf_converted, self.get_upstream_model_class()) + + def test_convert_te_to_hf_roundtrip(self): + """Test that HF → TE → HF conversion preserves weights.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + original_state_dict = model_hf_original.state_dict() + + # Convert to TE and back + hf_to_te_fn = self.get_hf_to_te_converter() + te_to_hf_fn = self.get_te_to_hf_converter() + + model_te = hf_to_te_fn(model_hf_original) + model_hf_converted = te_to_hf_fn(model_te) + converted_state_dict = model_hf_converted.state_dict() + + # Compare state dicts + assert set(original_state_dict.keys()) == set(converted_state_dict.keys()), "State dict keys don't match" + + for key in original_state_dict.keys(): + original_param = original_state_dict[key] + converted_param = converted_state_dict[key] + + # Convert both to the same dtype for comparison (use the original dtype) + if original_param.dtype != converted_param.dtype: + converted_param = converted_param.to(original_param.dtype) + + torch.testing.assert_close( + original_param, + converted_param, + atol=1e-5, + rtol=1e-5, + msg=f"Mismatch in parameter {key} after roundtrip conversion", + ) + + def test_convert_config(self): + """Test that config can be converted between HF and TE formats.""" + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + + # Load HF config + from transformers import AutoConfig + + kwargs = {} + if revision is not None: + kwargs["revision"] = revision + hf_config = AutoConfig.from_pretrained(upstream_id, **kwargs) + + # Get TE config class + te_config_class = self.get_config_class() + + # Convert to TE config + te_config = te_config_class(**hf_config.to_dict()) + + # Verify key attributes match + assert te_config.hidden_size == hf_config.hidden_size + assert te_config.num_hidden_layers == hf_config.num_hidden_layers + assert te_config.num_attention_heads == hf_config.num_attention_heads + + @pytest.fixture(scope="class", autouse=True) + def _set_tmpdir(self, tmp_path_factory): + """Make sure we can see the saved te checkpoint as a class-scoped fixture.""" + # set on the class, visible as self._tmp_dir + type(self)._tmp_dir = tmp_path_factory.mktemp(self.__class__.__name__) + + def get_converted_te_model_checkpoint(self) -> Path: + """Get the path to the converted TE model checkpoint. + + This method manages GPU memory carefully to support large models: + 1. Load and convert the HF model + 2. Free the HF model before saving + 3. Move TE model to CPU before saving (save_pretrained clones state dict internally) + """ + model_hf = self.get_reference_model() + convert_fn = self.get_hf_to_te_converter() + model_te = convert_fn(model_hf) + + # Free source model to reduce peak GPU memory + del model_hf + gc.collect() + torch.cuda.empty_cache() + + # Move to CPU before saving - save_pretrained internally clones the state dict, + # which would double GPU memory usage and OOM for large models. + model_te.to("cpu") + + checkpoint_path: Path = self._tmp_dir / "converted_te_model" + model_te.save_pretrained(checkpoint_path) + + del model_te + gc.collect() + + return checkpoint_path + + def get_converted_te_model(self, **kwargs) -> PreTrainedModel: + """Get the converted TE model. + + This shouldn't get called before the checkpoint tests are run in case they're broken. + """ + checkpoint_path = self.get_converted_te_model_checkpoint() + model_te = self.get_model_class().from_pretrained(checkpoint_path, **kwargs) + model_te.to("cuda") + return model_te + + # ==================== Golden Value Tests ==================== + + def test_golden_values(self): + """Test that TE model outputs match HF reference model. + + Models are run sequentially and freed between runs to support large models + that cannot fit two copies on a single GPU simultaneously. + """ + input_data = self.get_test_input_data("bshd") + + # Run HF model first, then free it + model_hf = self.get_reference_model(dtype=torch.bfloat16) + model_hf.eval() + with torch.no_grad(): + hf_outputs = model_hf(**input_data) + hf_loss = hf_outputs.loss.detach().clone() + hf_logits = hf_outputs.logits.detach().clone() + del model_hf, hf_outputs + gc.collect() + torch.cuda.empty_cache() + + # Load and run TE model + model_te = self.get_converted_te_model(dtype=torch.bfloat16) + model_te.eval() + with torch.no_grad(): + te_outputs = model_te(**input_data) + del model_te + gc.collect() + torch.cuda.empty_cache() + + # Compare outputs + self.compare_outputs( + te_outputs, + type("HFOutputs", (), {"loss": hf_loss, "logits": hf_logits})(), + input_data, + compare_loss=True, + compare_logits=True, + compare_hidden_states=False, + ) + + def test_golden_values_thd(self, te_attn_backend): + """Test the model outputs the same results with THD and BSHD input formats.""" + + if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") + + input_data_bshd = self.get_test_input_data(format="bshd") + input_data_thd = self.get_test_input_data(format="thd") + tolerances = self.get_tolerances() + + torch.testing.assert_close( + input_data_bshd["input_ids"][input_data_bshd["attention_mask"].to(bool)], + input_data_thd["input_ids"].flatten(0), + ) + + # The THD labels will have some extra -100 items due to the separator token, so we need to filter them out. + labels_bshd = input_data_bshd["labels"][input_data_bshd["attention_mask"].to(bool)] + labels_thd = input_data_thd["labels"].flatten(0) + torch.testing.assert_close(labels_bshd[labels_thd != -100], labels_thd[labels_thd != -100]) + + # Run models sequentially to support large models that cannot fit two copies on GPU + model_bshd = self.get_converted_te_model(attn_input_format="bshd", dtype=torch.bfloat16) + model_bshd.eval() + with torch.inference_mode(): + outputs_bshd = model_bshd(**input_data_bshd) + bshd_loss = outputs_bshd.loss.detach().clone() + bshd_logits = outputs_bshd.logits[input_data_bshd["attention_mask"].to(bool)].detach().clone() + del model_bshd, outputs_bshd + gc.collect() + torch.cuda.empty_cache() + + model_thd = self.get_converted_te_model(attn_input_format="thd", dtype=torch.bfloat16) + model_thd.eval() + with torch.inference_mode(): + outputs_thd = model_thd(**input_data_thd) + + # Compare logits + torch.testing.assert_close( + bshd_logits, + outputs_thd.logits, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + ) + + # Compare losses + torch.testing.assert_close( + bshd_loss, + outputs_thd.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + ) + + def test_thd_padding_input_data_equivalence(self): + """Test that the THD input data is the same before and after padding.""" + + input_data_thd = self.get_test_input_data(format="thd") + input_data_thd_padded = self.get_test_input_data(format="thd", pad_to_multiple_of=32) + + cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] + cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] + cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q + seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + + num_real_tokens = cu_seq_lens_q[-1] + + # How much we need to shift each sequence by. + offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) + + # The indices of the real tokens as appears in the padded logits. + real_idx = torch.arange(0, num_real_tokens, device="cuda") + offsets + + torch.testing.assert_close( + input_data_thd["input_ids"], + input_data_thd_padded["input_ids"].index_select(1, real_idx), + ) + + torch.testing.assert_close( + input_data_thd["labels"], + input_data_thd_padded["labels"].index_select(1, real_idx), + ) + assert input_data_thd_padded["pad_between_seqs"] is True + + @pytest.mark.xfail( + condition=not HAS_DATA_CENTER_GPU, + reason="Padded THD sequences are not supported on non-datacenter hardware.", + ) + def test_golden_values_thd_padded(self): + """Test that the model outputs the same results with padded input data.""" + + input_data_thd = self.get_test_input_data(format="thd") + input_data_thd_padded = self.get_test_input_data(format="thd", pad_to_multiple_of=32) + tolerances = self.get_tolerances() + + model_thd = self.get_converted_te_model(attn_input_format="thd", dtype=torch.bfloat16) + model_thd.eval() + + with torch.inference_mode(): + outputs_thd = model_thd(**input_data_thd) + outputs_thd_padded = model_thd(**input_data_thd_padded) + + cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] + cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] + cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q + seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + num_real_tokens = cu_seq_lens_q[-1] + offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) + + # The indices of the real tokens as appears in the padded logits. + real_idx = torch.arange(0, num_real_tokens, device="cuda") + offsets + logits_unpadded = outputs_thd_padded.logits.index_select(0, real_idx.cuda()) + + torch.testing.assert_close( + outputs_thd.logits, + logits_unpadded, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + ) + + torch.testing.assert_close( + outputs_thd.loss, + outputs_thd_padded.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + ) + + # ==================== FP8 Tests ==================== + def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): + """Test that model works with FP8 autocast.""" + if input_format == "thd" and not HAS_DATA_CENTER_GPU: + pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") + + model_class = self.get_model_class() + config = self.create_test_config( + dtype=torch.bfloat16, attn_input_format=input_format, self_attn_mask_type="padding_causal" + ) + + model = model_class(config) + model.to("cuda") + model.eval() + + # Prepare input data + input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) + + # Run without FP8 + with torch.no_grad(): + outputs = model(**input_data) + loss_bf16 = outputs.loss + + # Run with FP8 + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs_fp8 = model(**input_data) + loss_fp8 = outputs_fp8.loss + + assert torch.isfinite(loss_fp8) + + # Backward pass + loss_fp8.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" + + # Compare losses (should be close but not identical due to quantization) + tolerances = self.get_tolerances() + torch.testing.assert_close( + loss_fp8, + loss_bf16, + atol=tolerances.fp8_loss_atol, + rtol=tolerances.fp8_loss_rtol, + msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", + ) + + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + """Test that model initialized with FP8 works correctly.""" + if input_format == "thd" and not HAS_DATA_CENTER_GPU: + pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") + + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + + # Initialize with FP8 + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + model.to("cuda") + model.eval() + + # Prepare input data + input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward and backward pass with FP8 + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) + + loss = outputs.loss + assert torch.isfinite(loss) + + loss.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" + + # ==================== Meta Device Initialization Tests ==================== + + def test_cuda_init(self): + """Test that model can be initialized directly on CUDA device.""" + model_class = self.get_model_class() + config = self.create_test_config() + + model = model_class(config) + model.to("cuda") + + self.verify_model_parameters_initialized_correctly(model) + + def test_meta_init(self): + """Test that model can be initialized on meta device and moved to CUDA.""" + model_class = self.get_model_class() + config = self.create_test_config() + + # Initialize on meta device + with torch.device("meta"): + model = model_class(config) + + # Assert parameters are actually on the meta device + for name, parameter in model.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + # Move to CUDA (this will materialize the parameters) + model.init_empty_weights() + self.verify_model_parameters_initialized_correctly(model) + + def test_cuda_fp8_init(self, fp8_recipe): + """Test that model can be initialized on CUDA with FP8.""" + model_class = self.get_model_class() + config = self.create_test_config() + + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + model.to("cuda") + + self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + + def test_meta_fp8_init(self, fp8_recipe): + """Test that model can be initialized on meta device with FP8 and moved to CUDA.""" + model_class = self.get_model_class() + config = self.create_test_config() + + # Initialize on meta device with FP8 + with torch.device("meta"): + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + # Assert parameters are actually on the meta device + for name, parameter in model.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + # Move to CUDA + model.init_empty_weights() + self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/esmc/tests/conftest.py b/bionemo-recipes/models/esmc/tests/conftest.py new file mode 100644 index 0000000000..c68bc09885 --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/conftest.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + + +sys.path.append(Path(__file__).parent.parent.as_posix()) +sys.path.append(Path(__file__).parent.as_posix()) + +pytest_plugins = ["tests.common.fixtures"] diff --git a/bionemo-recipes/models/esmc/tests/test_modeling_esmc_te.py b/bionemo-recipes/models/esmc/tests/test_modeling_esmc_te.py new file mode 100644 index 0000000000..457b885a7a --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/test_modeling_esmc_te.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the ESMC TransformerEngine model. + +This file provides tests extending the common BaseModelTest for ESMC, including: +- Forward/backward smoke tests +- Golden value tests against the EvolutionaryScale reference model +- FP8 tests +- Meta device initialization tests +- Conversion roundtrip tests +""" + +import gc +from pathlib import Path +from typing import Callable, Dict, List, Literal, Type + +import torch +from torch import nn +from transformers import AutoTokenizer, DataCollatorForLanguageModeling, PretrainedConfig, PreTrainedModel + +from collator import DataCollatorWithFlattening +from convert import convert_esmc_te_to_ref, convert_esmc_to_te +from modeling_esmc_te import NVEsmcConfig, NVEsmcForMaskedLM +from tests.common import BaseModelTest, TestTolerances + + +TOKENIZER_DIR = str(Path(__file__).resolve().parent.parent / "esmc_fast_tokenizer") + + +class TestEsmcModel(BaseModelTest): + """Model tester for ESMC. + + ESMC uses the EvolutionaryScale library (not standard HF), so we override + several methods to handle custom model loading and conversion. + """ + + def get_model_class(self) -> Type[PreTrainedModel]: + return NVEsmcForMaskedLM + + def get_config_class(self) -> Type[PretrainedConfig]: + return NVEsmcConfig + + def get_upstream_model_id(self) -> str: + return "EvolutionaryScale/esmc-300m-2024-12" + + def get_upstream_model_revision(self) -> str: + return "main" + + def get_upstream_model_class(self) -> Type[PreTrainedModel]: + # ESMC doesn't have a standard HF model class; we skip HF-specific tests. + return PreTrainedModel # Placeholder, not used directly + + def get_tokenizer(self): + return AutoTokenizer.from_pretrained(TOKENIZER_DIR) + + def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: + return list(model.esmc.layers) + + def create_test_config(self, **kwargs) -> PretrainedConfig: + """Create test config for ESMC - use full architecture params but limit layers for speed.""" + num_hidden_layers = kwargs.pop("num_hidden_layers", 2) + # Shim: the base test class passes `dtype=` which works on transformers v5, but the esm + # package pins transformers<4.53 where PretrainedConfig only accepts `torch_dtype=`. + # This can be dropped when esm updates its transformers version constraint. + if "dtype" in kwargs: + kwargs["torch_dtype"] = kwargs.pop("dtype") + return NVEsmcConfig( + vocab_size=64, + hidden_size=960, + num_hidden_layers=num_hidden_layers, + num_attention_heads=15, + intermediate_size=2560, + **kwargs, + ) + + def get_test_input_data( + self, + format: Literal["bshd", "thd"] = "bshd", + pad_to_multiple_of: int | None = None, + ) -> Dict[str, torch.Tensor]: + """Prepare test input data with protein sequences.""" + tokenizer = self.get_tokenizer() + + # Short protein sequences for testing + sequences = [ + "MKTVRQERLKSIVRILERSKEPV", + "KALTARQQEVFDLIRDHISQTGMPPTRA", + "MFKVYGYDSNIHKCV", + ] + + # Tokenize + tokenized = [tokenizer(seq) for seq in sequences] + + # Use data collator for MLM + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of if format == "bshd" else None, + ) + + if format == "thd": + data_collator = DataCollatorWithFlattening( + collator=data_collator, + pad_sequences_to_be_divisible_by=pad_to_multiple_of, + ) + + batch = data_collator(tokenized) + + # Move to device + return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + def get_hf_to_te_converter(self) -> Callable: + """Return the ESMC ref -> TE conversion function. + + Wraps convert_esmc_to_te to accept a model and return a model. + """ + + def _converter(model_ref, **kwargs): + """Convert a reference ESMC model to TE format.""" + ref_state_dict = model_ref.state_dict() + config = NVEsmcConfig( + vocab_size=64, + hidden_size=model_ref.embed.weight.shape[1], + num_hidden_layers=len(model_ref.transformer.blocks), + num_attention_heads=model_ref.transformer.blocks[0].attn.n_heads, + intermediate_size=model_ref.transformer.blocks[0].ffn[1].weight.shape[0] // 2, + **kwargs, + ) + return convert_esmc_to_te(ref_state_dict, config) + + return _converter + + def get_te_to_hf_converter(self) -> Callable: + """Return the TE -> ESMC ref conversion function.""" + return convert_esmc_te_to_ref + + def get_tolerances(self) -> TestTolerances: + """Return ESMC-specific tolerances. + + With full d_model QK LayerNorm (matching the reference model exactly), the TE model + closely reproduces reference outputs. These tolerances are comparable to LLaMA3. + """ + return TestTolerances( + golden_value_loss_atol=5e-3, + golden_value_loss_rtol=0.01, + golden_value_logits_atol=1.5, + golden_value_logits_rtol=0.01, + ) + + # ==================== Override methods for non-HF reference model ==================== + + def get_reference_model(self, dtype=torch.bfloat16, attn_implementation="flash_attention_2"): + """Load the EvolutionaryScale ESMC reference model.""" + from esm.models.esmc import ESMC + from esm.utils.constants.models import ESMC_300M + + model = ESMC.from_pretrained(ESMC_300M, device=torch.device("cuda")) + model.to(dtype) + model.eval() + return model + + def get_reference_model_no_weights(self): + """Create a reference ESMC model with random weights for conversion tests.""" + from esm.models.esmc import ESMC + from esm.tokenization import EsmSequenceTokenizer + + model = ESMC( + d_model=960, + n_heads=15, + n_layers=30, + tokenizer=EsmSequenceTokenizer(), + use_flash_attn=False, + ) + return model + + def get_converted_te_model_checkpoint(self) -> Path: + """Load ESMC, convert to TE, and save checkpoint. + + We override this to handle the non-HF model loading and to work on CPU + for memory efficiency. + """ + ref_model = self.get_reference_model(dtype=torch.bfloat16) + ref_state_dict = {k: v.cpu() for k, v in ref_model.state_dict().items()} + + del ref_model + gc.collect() + torch.cuda.empty_cache() + + config = NVEsmcConfig( + vocab_size=64, + hidden_size=960, + num_hidden_layers=30, + num_attention_heads=15, + intermediate_size=2560, + torch_dtype="bfloat16", + ) + + model_te = convert_esmc_to_te(ref_state_dict, config) + model_te.to("cpu") + + checkpoint_path: Path = self._tmp_dir / "converted_te_model" + model_te.save_pretrained(checkpoint_path) + + del model_te + gc.collect() + + return checkpoint_path + + def get_converted_te_model(self, **kwargs) -> PreTrainedModel: + """Get the converted TE model. + + Shim: the base class passes `dtype=` which works on transformers v5, but the esm + package pins transformers<4.53 where `from_pretrained` only accepts `torch_dtype=`. + This can be dropped when esm updates its transformers version constraint. + """ + if "dtype" in kwargs: + kwargs["torch_dtype"] = kwargs.pop("dtype") + return super().get_converted_te_model(**kwargs) + + # ==================== Override tests that don't apply to ESMC ==================== + + def test_convert_hf_to_te(self): + """Test conversion from ESMC ref to TE format.""" + model_ref = self.get_reference_model_no_weights() + converter = self.get_hf_to_te_converter() + model_te = converter(model_ref) + + assert model_te is not None + assert isinstance(model_te, NVEsmcForMaskedLM) + + def test_convert_te_to_hf(self): + """Test conversion from TE to ESMC ref format.""" + model_ref = self.get_reference_model_no_weights() + converter = self.get_hf_to_te_converter() + model_te = converter(model_ref) + + ref_state_dict = convert_esmc_te_to_ref(model_te) + assert ref_state_dict is not None + assert "embed.weight" in ref_state_dict + + def test_convert_te_to_hf_roundtrip(self): + """Test roundtrip conversion ESMC ref -> TE -> ESMC ref. + + With full d_model QK LayerNorm, all weights should roundtrip exactly. + The only non-exact weights are output projection and fc2 (due to residue + scaling absorption/removal via float division/multiplication). + """ + model_ref = self.get_reference_model_no_weights() + original_state_dict = {k: v.clone() for k, v in model_ref.state_dict().items()} + + # Forward: ref -> TE + converter = self.get_hf_to_te_converter() + model_te = converter(model_ref) + + # Reverse: TE -> ref + converted_state_dict = convert_esmc_te_to_ref(model_te) + + # Compare - all weights should roundtrip with high precision + for key in original_state_dict: + if key not in converted_state_dict: + continue + original = original_state_dict[key] + converted = converted_state_dict[key] + + torch.testing.assert_close( + original.float(), + converted.float(), + atol=1e-5, + rtol=1e-5, + msg=f"Roundtrip mismatch for {key}", + ) + + def test_convert_config(self): + """Test that ESMC config can be created properly.""" + config = NVEsmcConfig( + vocab_size=64, + hidden_size=960, + num_hidden_layers=30, + num_attention_heads=15, + intermediate_size=2560, + ) + assert config.hidden_size == 960 + assert config.num_hidden_layers == 30 + assert config.num_attention_heads == 15 + + def test_golden_values(self): + """Test that TE model produces matching outputs compared to ESMC reference model. + + Overrides the base class because the ESMC ref model has a non-HF API: + it takes (sequence_tokens, sequence_id) and returns .sequence_logits. + """ + tokenizer = self.get_tokenizer() + sequences = ["MKTVRQERLKSIVRILERSKEPV", "KALTARQQEVFDLIRDHISQTGMPPTRA"] + encodings = tokenizer(sequences, return_tensors="pt", padding=True) + input_ids = encodings["input_ids"].to("cuda") + attention_mask = encodings["attention_mask"].to("cuda") + + # Run reference model + ref_model = self.get_reference_model(dtype=torch.bfloat16) + ref_model.eval() + with torch.no_grad(): + sequence_id = input_ids != tokenizer.pad_token_id + ref_output = ref_model(sequence_tokens=input_ids, sequence_id=sequence_id) + ref_logits = ref_output.sequence_logits.detach().clone() + + del ref_model, ref_output + gc.collect() + torch.cuda.empty_cache() + + # Run TE model + model_te = self.get_converted_te_model(dtype=torch.bfloat16) + model_te.eval() + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + with torch.no_grad(): + te_output = model_te(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + + del model_te + gc.collect() + torch.cuda.empty_cache() + + # Verify outputs are finite + mask = attention_mask.bool() + assert torch.isfinite(te_output.logits[mask]).all(), "TE model produced non-finite logits" + assert torch.isfinite(te_output.loss), "TE model produced non-finite loss" + assert torch.isfinite(ref_logits[mask]).all(), "Reference model produced non-finite logits" + + # Compare logits + tolerances = self.get_tolerances() + torch.testing.assert_close( + te_output.logits[mask], + ref_logits[mask], + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + ) diff --git a/bionemo-recipes/models/esmc/tests/test_tokenizer.py b/bionemo-recipes/models/esmc/tests/test_tokenizer.py new file mode 100644 index 0000000000..44b7434961 --- /dev/null +++ b/bionemo-recipes/models/esmc/tests/test_tokenizer.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the ESMC HuggingFace-compatible tokenizer. + +Verifies that the custom PreTrainedTokenizerFast produces identical token IDs +to the EvolutionaryScale EsmSequenceTokenizer, while using standard HF +model_input_names (input_ids, attention_mask) for compatibility with +DataCollatorForLanguageModeling. +""" + +from pathlib import Path + +import pytest +from transformers import AutoTokenizer, DataCollatorForLanguageModeling, PreTrainedTokenizerFast + + +TOKENIZER_DIR = str(Path(__file__).resolve().parent.parent / "esmc_fast_tokenizer") + + +@pytest.fixture() +def hf_tokenizer(): + return AutoTokenizer.from_pretrained(TOKENIZER_DIR) + + +@pytest.fixture() +def ref_tokenizer(): + from esm.tokenization import EsmSequenceTokenizer + + return EsmSequenceTokenizer() + + +class TestEsmcTokenizer: + """Tests comparing the HF tokenizer against the ESM reference tokenizer.""" + + def test_loads_as_pretrained_tokenizer_fast(self, hf_tokenizer): + assert isinstance(hf_tokenizer, PreTrainedTokenizerFast) + + def test_model_input_names(self, hf_tokenizer): + assert hf_tokenizer.model_input_names == ["input_ids", "attention_mask"] + + def test_vocab_size(self, hf_tokenizer, ref_tokenizer): + assert hf_tokenizer.vocab_size == len(ref_tokenizer.vocab) + + def test_special_token_ids_match(self, hf_tokenizer, ref_tokenizer): + assert hf_tokenizer.pad_token_id == ref_tokenizer.pad_token_id + assert hf_tokenizer.cls_token_id == ref_tokenizer.cls_token_id + assert hf_tokenizer.eos_token_id == ref_tokenizer.eos_token_id + assert hf_tokenizer.unk_token_id == ref_tokenizer.unk_token_id + assert hf_tokenizer.mask_token_id == ref_tokenizer.mask_token_id + + def test_chain_break_token(self, hf_tokenizer, ref_tokenizer): + """Verify the chain break token | maps to the same ID.""" + hf_id = hf_tokenizer.convert_tokens_to_ids("|") + ref_id = ref_tokenizer.convert_tokens_to_ids("|") + assert hf_id == ref_id == 31 + + def test_single_sequence_tokenization(self, hf_tokenizer, ref_tokenizer): + seq = "MKTVRQERLKSIVRILERSKEPV" + hf_out = hf_tokenizer(seq) + ref_out = ref_tokenizer(seq) + assert hf_out["input_ids"] == ref_out["input_ids"] + + def test_batch_tokenization_with_padding(self, hf_tokenizer, ref_tokenizer): + sequences = [ + "MKTVRQERLKSIVRILERSKEPV", + "KALTARQQEVFDLIRDHISQTGMPPTRA", + "MFKVYGYDSNIHKCV", + ] + hf_out = hf_tokenizer(sequences, padding=True) + ref_out = ref_tokenizer(sequences, padding=True) + assert hf_out["input_ids"] == ref_out["input_ids"] + assert hf_out["attention_mask"] == ref_out["attention_mask"] + + def test_all_amino_acids(self, hf_tokenizer, ref_tokenizer): + """Verify all standard amino acid characters produce matching IDs.""" + amino_acids = "LAGVSERTIDPKQNFYMHWCXBUZO" # pragma: allowlist secret + for aa in amino_acids: + hf_id = hf_tokenizer.convert_tokens_to_ids(aa) + ref_id = ref_tokenizer.convert_tokens_to_ids(aa) + assert hf_id == ref_id, f"Mismatch for amino acid {aa}: HF={hf_id}, ref={ref_id}" + + def test_data_collator_compatibility(self, hf_tokenizer): + """Verify DataCollatorForLanguageModeling works without errors.""" + sequences = ["MKTVRQERLK", "KALTARQQEV", "MFKVYGYD"] + tokenized = [hf_tokenizer(seq) for seq in sequences] + + collator = DataCollatorForLanguageModeling(tokenizer=hf_tokenizer, mlm=False) + batch = collator(tokenized) + + assert "input_ids" in batch + assert "labels" in batch + assert "attention_mask" in batch + assert batch["input_ids"].shape[0] == 3 diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index a4281dc3eb..3b5302a908 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -37,12 +37,14 @@ ], "bionemo-recipes/models/esm2/src/esm/collator.py": [ "bionemo-recipes/models/llama3/collator.py", + "bionemo-recipes/models/esmc/collator.py", "bionemo-recipes/recipes/esm2_native_te/collator.py", "bionemo-recipes/recipes/llama3_native_te/collator.py", ], "bionemo-recipes/models/esm2/src/esm/state.py": [ "bionemo-recipes/models/amplify/src/amplify/state.py", "bionemo-recipes/models/llama3/state.py", + "bionemo-recipes/models/esmc/state.py", ], "bionemo-recipes/models/llama3/modeling_llama_te.py": [ "bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py", @@ -56,6 +58,7 @@ # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", + "bionemo-recipes/models/esmc/tests/common", ], }