From ef39689d3491e31341ef96ca0311a5dc3283f377 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Sat, 28 Feb 2026 07:39:27 -0800 Subject: [PATCH 1/3] refactor autoregressive model tests Signed-off-by: Peter St. John --- .../esm2/tests/common/test_modeling_common.py | 136 ++++++++- .../tests/common/test_modeling_common.py | 136 ++++++++- .../llama3/tests/test_modeling_llama_te.py | 271 ++---------------- .../models/mixtral/modeling_mixtral_te.py | 6 +- .../tests/common/test_modeling_common.py | 136 ++++++++- .../mixtral/tests/test_modeling_mixtral.py | 122 +------- 6 files changed, 418 insertions(+), 389 deletions(-) diff --git a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py index aca85b785..d04c9e5d2 100644 --- a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py index aca85b785..d04c9e5d2 100644 --- a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py b/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py index cd4998c17..caf0ddf24 100644 --- a/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py @@ -20,17 +20,13 @@ - LLaMA-specific tests (inference, generation, THD inputs, etc.) """ -import gc -import os from typing import Callable, Dict, List, Literal, Type import pytest import torch -import transformer_engine.pytorch import transformers from torch import nn from transformers import ( - AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, PretrainedConfig, @@ -41,7 +37,7 @@ from collator import DataCollatorWithFlattening from convert import convert_llama_hf_to_te, convert_llama_te_to_hf from modeling_llama_te import HFInferenceParams, NVLlamaConfig, NVLlamaForCausalLM -from tests.common import HAS_DATA_CENTER_GPU, BaseModelTest, TestTolerances +from tests.common import BaseModelTest, TestTolerances class TestLlama3Model(BaseModelTest): @@ -50,6 +46,8 @@ class TestLlama3Model(BaseModelTest): This class provides LLaMA3-specific configuration for the common test suite. """ + is_autoregressive = True + def get_model_class(self) -> Type[PreTrainedModel]: """Return the LLaMA3 TE model class.""" return NVLlamaForCausalLM @@ -139,7 +137,23 @@ def get_tolerances(self) -> TestTolerances: cp_loss_rtol=0.25, ) - # ==================== LLaMA3-Specific Tests ==================== + # ==================== LLaMA3-Specific Overrides ==================== + + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """Create HFInferenceParams for the given config.""" + past_key_values = HFInferenceParams( + max_batch_size=batch_size * num_beams, + max_sequence_length=max_seq_len, + num_heads_kv=config.num_key_value_heads, + head_dim_k=config.hidden_size // config.num_attention_heads, + dtype=torch.bfloat16, + qkv_format="thd", + max_ctx_len=max_seq_len, + ) + for layer_number in range(1, config.num_hidden_layers + 1): + past_key_values.allocate_memory(layer_number) + return past_key_values + def test_golden_values(self, input_format): # pyright: ignore[reportIncompatibleMethodOverride] """For llama3, we can test both the dynamic sequence packing and native bshd attention formats.""" model_hf = self.get_reference_model(dtype=torch.bfloat16) @@ -166,247 +180,6 @@ def test_golden_values(self, input_format): # pyright: ignore[reportIncompatibl @pytest.mark.parametrize("tie_word_embeddings", [True, False]) def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, tie_word_embeddings): # pyright: ignore[reportIncompatibleMethodOverride] """There was a weird bug in BIO-217 on tied weights with quantized model init, so we test both cases.""" - if input_format == "thd" and not HAS_DATA_CENTER_GPU: - pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") - - model_class = self.get_model_class() - config = self.create_test_config( - attn_input_format=input_format, - self_attn_mask_type="padding_causal", - tie_word_embeddings=tie_word_embeddings, + super().test_quantized_model_init_forward_and_backward( + fp8_recipe, input_format, tie_word_embeddings=tie_word_embeddings ) - - # Initialize with FP8 - with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): - model = model_class(config) - - model.to("cuda") - model.eval() - - # Prepare input data - input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) - if "labels" not in input_data: - input_data["labels"] = input_data["input_ids"].clone() - - # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) - - loss = outputs.loss - assert torch.isfinite(loss) - - loss.backward() - - # Verify gradients exist - for name, param in model.named_parameters(): - if param.requires_grad: - assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" - - -# NOTE: Keeping remaining LLaMA-specific tests (inference, generation, etc.) unchanged -# These tests are specific to causal LM functionality and don't have ESM2 equivalents - - -@pytest.fixture -def input_text(): - return ( - """Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License.""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore " - "et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip " - "ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu " - "fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt " - "mollit anim id est laborum.", - ) - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_llama_model_golden_values_padding_left(input_text): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(input_text, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - labels = inputs["input_ids"].clone() - labels[labels == tokenizer.pad_token_id] = -100 - model_hf.to("cuda") - with torch.no_grad(): - outputs_hf = model_hf(**inputs, labels=labels, output_hidden_states=True) - - del model_hf - gc.collect() - torch.cuda.empty_cache() - - model_te.to("cuda") - with torch.no_grad(): - outputs_te = model_te(**inputs, labels=labels, output_hidden_states=True) - - torch.testing.assert_close(outputs_te.loss, outputs_hf.loss, atol=0.02, rtol=0.03) # Higher than I'd like. - torch.testing.assert_close( - outputs_te.logits[inputs["attention_mask"].to(bool)], - outputs_hf.logits[inputs["attention_mask"].to(bool)], - atol=1.5, - rtol=0.01, - ) - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_hf_llama_model_generate_bshd(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_hf.to("cuda") - - with torch.no_grad(): - output_ids = model_hf.generate(**inputs, max_new_tokens=16, use_cache=False) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal") - - prompt = """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""" - - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - past_key_values = HFInferenceParams( - max_batch_size=1, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache_bshd(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - past_key_values = HFInferenceParams( - max_batch_size=2, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache_bshd_beam_search(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - num_beams = 2 - - past_key_values = HFInferenceParams( - max_batch_size=2 * num_beams, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=True, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] diff --git a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py index 4b6683813..6dd5a0a66 100644 --- a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py +++ b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py @@ -166,7 +166,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Permute tokens by expert using TE moe_permute permuted_hidden, row_id_map = transformer_engine.pytorch.moe_permute( - hidden_states, selected_experts, map_type="index" + hidden_states, selected_experts.to(torch.int32), map_type="index" ) # Compute m_splits: number of tokens per expert @@ -185,11 +185,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Down projection expert_output = self.experts_down(intermediate, m_splits=m_splits) # [total_tokens, H] - # Unpermute and combine with routing weights + # Unpermute and combine with routing weights (keep probs in float32 for numerical stability) output = transformer_engine.pytorch.moe_unpermute( expert_output, row_id_map, - merging_probs=routing_weights.to(expert_output.dtype), + merging_probs=routing_weights, map_type="index", ) diff --git a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py index aca85b785..d04c9e5d2 100644 --- a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py b/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py index 08e9f3091..e15645a3c 100644 --- a/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py +++ b/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py @@ -47,6 +47,8 @@ class TestMixtralModel(BaseModelTest): This class provides Mixtral-specific configuration for the common test suite. """ + is_autoregressive = True + def get_model_class(self) -> Type[PreTrainedModel]: """Return the Mixtral TE model class.""" return NVMixtralForCausalLM @@ -146,9 +148,7 @@ def get_tolerances(self) -> TestTolerances: cp_loss_rtol=0.25, ) - # ==================== Mixtral-Specific KV-Cache Tests ==================== - - def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): """Create HFInferenceParams for the given config.""" past_key_values = HFInferenceParams( max_batch_size=batch_size * num_beams, @@ -162,119 +162,3 @@ def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_be for layer_number in range(1, config.num_hidden_layers + 1): past_key_values.allocate_memory(layer_number) return past_key_values - - def test_generate_with_cache(self): - """Test single-prompt generation with KV-cache (THD format).""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompt = "The quick brown fox jumps over" - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - past_key_values = self._create_inference_params(config, batch_size=1) - - with torch.no_grad(): - output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - # Verify generation produced new tokens - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - def test_generate_with_cache_batched(self): - """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompts = ( - "The quick brown fox jumps over the lazy dog.", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - past_key_values = self._create_inference_params(config, batch_size=2) - - with torch.no_grad(): - output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - # Verify generation produced new tokens for both sequences - assert output_ids.shape[0] == 2 - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - def test_generate_with_cache_beam_search(self): - """Test batched generation with KV-cache and beam search.""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompts = ( - "The quick brown fox jumps over the lazy dog.", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - num_beams = 2 - past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) - - with torch.no_grad(): - output_ids = model.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=True, - ) - - # Verify generation produced new tokens for both sequences - assert output_ids.shape[0] == 2 - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - # ==================== Standalone Mixtral Generation Tests ==================== - - def test_te_mixtral_model_generate_with_cache_beam_search(self): - """Test Mixtral generation with KV-cache and beam search using real model weights.""" - import gc - - model_hf = self.get_reference_model() - model_te = convert_mixtral_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal") - del model_hf - gc.collect() - - model_te.to("cuda") - model_te.eval() - - tokenizer = self.get_tokenizer() - - prompts = ( - 'Licensed under the Apache License, Version 2.0 (the "License");' - " you may not use this file except in compliance with the License." - " You may obtain a copy of the License at", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - num_beams = 2 - config = model_te.config - past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=False, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua" in generated_text[1] From ebe0f0cd5df9f5b3b0693cbc5c19cc84f7d8f37d Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Sat, 28 Feb 2026 07:41:38 -0800 Subject: [PATCH 2/3] add qwen3 model Signed-off-by: Peter St. John --- bionemo-recipes/models/qwen3/.ruff.toml | 1 + bionemo-recipes/models/qwen3/collator.py | 1036 +++++++++++++++ bionemo-recipes/models/qwen3/convert.py | 211 ++++ bionemo-recipes/models/qwen3/export.py | 60 + .../models/qwen3/modeling_qwen3_te.py | 460 +++++++ bionemo-recipes/models/qwen3/requirements.txt | 5 + bionemo-recipes/models/qwen3/state.py | 724 +++++++++++ .../models/qwen3/tests/__init__.py | 14 + .../models/qwen3/tests/common/README.md | 64 + .../models/qwen3/tests/common/__init__.py | 45 + .../models/qwen3/tests/common/fixtures.py | 128 ++ .../tests/common/test_modeling_common.py | 1108 +++++++++++++++++ .../models/qwen3/tests/conftest.py | 31 + .../qwen3/tests/test_modeling_qwen3_te.py | 164 +++ .../qwen3/tests/test_te_qk_norm_dtype.py | 70 ++ ci/scripts/check_copied_files.py | 3 + 16 files changed, 4124 insertions(+) create mode 100644 bionemo-recipes/models/qwen3/.ruff.toml create mode 100644 bionemo-recipes/models/qwen3/collator.py create mode 100644 bionemo-recipes/models/qwen3/convert.py create mode 100644 bionemo-recipes/models/qwen3/export.py create mode 100644 bionemo-recipes/models/qwen3/modeling_qwen3_te.py create mode 100644 bionemo-recipes/models/qwen3/requirements.txt create mode 100644 bionemo-recipes/models/qwen3/state.py create mode 100644 bionemo-recipes/models/qwen3/tests/__init__.py create mode 100644 bionemo-recipes/models/qwen3/tests/common/README.md create mode 100644 bionemo-recipes/models/qwen3/tests/common/__init__.py create mode 100644 bionemo-recipes/models/qwen3/tests/common/fixtures.py create mode 100644 bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py create mode 100644 bionemo-recipes/models/qwen3/tests/conftest.py create mode 100644 bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py create mode 100644 bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py diff --git a/bionemo-recipes/models/qwen3/.ruff.toml b/bionemo-recipes/models/qwen3/.ruff.toml new file mode 100644 index 000000000..7e9a31bf5 --- /dev/null +++ b/bionemo-recipes/models/qwen3/.ruff.toml @@ -0,0 +1 @@ +extend = "../.ruff.toml" diff --git a/bionemo-recipes/models/qwen3/collator.py b/bionemo-recipes/models/qwen3/collator.py new file mode 100644 index 000000000..e83d719eb --- /dev/null +++ b/bionemo-recipes/models/qwen3/collator.py @@ -0,0 +1,1036 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data collators for sequence packing and context parallel training. + +This should eventually get moved to a separate package, or possibly upstreamed into `transformers`. +""" + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, TypedDict + +import datasets +import nvtx +import torch +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import pad_thd_sequences_for_cp +from transformers import DataCollator, DataCollatorForLanguageModeling + + +logger = logging.getLogger(__name__) + + +@dataclass +class DataCollatorWithFlattening: + """Data collator that wraps a DataCollatorForLanguageModeling and flattens inputs for flash-attention. + + This collator enables efficient training on batches containing variable-length sequences, by first flattening + (packing) multiple input sequences into a single contiguous tensor without padding between sequences. Then, it + applies masked language modeling (MLM) masking using the provided DataCollatorForLanguageModeling instance. + + The collator also generates metadata required for Flash Attention or context-parallel attention: + - `cu_seq_lens_q` and `cu_seq_lens_k` tensors, denoting cumulative sequence lengths so that sequence boundaries + within the packed tensor are known during attention computation. + + Optionally, the collator can: + - Pad the total number of tokens in the batch to be divisible by `pad_to_multiple_of` (by appending a mock + sequence). + - Pad each individual sequence to be divisible by `pad_sequences_to_be_divisible_by` if provided. + + Only PyTorch tensors (`return_tensors="pt"`) are supported. + + Args: + collator (DataCollatorForLanguageModeling): The collator to use for MLM masking. This is a captive + collator and should be constructed externally and passed in. + return_position_ids (bool): Whether to return position ids (default False). + pad_to_multiple_of (int, optional): If set, pads the total sequence length to be divisible by this number. + pad_sequences_to_be_divisible_by (int, optional): If set, each individual sequence is padded to this value. + separator_id (int, optional): A label to insert between sequences, typically should be -100 for causal LM. + + Example: + >>> from transformers import AutoTokenizer, DataCollatorForLanguageModeling + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") + >>> mlm_collator = DataCollatorForLanguageModeling(tokenizer) + >>> flat_collator = DataCollatorWithFlattening( + ... collator=mlm_collator, + ... pad_to_multiple_of=8, + ... ) + >>> + >>> # Input: variable length protein sequences + >>> sequences = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... {"input_ids": [0, 12, 13, 2]}, # 4 tokens + ... ] # Total: 15 tokens + >>> batch = flat_collator(sequences) + >>> print(batch['input_ids'].shape) # torch.Size([1, 16]) + >>> print(batch['labels'].shape) # torch.Size([1, 16]) + >>> print(batch['cu_seq_lens_q']) # tensor([0, 5, 11, 15, 16], dtype=torch.int32) + + Note: + The output is a THD-format (Total, Height, Depth) batch, where all input sequences are packed without + inter-sequence padding. Sequence boundaries are preserved using `cu_seq_lens_q`/`cu_seq_lens_k`, enabling + Flash Attention or context-parallelism without traditional attention masks. + """ + + collator: DataCollatorForLanguageModeling + return_position_ids: bool = False + pad_to_multiple_of: int | None = None + pad_sequences_to_be_divisible_by: int | None = None + separator_id: int | None = None + + def __post_init__(self): + """Ensure padding options are not used together.""" + if self.pad_sequences_to_be_divisible_by is not None and self.pad_to_multiple_of is not None: + raise ValueError("pad_sequences_to_be_divisible_by and pad_to_multiple_of cannot be used together") + + def __call__(self, features, return_tensors=None): + """Process a batch of variable-length sequences for Flash Attention with MLM. + + This method performs the following steps: + 1. Flattens multiple sequences into a single packed tensor with Flash Attention metadata + 2. Applies MLM masking to the flattened sequence while preserving special tokens + 3. Optionally pads to a multiple of a specified number for hardware optimization + + Args: + features (List[Dict[str, List[int]]]): List of tokenized sequences, each containing + 'input_ids' and optionally 'attention_mask'. Example: + [ + {"input_ids": [0, 5, 6, 7, 2]}, # Protein sequence 1 + {"input_ids": [0, 8, 9, 10, 11, 2]}, # Protein sequence 2 + {"input_ids": [0, 12, 13, 2]} # Protein sequence 3 + ] + return_tensors (str, optional): Format for returned tensors. Only "pt" (PyTorch) + is supported. Defaults to None (uses collator default). + + Returns: + Dict[str, torch.Tensor]: Batch dictionary containing: + - input_ids (torch.Tensor): Flattened and MLM-masked token sequences. + Shape: [1, total_tokens] where total_tokens = sum of all sequence lengths + (plus padding if pad_to_multiple_of is specified). + - labels (torch.Tensor): MLM labels with -100 for non-masked tokens and + original token IDs for masked positions. Same shape as input_ids. + - cu_seq_lens_q (torch.IntTensor): Cumulative sequence lengths for queries. + Shape: [num_sequences + 1] or [num_sequences + 2] if padding is added. + Example: [0, 5, 11, 15] or [0, 5, 11, 15, 16] with padding. + - cu_seq_lens_k (torch.IntTensor): Cumulative sequence lengths for keys. + Same as cu_seq_lens_q for self-attention. + - max_length_q (int): Maximum sequence length in the batch. + - max_length_k (int): Same as max_length_q for self-attention. + - attention_mask (torch.Tensor): Attention mask with 1s for actual tokens + and 0s for padding tokens (if any). + + Raises: + NotImplementedError: If return_tensors is not "pt". + + Example: + >>> # Input features + >>> features = [ + ... {"input_ids": [0, 5, 6, 7, 2]}, # 5 tokens + ... {"input_ids": [0, 8, 9, 10, 11, 2]}, # 6 tokens + ... ] + >>> + >>> batch = collator(features) + >>> + >>> # Output shapes and values + >>> batch['input_ids'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['labels'].shape # torch.Size([1, 11]) or larger if padded + >>> batch['cu_seq_lens_q'] # tensor([0, 5, 11], dtype=torch.int32) or larger + + Note: + The output is in THD (Total, Height, Depth) format with batch_size=1 and + sequence_length=total_tokens, optimized for Flash Attention's variable-length + sequence processing capabilities. When pad_to_multiple_of is used, an additional + mock sequence is appended to reach the desired total length. + """ + if return_tensors is not None and return_tensors != "pt": + raise NotImplementedError(f"Only return_tensors='pt' is supported, got '{return_tensors}'") + + # Perform the masking with the BSHD collator. + bshd_batch = self.collator(features, return_tensors=return_tensors) + + # Create the flattened batch to get the cu_seq_lens_q and cu_seq_lens_k values. + packed_batch = _pt_flatten_collate(features, return_position_ids=self.return_position_ids) + + # Get the masked input_ids and labels from the BSHD batch. + masked_input_ids = bshd_batch["input_ids"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + masked_labels = bshd_batch["labels"][bshd_batch["attention_mask"].bool()].unsqueeze(0) + + if self.separator_id is not None: + masked_labels[:, packed_batch["cu_seq_lens_q"][1:-1]] = self.separator_id + + # Update the packed batch with the masked input_ids and labels. + packed_batch["input_ids"] = masked_input_ids + packed_batch["labels"] = masked_labels + + if self.pad_to_multiple_of is not None: + packed_batch = self._pad_batch_to_multiple_of(packed_batch) + + elif self.pad_sequences_to_be_divisible_by is not None: + packed_batch = self._pad_sequences_to_be_divisible_by(packed_batch) + + return packed_batch + + def _pad_batch_to_multiple_of(self, batch): + """Add a mock sequence to make the total number of tokens divisible by pad_to_multiple_of.""" + # Ensure token_pad is an integer, defaulting to 1 if pad_token_id is None or invalid + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_to_multiple_of is not None, "pad_to_multiple_of must be set" + + return _pt_pad_to_multiple_of( + batch, + self.pad_to_multiple_of, + token_pad=pad_token_id, + label_pad=-100, + ) + + def _pad_sequences_to_be_divisible_by(self, batch): + """Pad individual sequences using cu_seq_lens_*_padded for context parallelism.""" + pad_token_id = self.collator.tokenizer.pad_token_id + if not isinstance(pad_token_id, int): + logger.warning(f"tokenizer.pad_token_id is not an integer, using 1 instead: {pad_token_id}") + pad_token_id = 1 + + assert self.pad_sequences_to_be_divisible_by is not None, "pad_sequences_to_be_divisible_by must be set" + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + batch["input_ids"], + batch["labels"], + batch["cu_seq_lens_q"], + self.pad_sequences_to_be_divisible_by, + padding_token_id=pad_token_id, + padding_label_id=-100, + ) + + batch["input_ids"] = input_ids_padded.unsqueeze(0) + batch["labels"] = labels_padded.unsqueeze(0) + batch["cu_seq_lens_q_padded"] = cu_seqlens_padded.to(torch.int32) + batch["cu_seq_lens_k_padded"] = cu_seqlens_padded.to(torch.int32) + batch["pad_between_seqs"] = True + return batch + + +@dataclass +class TokenPackingDataset(torch.utils.data.IterableDataset): + """Dataset that uses sequence packing to construct batches with variable length up to a maximum number of tokens.""" + + dataset: datasets.IterableDataset + """Dataset to pack.""" + max_tokens_per_batch: int + """Maximum number of tokens per batch.""" + drop_last: bool = True + """Whether to drop the last batch if it's less than max_length.""" + split_samples: bool = False + """Whether to split samples to ensure batches have exactly max_tokens_per_batch tokens.""" + pad_sequences_to_be_divisible_by: int | None = None + """If set, account for per-sequence padding when accumulating batches. + + Each sequence's contribution to the batch length is rounded up to the nearest multiple of this value, + matching the padding behavior of DataCollatorWithFlattening with the same parameter. When used with + split_samples=True, the split point is chosen so that the first part (after padding) exactly fills + the remaining batch capacity. + """ + + def __post_init__(self): + """Validate padding configuration.""" + if ( + self.pad_sequences_to_be_divisible_by is not None + and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 + ): + logger.warning( + "max_tokens_per_batch (%d) is not divisible by pad_sequences_to_be_divisible_by (%d). " + "Batches may not fill to exactly max_tokens_per_batch when split_samples=True.", + self.max_tokens_per_batch, + self.pad_sequences_to_be_divisible_by, + ) + + def _padded_len(self, length: int) -> int: + """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" + if self.pad_sequences_to_be_divisible_by is None: + return length + return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by + + def __iter__(self): + """Yield batches of samples, each with a variable number of tokens up to the maximum length. + + When split_samples=True, ensures each batch has exactly max_tokens_per_batch by splitting + the final sample if needed. The remaining tokens from the split sample start the next batch. + + When pad_sequences_to_be_divisible_by is set, each sequence's padded length is used when + accumulating batch sizes, so the total padded length of the batch matches max_tokens_per_batch. + + Returns: + A generator of batches of samples, each with a variable number of tokens up to the maximum length. + """ + samples = [] + current_length = 0 + for sample in iter(self.dataset): + sample_length = len(sample["input_ids"]) + padded_len = self._padded_len(sample_length) + if padded_len > self.max_tokens_per_batch: + raise ValueError( + f"TokenPackingDataset: Padded sample length ({padded_len}) exceeds max_tokens_per_batch " + f"({self.max_tokens_per_batch}). Set truncation or a maximum length in your tokenizer or dataset to" + " ensure all samples fit within max_tokens_per_batch." + ) + + current_length += padded_len + if current_length == self.max_tokens_per_batch: + yield [*samples, sample] + samples = [] + current_length = 0 + + elif current_length > self.max_tokens_per_batch: + if not self.split_samples: + # Yield the current batch (before this sample) and start a new one with this sample. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + # Calculate how many padded tokens are already in the batch. + tokens_in_batch = current_length - padded_len + # Calculate how many tokens we can fit from this sample, ensuring the + # padded length doesn't exceed the remaining capacity. + tokens_available = self.max_tokens_per_batch - tokens_in_batch + if self.pad_sequences_to_be_divisible_by is not None: + d = self.pad_sequences_to_be_divisible_by + tokens_available = (tokens_available // d) * d + if tokens_available <= 0: + # Remaining capacity is less than pad_sequences_to_be_divisible_by; + # can't fit any tokens from this sample. Yield current batch and start fresh. + if samples: + yield samples + samples = [sample] + current_length = padded_len + else: + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] + current_length = self._padded_len(len(samples[0]["input_ids"])) + else: + samples.append(sample) + + if not self.drop_last and samples: + yield samples + + def set_epoch(self, epoch: int): + """Set the epoch for the dataset.""" + self.dataset.set_epoch(epoch) + + +@dataclass +class DataCollatorForContextParallel: + """A collator that is aware of context parallelism. + + For the case of context parallelism, padded sequences will be returned from the wrapped collator, and then split + into shards for each context parallelism rank. + + The shards are then typically sent to the ContextParallelDataLoaderWrapper which will scatter them to the + appropriate GPUs. + + Note: + When used with the ContextParallelDataLoaderWrapper and both context parallelism and tensor parallelism are + used, the collator inspects the ordering of the mesh dimensions to determine the layout of the flattened batch. + + If "cp" comes before "tp" in the mesh dimension names (CP row-major), the flattened batch will be: + [(cp0, tp0), (cp0, tp1), ..., (cp1, tp0), (cp1, tp1), ...] + + If "tp" comes before "cp" (TP row-major), the flattened batch will be: + [(tp0, cp0), (tp0, cp1), ..., (tp1, cp0), (tp1, cp1), ...] + + Args: + collator: The collator to use for the batch. + device_mesh: The device mesh with named dimensions. Must contain either a "cp" dimension for context parallelism + and/or a "tp" dimension for tensor parallelism. + qkv_format: The format of the query-key-value (QKV) tensor. + is_causal_lm: Whether the collator is for a causal language model. If True, the labels will be shifted before + being split into CP shards, and will be returned in the `shift_labels` field. + + """ + + collator: DataCollator + device_mesh: torch.distributed.device_mesh.DeviceMesh + qkv_format: str = "thd" + is_causal_lm: bool = False + + # Derived fields, initialized in __post_init__. + cp_world_size: int = field(init=False) + tp_world_size: int | None = field(init=False) + _is_cp_row_major: bool = field(init=False) + + def __post_init__(self): + """Initialize the cp_world_size, tp_world_size, and _is_cp_row_major fields based on the device mesh.""" + dim_names = self.device_mesh.mesh_dim_names + if dim_names is None: + raise ValueError("device_mesh must have mesh_dim_names") + + self.cp_world_size = self.device_mesh.size(dim_names.index("cp")) if "cp" in dim_names else 1 + self.tp_world_size = self.device_mesh.size(dim_names.index("tp")) if "tp" in dim_names else None + + # Determine whether CP is the row (outer) dimension of the 2D mesh. + # When flattened, the row-major dimension's index changes slowest. + # If "cp" comes before "tp" in mesh_dim_names, CP is the row dimension. + if "cp" in dim_names and "tp" in dim_names: + self._is_cp_row_major = dim_names.index("cp") < dim_names.index("tp") + else: + self._is_cp_row_major = True + + def __call__(self, features) -> list[dict[str, Any]]: + """Process batches of data and create shards for each context parallelism rank. + + Args: + features: List of tokenized sequences, each containing 'input_ids' and optionally 'labels'. + + Returns: + A list of dictionaries, each containing a shard of the batch for a given context parallelism rank. + """ + batch = self.collator(features) + + # Remove the attention mask from the batch, it's not valid for CP. + batch.pop("attention_mask", None) + + if self.is_causal_lm: + labels = torch.nn.functional.pad(batch["labels"], (0, 1), value=-100) + batch["labels"] = labels[..., 1:].contiguous() + + combined_batch = [] + for cp_rank in range(self.cp_world_size): + input_ids_sharded, labels_sharded = _split_batch_by_cp_rank( + cu_seqlens_padded=batch.get("cu_seq_lens_q_padded", None), # This will be None for BSHD format. + input_ids_padded=batch["input_ids"], + labels_padded=batch["labels"], + qvk_format=self.qkv_format, + cp_rank=cp_rank, + cp_world_size=self.cp_world_size, + ) + batch_shard = dict(batch) + batch_shard["input_ids"] = input_ids_sharded + if self.is_causal_lm: + batch_shard["shift_labels"] = labels_sharded + batch_shard["labels"] = None + else: + batch_shard["labels"] = labels_sharded + # Now determine the max length of the sequence. + if self.qkv_format == "thd": + seqlens_q = batch_shard["cu_seq_lens_q_padded"][1:] - batch_shard["cu_seq_lens_q_padded"][:-1] + max_length = seqlens_q.max().item() + elif self.qkv_format == "bshd": + max_length = batch["input_ids"].shape[1] + else: + raise ValueError(f"Unsupported qvk_format: {self.qkv_format}!") + + batch_shard["max_length_k"] = batch_shard["max_length_q"] = ((max_length + 63) // 64) * 64 + combined_batch.append(batch_shard) + + if self.tp_world_size is not None: + # Replicate each CP shard for TP ranks. The ordering depends on which dimension forms the rows in the + # flattened mesh. + if self._is_cp_row_major: + # Flattened mesh: [(cp0,tp0), (cp0,tp1), (cp1,tp0), (cp1,tp1)] + # Output: [cp0, cp0, cp1, cp1] + combined_batch = [batch for batch in combined_batch for _ in range(self.tp_world_size)] + else: + # Flattened mesh: [(tp0,cp0), (tp0,cp1), (tp1,cp0), (tp1,cp1)] + # Output: [cp0, cp1, cp0, cp1] + combined_batch = [ + combined_batch[cp_rank] for _ in range(self.tp_world_size) for cp_rank in range(self.cp_world_size) + ] + + return combined_batch + + +class ContextParallelDataLoaderWrapper: + """A dataloader that is aware of context and tensor parallelism.""" + + def __init__( + self, + dataloader: torch.utils.data.DataLoader | None, + cp_tp_mesh: torch.distributed.device_mesh.DeviceMesh, + ): + """A dataloader wrapper that distributes the data across the context and tensor parallelism groups. + + This class materializes a single dataloader for each data parallel mesh rank, and splits / replicates the data + from this dataloader across the context and tensor parallelism groups. + + Args: + dataloader: The dataloader to use. + cp_tp_mesh: The context parallel mesh, or a flattened, combined context parallel and tensor parallel mesh. + If a flattened mesh is provided, the cp / tp dimensions should be in the order they appeared in the + mesh_dim_names as passed to DataCollatorForContextParallel. + """ + if cp_tp_mesh.get_local_rank() == 0: + assert dataloader is not None, "dataloader must be provided on rank 0" + self.dataloader = dataloader + + else: + assert dataloader is None, "Dataloader on non-rank 0 will not be used" + + self.cp_tp_rank = cp_tp_mesh.get_local_rank() + self.cp_tp_group = cp_tp_mesh.get_group() + self.num_cp_tp_ranks = cp_tp_mesh.size() + self._iterator = None + self._prefetch_thread: threading.Thread | None = None + self._prefetch_result: Any = None + self._cuda_device: int | None = None + + logger.debug( + "Created ContextParallelDataLoaderWrapper on global rank %s, cp rank %s", + torch.distributed.get_rank() if torch.distributed.is_initialized() else "", + self.cp_tp_rank, + ) + + def __iter__(self): + """Make the dataloader iterable.""" + if self.cp_tp_rank == 0: + self._iterator = iter(self.dataloader) # < --- collator output. + self.close() + # Capture CUDA device from main thread; torch.cuda.set_device is per-thread, + # so the background thread needs to set it explicitly. + self._cuda_device = torch.cuda.current_device() if torch.cuda.is_available() else None + self._kick_prefetch() + return self + + @nvtx.annotate("ContextParallelDataLoaderWrapper __next__", color="blue") + def __next__(self): + """Get the batch from the dataloader for the current CP rank.""" + if self._prefetch_thread is not None: + self._prefetch_thread.join() + result = self._prefetch_result + if isinstance(result, Exception): + self._prefetch_thread = None + raise result + self._kick_prefetch() + return result + + def _kick_prefetch(self): + """Start a background thread to prefetch exactly one batch via scatter.""" + self._prefetch_thread = threading.Thread(target=self._do_one_prefetch, daemon=True) + self._prefetch_thread.start() + + def _do_one_prefetch(self): + """Fetch one batch in the background. + + This function calls the _send_data_to_cp_tp_ranks function to materialize the next batches for all ranks in the + given CP/TP group, and uses torch.distributed.scatter_object_list to scatter these batches to their + corresponding ranks. The result is stored in _prefetch_result, and returned when __next__ is called. + """ + if self._cuda_device is not None: + torch.cuda.set_device(self._cuda_device) + try: + self._prefetch_result = self._send_data_to_cp_tp_ranks() + except StopIteration as e: + self._prefetch_result = e + except Exception as e: + self._prefetch_result = e + + def close(self): + """Stop the prefetch thread. Must be called before destroy_process_group().""" + if self._prefetch_thread is not None: + self._prefetch_thread.join(timeout=10) + self._prefetch_thread = None + + @nvtx.annotate("ContextParallelDataLoaderWrapper _send_data_to_cp_tp_ranks", color="green") + def _send_data_to_cp_tp_ranks(self): + """Send data to all the CP/TP ranks. + + This function will get the batch from the dataloader on CP rank 0, and then determine + the shards for all the different CP group members. + combined_batch = [, , ..., ] + Then it will scatter the shards to the different CP group members. + The shards are then combined into a single batch and returned to the caller + for the current CP rank. + + If tensor parallelism is also being used, the combined batch will look like: + combined_batch = [, , ..., , , ...] + where there are cp_world_size shards, and each shard is replicated tp_world_size times. The ordering of the + shards depends on which dimension forms the rows in the flattened mesh. + + Scalability: + Rank 0's work grows linearly with CP size, but the other ranks do not need to store all the shards so they + do not grow linearly with CP size. + + Args: + None + + Returns: + batch: The batch for the current CP/TP rank. + + """ + try: + with nvtx.annotate("ContextParallelDataLoaderWrapper next batch", color="green"): + combined_batch = next(self._iterator) if self.cp_tp_rank == 0 else None + except StopIteration as ex: + # If we encounter a StopIteration in the dataloader, we want to raise this error on all the CP ranks, so + # that the dataloader can be restarted. + combined_batch = [ex] * self.num_cp_tp_ranks + + batch_on_this_rank = _scatter_batch_to_cp_tp_ranks(combined_batch, self.cp_tp_group) + + if isinstance(batch_on_this_rank, StopIteration): + raise batch_on_this_rank + + return batch_on_this_rank + + def state_dict(self): + """Get the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return {} + elif hasattr(self.dataloader, "state_dict"): + return {"dataloader": self.dataloader.state_dict()} + else: + logger.warning( + "Attempting to get the state dict of the dataloader, but the dataloader does not support state_dict, " + "returning empty dict" + ) + return {"dataloader": {}} + + def load_state_dict(self, state_dict): + """Load the state dict by delegating to the dataloader.""" + if self.cp_tp_rank != 0: + return + elif hasattr(self.dataloader, "load_state_dict"): + self.dataloader.load_state_dict(state_dict["dataloader"]) + else: + logger.warning( + "Attempting to load the state dict of the dataloader, but the dataloader does not support " + "load_state_dict, returning without loading the state dict." + ) + return + + @property + def num_workers(self): + """Get the number of workers of the dataloader.""" + if self.cp_tp_rank != 0: + return 0 + else: + return self.dataloader.num_workers + + +def _split_sample_by_num_tokens(sample: dict[str, Any], num_tokens: int) -> tuple[dict[str, Any], dict[str, Any]]: + """Split a sample dictionary at a specified number of tokens. + + This function splits a sample into two parts: the first part contains exactly `num_tokens` tokens, + and the second part contains the remaining tokens. All fields that are sequences (input_ids, attention_mask, + token_type_ids, labels, etc.) are split accordingly. + + Args: + sample: Dictionary containing sample data with fields like input_ids, attention_mask, token_type_ids, labels, etc. + num_tokens: Number of tokens to include in the first part of the split. + + Returns: + A tuple of two dictionaries: (first_part, remaining_part), where: + - first_part contains the first `num_tokens` tokens from each sequence field + - remaining_part contains the remaining tokens from each sequence field + + Example: + >>> sample = { + ... "input_ids": [0, 5, 6, 7, 8, 9, 2], + ... "attention_mask": [1, 1, 1, 1, 1, 1, 1], + ... "labels": [0, 5, 6, 7, 8, 9, 2] + ... } + >>> first, remaining = split_sample_by_num_tokens(sample, 3) + >>> first["input_ids"] # [0, 5, 6] + >>> remaining["input_ids"] # [7, 8, 9, 2] + """ + sample_length = len(sample["input_ids"]) + if num_tokens >= sample_length: + raise ValueError( + f"num_tokens ({num_tokens}) must be less than sample length ({sample_length}) to split the sample" + ) + if num_tokens <= 0: + raise ValueError(f"num_tokens ({num_tokens}) must be positive") + + first_part = {} + remaining_part = {} + + # Fields that should be split by tokens (sequence fields) + sequence_fields = ["input_ids", "attention_mask", "token_type_ids", "token_type", "labels"] + + for key, value in sample.items(): + if key in sequence_fields: + # Handle both list and tensor inputs + if isinstance(value, torch.Tensor): + first_part[key] = value[:num_tokens].clone() + remaining_part[key] = value[num_tokens:].clone() + elif isinstance(value, list): + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + else: + # For other types, try to slice if possible + try: + first_part[key] = value[:num_tokens] + remaining_part[key] = value[num_tokens:] + except (TypeError, IndexError): + # If slicing doesn't work, copy the value to both parts + # This handles fields that shouldn't be split (like metadata) + first_part[key] = value + remaining_part[key] = value + else: + # For non-sequence fields, copy to both parts + # This handles metadata fields that shouldn't be split + first_part[key] = value + remaining_part[key] = value + + return first_part, remaining_part + + +def _pt_flatten_collate(features: list[dict[str, list[int]]], return_position_ids: bool = False): + """Flatten a list of tokenized samples into a single packed batch with cumulative sequence lengths. + + Args: + features: List of tokenized samples, each containing at least ``input_ids``. + return_position_ids: Whether to return position ids for each token. + + Returns: + A dictionary with packed ``input_ids``, ``cu_seq_lens_q``/``cu_seq_lens_k``, and + ``max_length_q``/``max_length_k``. + """ + is_labels_provided = "labels" in features[0] + sample_lengths = [len(sample["input_ids"]) for sample in features] + + batch = {} + batch["max_length_q"] = batch["max_length_k"] = max(sample_lengths) + batch["input_ids"] = torch.tensor( + [[token for sample in features for token in sample["input_ids"]]], dtype=torch.int64 + ) + if is_labels_provided: + batch["labels"] = torch.tensor( + [[label for sample in features for label in sample["labels"]]], dtype=torch.int64 + ) + cu_seq_lens = torch.zeros(len(features) + 1, dtype=torch.int32) + cu_seq_lens[1:] = torch.cumsum(torch.tensor(sample_lengths), dim=0, dtype=torch.int32) + batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens + if "attention_mask" in features[0]: + batch["attention_mask"] = torch.tensor( + [[v for sample in features for v in sample["attention_mask"]]], dtype=torch.int64 + ) + if return_position_ids: + batch["position_ids"] = torch.hstack( + [torch.arange(sample_len, dtype=torch.int64) for sample_len in sample_lengths] + ).unsqueeze(0) + + return batch + + +def _find_seq_dim(tensor: torch.Tensor, seq_len: int) -> int: + """Find which dimension of tensor matches the expected sequence length. + + Args: + tensor: The tensor to inspect. + seq_len: The expected sequence length to match against tensor dimensions. + + Returns: + The dimension index that matches the sequence length. + + Raises: + ValueError: If no dimension matches the expected sequence length. + """ + if tensor.ndim == 1: + if tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"1D tensor shape {tensor.shape} doesn't match sequence length {seq_len}") + elif tensor.ndim >= 2: + if tensor.shape[1] == seq_len: + return 1 + elif tensor.shape[0] == seq_len: + return 0 + raise ValueError(f"Tensor shape {tensor.shape} doesn't match sequence length {seq_len} in dim 0 or 1") + raise ValueError(f"Unexpected tensor ndim={tensor.ndim}") + + +def _process_tensor_thd( + val: torch.Tensor | None, + seq_len: int, + slice_sizes: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + cp_rank: int, + total_slices: int, +) -> torch.Tensor | None: + """Extract the THD context-parallel shard for a single tensor. + + For each sequence in the batch, selects two slices (one from the beginning and one from the end) + corresponding to the given CP rank, following the zigzag CP sharding pattern. + + Args: + val: The tensor to shard, or None (returned as-is). + seq_len: Total sequence length (from cu_seqlens_padded[-1]). + slice_sizes: Per-sequence slice sizes, computed as sequence_lengths // total_slices. + cu_seqlens_padded: Cumulative sequence lengths including padding. + cp_rank: The context parallelism rank index. + total_slices: Total number of slices per sequence (2 * cp_world_size). + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + """ + if val is None: + return val + + seq_dim = _find_seq_dim(val, seq_len) + + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices - cp_rank - 1) * slice_size), + seq_start + ((total_slices - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(seq_dim, torch.cat(cp_rank_slices)) + + +def _process_tensor_bshd( + val: torch.Tensor | None, + cp_rank: int, + cp_world_size: int, +) -> torch.Tensor | None: + """Extract the BSHD context-parallel shard for a single tensor. + + Splits a BSHD-format tensor along the sequence dimension (dim=1) into 2*cp_world_size chunks, + then selects the two chunks corresponding to the given CP rank (zigzag pattern). + + Args: + val: The tensor to shard, or None (returned as-is). + cp_rank: The context parallelism rank index. + cp_world_size: Total number of context parallelism ranks. + + Returns: + The sharded tensor for the given CP rank, or None if val is None. + + Raises: + ValueError: If the tensor has fewer than 2 dimensions or its sequence length + is not divisible by 2 * cp_world_size. + """ + if val is None: + return val + + if val.ndim < 2: + raise ValueError(f"BSHD format requires at least 2D tensors, got {val.ndim}D") + + seq_len = val.shape[1] + + # Calculate chunk size + total_chunks = 2 * cp_world_size + chunk_size = seq_len // total_chunks + + if seq_len % total_chunks != 0: + raise ValueError( + f"Sequence length {seq_len} must be divisible by {total_chunks} " + f"(2 * cp_world_size) for BSHD context parallelism" + ) + + # Determine which chunks this rank should get + # Rank 0 gets chunks [0, total_chunks-1] + # Rank 1 gets chunks [1, total_chunks-2] + # Rank k gets chunks [k, total_chunks-k-1] + chunk_indices = [cp_rank, total_chunks - cp_rank - 1] + + # Collect slices for this rank + rank_slices = [] + for chunk_idx in chunk_indices: + start_idx = chunk_idx * chunk_size + end_idx = start_idx + chunk_size + rank_slices.append(torch.arange(start_idx, end_idx, device=val.device)) + + # Concatenate indices for all chunks this rank should get + indices = torch.cat(rank_slices) + + # Select along sequence dimension (dim=1) + return val.index_select(1, indices) + + +def _pt_pad_to_multiple_of(batch: dict[str, Any], pad_to_multiple_of: int, token_pad: int, label_pad: int): + """Pad a batch to a multiple of pad_to_multiple_of. + + Appends a mock sequence to the end of the batch with the given token_pad and label_pad to make the total number of + tokens divisible by pad_to_multiple_of. + + Args: + batch: Input batch, possibly containing labels and/or cu_seq_lens / max_length keys. + pad_to_multiple_of: Multiple to pad to. + token_pad: Token to pad with. + label_pad: Label to pad with. + + Returns: + Batch dictionary with padded input_ids, labels, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k. + """ + # Number of tokens we need to pad to make the total number of tokens divisible by pad_to_multiple_of + remainder = -batch["input_ids"].numel() % pad_to_multiple_of + + if remainder == 0: + return batch + + batch["input_ids"] = torch.cat( + [batch["input_ids"], torch.full((1, remainder), token_pad, dtype=batch["input_ids"].dtype)], dim=1 + ) + + if "labels" in batch: + batch["labels"] = torch.cat( + [batch["labels"], torch.full((1, remainder), label_pad, dtype=batch["labels"].dtype)], dim=1 + ) + + if "cu_seq_lens_q" in batch: + batch["cu_seq_lens_q"] = torch.cat( + [ + batch["cu_seq_lens_q"], + torch.tensor([batch["cu_seq_lens_q"][-1] + remainder], dtype=batch["cu_seq_lens_q"].dtype), + ], + dim=0, + ) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + if "max_length_q" in batch: + batch["max_length_q"] = max(batch["max_length_q"], remainder) + batch["max_length_k"] = batch["max_length_q"] + + if "attention_mask" in batch: + batch["attention_mask"] = torch.cat( + [batch["attention_mask"], torch.zeros((1, remainder), dtype=batch["attention_mask"].dtype)], dim=1 + ) + + if "position_ids" in batch: + batch["position_ids"] = torch.cat( + [batch["position_ids"], torch.arange(remainder, dtype=batch["position_ids"].dtype).unsqueeze(0)], dim=1 + ) + + return batch + + +# TODO(@jomitchell): Once this gets merged: https://github.com/NVIDIA/TransformerEngine/pull/2387 +# we can replace this with the one in TransformerEngine. +@nvtx.annotate("collator._split_batch_by_cp_rank", color="green") +def _split_batch_by_cp_rank( + cu_seqlens_padded: torch.Tensor | None, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup | None = None, + qvk_format: str = "thd", + cp_rank: int | None = None, + cp_world_size: int | None = None, +): + """Slice batch input along sequence dimension into multiple chunks for THD or BSHD format. + + This function is intended for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths for THD format, + and with padded sequences for BSHD format. + + Args: + cu_seqlens_padded: Cumulative sequence length. Required for THD format, optional for BSHD format. + input_ids_padded: Input IDs. + labels_padded: Labels. + cp_group: Context parallel group. + qvk_format: Format of the input data ("thd" or "bshd"). + cp_world_size: The size of the context parallelism group. If provided, the function will use this value to determine the rank. + cp_rank: Optional manual CP rank index. When provided, the function shards tensors as if it + were executing on that rank without querying `torch.distributed.get_rank`. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + + if cp_world_size is None or cp_world_size <= 1: + # No splitting needed + return input_ids_padded, labels_padded + + if cp_rank is None: + cp_rank = torch.distributed.get_rank(group=cp_group) + elif not (0 <= cp_rank < cp_world_size): + raise ValueError(f"cp_rank must be in [0, {cp_world_size}), but received {cp_rank}.") + + if qvk_format == "thd": + if cu_seqlens_padded is None: + raise ValueError("cu_seqlens_padded is required for THD format") + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_world_size + slice_sizes = (cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]) // total_slices_of_any_sequence + + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + last_elem = cu_seqlens_padded[-1] + seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem + + input_ids_padded = _process_tensor_thd( + input_ids_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + labels_padded = _process_tensor_thd( + labels_padded, seq_len_val, slice_sizes, cu_seqlens_padded, cp_rank, total_slices_of_any_sequence + ) + + elif qvk_format == "bshd": + input_ids_padded = _process_tensor_bshd(input_ids_padded, cp_rank, cp_world_size) + labels_padded = _process_tensor_bshd(labels_padded, cp_rank, cp_world_size) + + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded + + +class BatchType(TypedDict): + """The fields in the batch dictionary for THD context parallel.""" + + input_ids: torch.Tensor + labels: torch.Tensor | None + shift_labels: torch.Tensor | None + cu_seq_lens_q: torch.Tensor + cu_seq_lens_k: torch.Tensor + cu_seq_lens_q_padded: torch.Tensor + cu_seq_lens_k_padded: torch.Tensor + max_length_q: int + max_length_k: int + pad_between_seqs: bool + + +@nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") +def _scatter_batch_to_cp_tp_ranks( + all_batches: list[BatchType] | list[StopIteration], cp_tp_group: torch.distributed.ProcessGroup | None = None +) -> BatchType | StopIteration: + """Scatter a batch to all the CP ranks. + + Args: + all_batches (list[BatchType] | list[StopIteration]): A list of already-sharded batches to scatter to the CP/TP + ranks. + cp_tp_group (torch.distributed.ProcessGroup | None): The process group to scatter the batches to. + + Returns: + BatchType | StopIteration: The batch on this rank. + """ + scatter_object_output_list = [None] + # Note: This does not provide an async_op handle. Thus its blocking. + torch.distributed.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=all_batches, + group=cp_tp_group, + group_src=0, + ) + return scatter_object_output_list[0] diff --git a/bionemo-recipes/models/qwen3/convert.py b/bionemo-recipes/models/qwen3/convert.py new file mode 100644 index 000000000..58a694d2a --- /dev/null +++ b/bionemo-recipes/models/qwen3/convert.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion utilities between HuggingFace Qwen3 and TransformerEngine formats.""" + +import inspect + +import torch +from transformers import Qwen3Config, Qwen3ForCausalLM + +import state +from modeling_qwen3_te import NVQwen3Config, NVQwen3ForCausalLM + + +mapping = { + "model.embed_tokens.weight": "model.embed_tokens.weight", + "model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight", + "model.layers.*.self_attn.q_norm.weight": "model.layers.*.self_attention.q_norm.weight", + "model.layers.*.self_attn.k_norm.weight": "model.layers.*.self_attention.k_norm.weight", + "model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight", + "model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight", + "model.norm.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", +} + +# Reverse mapping from TE to HF format by reversing the original mapping +reverse_mapping = {v: k for k, v in mapping.items()} + + +def _merge_qkv(ctx: state.TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge q, k, v to interleave-concatenated qkv. + + This version uses config.head_dim instead of hidden_size // num_attention_heads, + which is necessary for Qwen3 where head_dim is independently configured. + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = target_config.head_dim + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +def _split_qkv(ctx: state.TransformCTX, linear_qkv: torch.Tensor): + """Split interleave-concatenated qkv to q, k, v. + + This version uses config.head_dim instead of hidden_size // num_attention_heads, + which is necessary for Qwen3 where head_dim is independently configured. + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + head_size = target_config.head_dim + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + hidden_size = linear_qkv.size(-1) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +def convert_qwen3_hf_to_te(model_hf: Qwen3ForCausalLM, **config_kwargs) -> NVQwen3ForCausalLM: + """Convert a Hugging Face model to a Transformer Engine model. + + Args: + model_hf (nn.Module): The Hugging Face model. + **config_kwargs: Additional configuration kwargs to be passed to NVQwen3Config. + + Returns: + nn.Module: The Transformer Engine model. + """ + te_config = NVQwen3Config(**model_hf.config.to_dict(), **config_kwargs) + with torch.device("meta"): + model_te = NVQwen3ForCausalLM(te_config) + + if model_hf.config.tie_word_embeddings: + state_dict_ignored_entries = ["lm_head.weight"] + else: + state_dict_ignored_entries = [] + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + state.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="model.layers.*.self_attention.layernorm_qkv.weight", + fn=_merge_qkv, + ), + state.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="model.layers.*.layernorm_mlp.fc1_weight", + fn=state.TransformFns.merge_fc1, + ), + ], + state_dict_ignored_entries=state_dict_ignored_entries, + ) + + output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone() + + return output_model + + +def convert_qwen3_te_to_hf(model_te: NVQwen3ForCausalLM, **config_kwargs) -> Qwen3ForCausalLM: + """Convert a Transformer Engine model to a Hugging Face model. + + Args: + model_te (nn.Module): The Transformer Engine model. + **config_kwargs: Additional configuration kwargs to be passed to Qwen3Config. + + Returns: + nn.Module: The Hugging Face model. + """ + # Filter out keys from model_te.config that are not valid Qwen3Config attributes + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(Qwen3Config.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = Qwen3Config(**filtered_config, **config_kwargs) + + with torch.device("meta"): + model_hf = Qwen3ForCausalLM(hf_config) + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [ + state.state_transform( + source_key="model.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + fn=_split_qkv, + ), + state.state_transform( + source_key="model.layers.*.layernorm_mlp.fc1_weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=state.TransformFns.split_fc1, + ), + ], + state_dict_ignored_entries=model_hf._tied_weights_keys, + ) + + output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone() + output_model.tie_weights() + + return output_model diff --git a/bionemo-recipes/models/qwen3/export.py b/bionemo-recipes/models/qwen3/export.py new file mode 100644 index 000000000..944013408 --- /dev/null +++ b/bionemo-recipes/models/qwen3/export.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create a Qwen3 checkpoint for export. + +This script saves a randomly initialized Qwen3 model with TransformerEngine layers. +""" + +import json +import shutil +from pathlib import Path + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import convert +from modeling_qwen3_te import AUTO_MAP + + +def export_hf_checkpoint(tag: str, export_path: Path): + """Export a Hugging Face checkpoint to a Transformer Engine checkpoint. + + Args: + tag: The tag of the checkpoint to export. + export_path: The parent path to export the checkpoint to. + """ + model_hf = AutoConfig.from_pretrained(tag) + model_hf = AutoModelForCausalLM.from_config(model_hf) + + model_te = convert.convert_qwen3_hf_to_te(model_hf) + model_te.save_pretrained(export_path) + + tokenizer = AutoTokenizer.from_pretrained(tag) + tokenizer.save_pretrained(export_path) + + # Patch the config + with open(export_path / "config.json", "r") as f: + config = json.load(f) + + config["auto_map"] = AUTO_MAP + + with open(export_path / "config.json", "w") as f: + json.dump(config, f, indent=2, sort_keys=True) + + shutil.copy("modeling_qwen3_te.py", export_path / "modeling_qwen3_te.py") + + +if __name__ == "__main__": + export_hf_checkpoint("Qwen/Qwen3-0.6B", Path("checkpoint_export")) diff --git a/bionemo-recipes/models/qwen3/modeling_qwen3_te.py b/bionemo-recipes/models/qwen3/modeling_qwen3_te.py new file mode 100644 index 000000000..d86435924 --- /dev/null +++ b/bionemo-recipes/models/qwen3/modeling_qwen3_te.py @@ -0,0 +1,460 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TransformerEngine-optimized Qwen3 model.""" + +from collections import OrderedDict +from typing import ClassVar, Unpack + +import torch +import torch.nn as nn +import transformer_engine.pytorch +import transformers +from transformer_engine.pytorch.attention import InferenceParams +from transformer_engine.pytorch.attention.inference import PagedKVCacheManager +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import PreTrainedModel, Qwen3Config +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +AUTO_MAP = { + "AutoConfig": "modeling_qwen3_te.NVQwen3Config", + "AutoModel": "modeling_qwen3_te.NVQwen3Model", + "AutoModelForCausalLM": "modeling_qwen3_te.NVQwen3ForCausalLM", +} + + +class NVQwen3Config(Qwen3Config): + """NVQwen3 configuration.""" + + # Attention input format: + # "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + + +class NVQwen3PreTrainedModel(PreTrainedModel): + """Base class for NVQwen3 models.""" + + config_class = NVQwen3Config + base_model_prefix = "model" + _no_split_modules = ("TransformerLayer",) + _skip_keys_device_placement = ("past_key_values",) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + + self.model.rotary_emb.inv_freq = Qwen3RotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling this method for TE modules, since the default _init_weights will assume + # any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking `LayerNormLinear` + # and `LayerNormMLP` modules that use `weight` for the linear layer and `layer_norm_weight` for the layer + # norm. + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with + standard PyTorch/HuggingFace model loading. These are filtered out to ensure + checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVQwen3Model(NVQwen3PreTrainedModel): + """Qwen3 model implemented in Transformer Engine.""" + + def __init__(self, config: Qwen3Config): + """Initialize the NVQwen3 model.""" + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + activation="swiglu", + attn_input_format=config.attn_input_format, + self_attn_mask_type=config.self_attn_mask_type, + num_gqa_groups=config.num_key_value_heads, + kv_channels=config.head_dim, + qk_norm_type="RMSNorm", + qk_norm_eps=config.rms_norm_eps, + qk_norm_before_rope=True, + window_size=(config.sliding_window, config.sliding_window) + if config.layer_types[layer_idx] == "sliding_attention" + else None, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + # Patch for TE not passing the correct dtype to the RMSNorm layer. To be fixed in NVIDIA/TransformerEngine#2718. + for layer in self.layers: + layer.self_attention.q_norm.weight.data = layer.self_attention.q_norm.weight.data.to(config.dtype) + layer.self_attention.k_norm.weight.data = layer.self_attention.k_norm.weight.data.to(config.dtype) + + # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original + # Qwen3RotaryEmbedding. Use head_dim (not hidden_size // num_attention_heads) since Qwen3 has + # independently configured head_dim. + self.rotary_emb = RotaryPositionEmbedding(config.head_dim) + self.rotary_emb.inv_freq = Qwen3RotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: InferenceParams | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass for the NVQwen3 model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + use_cache (bool): Whether to use cache. + **kwargs: Additional keyword arguments. + + Returns: + BaseModelOutputWithPast: The output of the model. + """ + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # TE-specific input handling. + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + # Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we + # dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output. + # This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused + # attention backend, but it should be faster for the flash attention backend. + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend) + attention_mask = ~attention_mask[:, None, None, :].bool() + + if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching. + # In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to + # compute the lengths of each sequence in the batch. + lengths = ( + attention_mask.sum(dim=1).tolist() + if attention_mask.shape == input_ids.shape + else [1] * input_ids.shape[0] + ) + past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) + + # Ensure that rotary embeddings are computed with at a higher precision + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad + # these with the same _pad_input call as below if we wanted them returned in BSHD format. + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output. + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class NVQwen3ForCausalLM(NVQwen3PreTrainedModel, transformers.GenerationMixin): + """Qwen3 model with causal language head.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config): + """Initialize the NVQwen3ForCausalLM model.""" + super().__init__(config) + self.model = NVQwen3Model(config) + self.vocab_size = config.vocab_size + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + use_cache: bool | None = None, + cache_position: torch.Tensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + """Forward pass for the NVQwen3ForCausalLM model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + labels (torch.Tensor): The labels. + shift_labels (torch.Tensor): Labels that have already been shifted by the dataloader, to be used instead of + labels for the loss function. For context parallelism, it is more reliable to shift the labels before + splitting the batch into shards. + use_cache (bool): Whether to use cache. + cache_position (torch.Tensor): The cache position. + logits_to_keep (int | torch.Tensor): Whether to keep only the last logits to reduce the memory footprint of + the model during generation. + **kwargs: Additional keyword arguments. + + Returns: + CausalLMOutputWithPast: The output of the model. + """ + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: # With THD inputs, batch and sequence dimensions are collapsed in the first dimension. + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +torch._dynamo.config.capture_scalar_outputs = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to a BSHD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to a THD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: # Likely in generation mode with kv-caching + return ( + hidden_states.squeeze(1), # hidden_states + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), # indices + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), # cu_seqlens + 1, # max_seqlen + 1, # seqused + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +class HFInferenceParams(InferenceParams): + """Extension of the InferenceParams class to support beam search.""" + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache based on the beam indices.""" + if isinstance(self.cache_manager, PagedKVCacheManager): + raise NotImplementedError("Beam search is not supported for paged cache manager.") + for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items(): + updated_key_cache = key_cache.index_select(0, beam_idx) + updated_value_cache = value_cache.index_select(0, beam_idx) + self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) diff --git a/bionemo-recipes/models/qwen3/requirements.txt b/bionemo-recipes/models/qwen3/requirements.txt new file mode 100644 index 000000000..ec6a547cb --- /dev/null +++ b/bionemo-recipes/models/qwen3/requirements.txt @@ -0,0 +1,5 @@ +lm-eval # For testing +torch +torchao!=0.14.0 +transformer_engine[pytorch] +transformers diff --git a/bionemo-recipes/models/qwen3/state.py b/bionemo-recipes/models/qwen3/state.py new file mode 100644 index 000000000..bda08c4d7 --- /dev/null +++ b/bionemo-recipes/models/qwen3/state.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion utilities adapted from nemo.lightning.io.state. + +This module provides the transform system used by convert.py to map state dicts between model formats: + +- ``mapping``: A dict of simple key renames (source_key -> target_key). Each source key is copied directly + to the corresponding target key with no modification to the tensor values. + +- ``transforms``: A list of ``StateDictTransform`` objects for multi-key merges and splits. These handle + cases where multiple source keys must be combined into one target key (e.g., merging Q/K/V into fused QKV), + or one source key must be split into multiple target keys. + + Important: When ``source_key`` is a tuple (many-to-one merge), the transform function's parameter names + are used to map each source key to a function argument. This means ``*args`` style parameters do not work; + each parameter must be explicitly named (e.g., ``def fn(q, k, v)`` not ``def fn(*args)``). +""" + +import inspect +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger(__name__) + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + """Transform Data class Definition.""" + + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +class _ModelState: + """Helper class for used for to modify state dict of a source model during model conversion.""" + + def __init__(self, state_dict, config=None): + self._state_dict = state_dict + self.config = config + + def state_dict(self): + # pylint: disable=C0115,C0116 + return self._state_dict + + def to(self, dtype): + # pylint: disable=C0115,C0116 + for k, v in self._state_dict.items(): + if v.dtype != dtype: + logger.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + +@torch.no_grad +def apply_transforms( + source: Union[nn.Module, _ModelState], + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, + state_dict_ignored_entries: Optional[List] = None, + cast_dtype: Optional[torch.dtype] = None, +) -> TargetModuleT: + """Transform the state dictionary of a source module to match the structure of a target module's state dictionary. + + This function renames keys according to a provided mapping and modifies values using a list + of transformation functions. Each transformation function typically is decorated + with `io.state_transform`. + + Args: + source (nn.Module): The source module from which parameters and buffers are taken. + target (TargetModuleT): The target module to which parameters and buffers are adapted. + mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary + is mapped to a corresponding key in the target state dictionary. + transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions + that modify the `TransformCTX` object. If None, no transformations beyond key renaming + are applied. Defaults to None. + state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases + where multiple entries in model's state_dict point to one entry in model's named_parameter. + E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, + `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight + in T5 Huggingface implementation.). In these cases, ignore redundant entries. + cast_dtype: case the output state dict to a certain precision. + + Returns: + TargetModuleT: The modified target module with its state dictionary adjusted according to + the specified mappings and transformations. + + Raises: + ValueError: If there's a mismatch in shape between corresponding source and target parameters + or buffers. + RuntimeError: If the target state dictionary contains keys that are not present in the source + state dictionary after all transformations. + + Examples: + >>> source_module = nn.Linear(10, 5) + >>> target_module = nn.Linear(10, 5) + >>> mapping = {'weight': 'weights', 'bias': 'biases'} + @io.state_transform( + source_key="weight", + target_key="weights" + ) + def scale_weights(ctx): + ctx.target_state['weights'] = ctx.source_state['weight'] * 2 + return ctx + >>> transformed_target = apply_transforms( + ... source_module, target_module, mapping, [scale_weights] + ... ) + >>> print(transformed_target.state_dict()['weights']) + + See Also: + - `TransformCTX`: For more details on the context object used in transformations. + - `StateDictTransform`: For creating complex transformations. + + Note: + This function is particularly useful when adapting models from different frameworks or + when consolidating models with different architectural changes. + """ + if transforms is None: + transforms = [] + if state_dict_ignored_entries is None: + state_dict_ignored_entries = [] + + # Track dtypes to make sure they weren't modified during conversion. + target_orig_dtypes = extract_dtypes(target.named_parameters()) + + target_state = target.state_dict() + ctx = TransformCTX( + source=source, + source_state=source.state_dict(), + target=target, + target_state=target_state, + ) + + for key, val in mapping.items(): + logger.debug(f"Mapping {key} -> {val}") + ctx = StateDictTransform(key, val)(ctx) + + for transform in transforms: + logger.debug(f"Transforming {transform.source_key} -> {transform.target_key}") + ctx = transform(ctx) + + _params: Dict[str, nn.Parameter] = {} + for name, param in target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError( + f"Shape mismatch for parameter {name}: target shape {param.shape} vs " + f"converted source shape {target_param.shape}" + ) + + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in target model but is in source model.") + + for key, val in _params.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_parameter(_key, val) + + _buffers = {} + for name, buffer in target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_buffer(_key, val) + + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) + keys = [key for key in keys if key not in state_dict_ignored_entries] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.") + + if hasattr(target, "tie_weights"): + target.tie_weights() + + meta_tensor_keys = [] + for name, param in target.named_parameters(): + if param.is_meta: + meta_tensor_keys.append(name) + + assert not meta_tensor_keys, ( + f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." + f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" + ) + + if cast_dtype: + logger.info(f"Casting model to {cast_dtype}...") + target.to(cast_dtype) + logger.info(f"Casting model to {cast_dtype} complete.") + else: + target_new_dtypes = extract_dtypes(target.named_parameters()) + for key in target_orig_dtypes.keys(): + if key in target_new_dtypes: # For tied weights, these parameters may disappear. + assert target_orig_dtypes[key] == target_new_dtypes[key], ( + f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}" + ) + + return target + + +def _default_transform(inp): + return inp + + +class StateDictTransform(Generic[F]): + """A transformation class for state dictionaries. + + Allows for flexible key matching and transformation of values between source and target state dictionaries. + + Attributes: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + transform: A callable that performs the transformation on matched keys' values. + + Examples: + >>> def example_transform(ctx, *args): + ... return sum(args) + >>> transform = StateDictTransform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", + ... transform=example_transform + ... ) + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + """Initialize the StateDictTransform.""" + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + """Perform the transformation on the given context.""" + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + np.set_printoptions(threshold=10) + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + matched = False + if isinstance(source_key, (dict, tuple)): + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): + # Wrap in a list if it's a single layer (ie non-expert) + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] # noqa: PLW2901 + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) + logger.debug(f"Matched (transform)! {layer_names_group=}") + matched = True + else: + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size == 1 and target_matches == np.array(None): + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + _matches = [_match_keys(target_keys, key) for key in target_key] + target_matches = np.stack(_matches, axis=-1) + + # Determine if we are dealing with multiple source matches or multiple target matches + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + try: + source_match = source_matches[target_index] + except IndexError as e: + logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}") + raise e + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + logger.debug(f"Matched (multi source)! {target_match=} {source_match=}") + matched = True + else: + for source_index, source_match in np.ndenumerate(source_matches): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = dict(zip(fn_params, source_values)) + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + logger.debug(f"Matched (single source)! {target_match=} {source_match=}") + matched = True + if not matched: + logger.warning(f"No matches found for source key: {source_key=} {target_key=}") + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + """Perform transform and check if the given args valid.""" + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ["self", "ctx"]]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if "ctx" in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + escaped_pattern = "" + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == "**": + escaped_pattern += r"(.+)" # Match any characters including dots + wildcard_positions.append("**") + i += 2 + elif pattern[i] == "*": + escaped_pattern += r"([^.]+)" # Match any characters except dots + wildcard_positions.append("*") + i += 1 + else: + if pattern[i] == ".": + escaped_pattern += r"\." # Escape the dot + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + # Sort the wildcard matches to maintain consistent ordering + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + # Determine the shape of the output array based on the unique matches for each wildcard + shape = [len(matches) for matches in wildcard_matches] + + if len(wildcard_matches) == 0: + # If there is no wildcard matches, assuming it is a single match + shape = [1] + # Initialize an empty array with the determined shape + output_array = np.empty(shape, dtype=object) + + # Populate the array with the keys, now that we have the correct shape and ordering + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + # Convert match groups to indices based on their position in wildcard_matches + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key # Place the key in the array based on the indices + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """Create a StateDictTransform instance with specified source and target keys, and a transformation function. + + Args: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + fn: An optional callable that performs the transformation on matched keys' values. If not + provided, the decorator can be used to wrap a function definition. + + Returns: + ------- + A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that + takes a function and returns a StateDictTransform instance. + + Examples: + -------- + >>> @state_transform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" + ... ) + ... def sum_transform(ctx, *args): + ... return sum(args) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + + return wrapper(fn) + + +class TransformFns: + """A collection of common functions used in state dict transformation.""" + + @staticmethod + def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): + """Split interleave-concatenated qkv to q, k, v. + + Example: export layer linear_qkv to HF {q|k|v}_proj + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + hidden_size = linear_qkv.size(-1) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + @staticmethod + def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Split interleave-concatenated qkv bias to separate q, k, v bias. + + Example: export layer linear_qkv bias to HF {q|k|v}_proj bias + """ + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + @staticmethod + def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor): + """Merge naively concatenated q, k, v to interleave-concatenated qkv. + + Example: import HF qkv to layer linear_qkv + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0) + return TransformFns.merge_qkv(ctx, q, k, v) + + @staticmethod + def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge q, k, v to interleave-concatenated qkv. + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + @staticmethod + def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF qkv bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + qb, kb, vb = qkv_bias.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + return TransformFns.merge_qkv_bias(ctx, qb, kb, vb) + + @staticmethod + def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): + """Merge q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF {q|k|v}_proj bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + + new_q_tensor_shape = (head_num, head_size) + new_kv_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_tensor_shape) + kb = kb.view(*new_kv_tensor_shape) + vb = vb.view(*new_kv_tensor_shape) + + qkv_bias = torch.empty((0, head_size)).type_as(qb) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) + return qkv_bias + + @staticmethod + def merge_fc1(gate: torch.Tensor, up: torch.Tensor): + """Merge gate and up proj into concatenated fc1. + + Example: import HF {gate|up}_proj to layer linear_fc1 + """ + return torch.cat((gate, up), dim=0) + + @staticmethod + def split_fc1(linear_fc1: torch.Tensor): + """Split concatenated fc1 to gate and up proj. + + Example: export layer linear_fc1 to HF {gate|up}_proj + """ + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + return gate_proj, up_proj + + @staticmethod + def duplicate2(param: torch.Tensor): + """Duplicate the source parameter to two target parameters. + + Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A + """ + return param, param + + @staticmethod + def duplicate3(param: torch.Tensor): + """Duplicate the source parameter to three target parameters. + + Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A + """ + return param, param, param + + @staticmethod + def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): + """Prune the embedding size to vocab size. + + Example: export embedding/output layer to HF with non-padded vocab size + """ + megatron_config = ctx.target.config + return embedding[: megatron_config.vocab_size, :] + + +def extract_dtypes(ckpt): + """Extract dtype from the input iterator. + + ckpt can be module.named_parameters or module.state_dict().items() + """ + dtypes = {} + for key, val in ckpt: + if hasattr(val, "dtype"): + dtypes[key] = val.dtype + elif hasattr(val, "data") and hasattr(val.data, "dtype"): + # if it's ShardedTensor populated with data. + dtypes[key] = val.data.dtype + return dtypes diff --git a/bionemo-recipes/models/qwen3/tests/__init__.py b/bionemo-recipes/models/qwen3/tests/__init__.py new file mode 100644 index 000000000..1dd47a63c --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/bionemo-recipes/models/qwen3/tests/common/README.md b/bionemo-recipes/models/qwen3/tests/common/README.md new file mode 100644 index 000000000..bed8deaea --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/common/README.md @@ -0,0 +1,64 @@ +# BioNeMo Common Test Library + +Shared test infrastructure for BioNeMo models. One base class, **BaseModelTest**: inherit and implement the abstract methods to get the full test suite (golden values, conversion, FP8, meta init, smoke tests). + +## Structure + +```text +tests/common/ +├── __init__.py # Public API exports +├── test_modeling_common.py # BaseModelTest, TestTolerances +├── fixtures.py # input_format, fp8_recipe, te_attn_backend, etc. +└── README.md +``` + +**Required:** In your top-level `tests/conftest.py` (e.g. `bionemo-recipes/models/esm2/tests/conftest.py`), add: + +```python +pytest_plugins = ["tests.common.fixtures"] +``` + +Without this, parametrized fixtures will not load. + +## BaseModelTest + +Inherit from `BaseModelTest` and implement: + +| Method | Returns | Description | +| ------------------------------------------------- | ------------------------- | ----------------------------------------------- | +| `get_model_class()` | `Type[PreTrainedModel]` | TE model class | +| `get_tokenizer()` | `PreTrainedTokenizer` | Tokenizer | +| `get_config_class()` | `Type[PretrainedConfig]` | Config class | +| `get_upstream_model_id()` | `str` | HF model ID | +| `get_upstream_model_revision()` | `Optional[str]` | Revision or None | +| `get_upstream_model_class()` | `Type[PreTrainedModel]` | HF model class | +| `get_layer_path(model)` | `List[nn.Module]` | Transformer layers | +| `get_test_input_data(format, pad_to_multiple_of)` | `Dict[str, torch.Tensor]` | Inputs on CUDA; `format` is `"bshd"` or `"thd"` | +| `get_hf_to_te_converter()` | `Callable` | HF → TE | +| `get_te_to_hf_converter()` | `Callable` | TE → HF | + +**Optional overrides:** `get_tolerances()` → `TestTolerances`, `get_attn_input_formats()`, `get_reference_model_no_weights()`. + +**Helpers:** `create_test_config()`, `get_reference_model()`, `get_reference_model_no_weights()`, `compare_outputs()`, `verify_model_parameters_initialized_correctly()`, `get_converted_te_model_checkpoint()`, `get_converted_te_model()`. + +**Tests included:** Meta/CUDA init (`test_cuda_init`, `test_meta_init`, …), smoke (parametrized by `input_format`), conversion, golden values (BSHD + THD), FP8 (parametrized by `fp8_recipe`, `input_format`). + +## TestTolerances + +Dataclass in `test_modeling_common.py`. Override `get_tolerances()` to return a custom instance. Fields: `golden_value_*`, `cp_*`, `fp8_*`, `init_*` (see class definition). + +## Fixtures (fixtures.py) + +| Fixture | Description | +| ----------------- | ----------------------------------- | +| `input_format` | `"bshd"` / `"thd"` | +| `fp8_recipe` | FP8 recipe (skipped if unsupported) | +| `te_attn_backend` | `"flash_attn"` / `"fused_attn"` | +| `unused_tcp_port` | For distributed tests | +| `use_te_debug` | Autouse: `NVTE_DEBUG=1` | + +## Usage + +1. Create a class inheriting from `BaseModelTest` and implement the abstract methods (see `esm2/tests/test_modeling_esm_te.py` for a full example). +2. Add `pytest_plugins = ["tests.common.fixtures"]` to `tests/conftest.py`. +3. Run `pytest tests/test_modeling__te.py -v`. diff --git a/bionemo-recipes/models/qwen3/tests/common/__init__.py b/bionemo-recipes/models/qwen3/tests/common/__init__.py new file mode 100644 index 000000000..a2570bee9 --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/common/__init__.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common test utilities for BioNeMo models. + +This package provides reusable test infrastructure following HuggingFace +transformers patterns, including: + +- BaseModelTest: Base test class with all common test methods +- TestTolerances: Dataclass for model-specific numerical tolerances +- Distributed testing utilities for multi-GPU tests +- Shared fixtures for common test requirements + +Example usage: + + ```python + from tests.common import BaseModelTest, TestTolerances + + class ESM2ModelTester(BaseModelTest): + def get_model_class(self): + return NVEsmForMaskedLM + # ... implement other abstract methods + ``` +""" + +from .test_modeling_common import HAS_DATA_CENTER_GPU, BaseModelTest, TestTolerances + + +__all__ = [ + "HAS_DATA_CENTER_GPU", + "BaseModelTest", + "TestTolerances", +] diff --git a/bionemo-recipes/models/qwen3/tests/common/fixtures.py b/bionemo-recipes/models/qwen3/tests/common/fixtures.py new file mode 100644 index 000000000..a437aae3d --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/common/fixtures.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test fixtures for BioNeMo models.""" + +import os +import socket + +import pytest +from transformer_engine.common import recipe as recipe_module +from transformer_engine.pytorch import fp8 +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends + + +@pytest.fixture +def unused_tcp_port() -> int: + """Get an unused TCP port for distributed testing. + + Returns: + An available TCP port number. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +@pytest.fixture(autouse=True) +def use_te_debug(): + """Auto-use fixture to enable TransformerEngine debugging. + + This fixture automatically enables debug mode for TransformerEngine + in all tests for better error messages. + """ + import os + + os.environ["NVTE_DEBUG"] = "1" + yield + os.environ.pop("NVTE_DEBUG", None) + + +ALL_RECIPES = [ + recipe_module.DelayedScaling(), + recipe_module.Float8CurrentScaling(), + recipe_module.Float8BlockScaling(), + recipe_module.MXFP8BlockScaling(), + recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +] + + +def _check_recipe_support(recipe: recipe_module.Recipe): + """Check if a recipe is supported and return (supported, reason).""" + if isinstance(recipe, recipe_module.DelayedScaling): + recipe_supported, reason = fp8.check_fp8_support() + elif isinstance(recipe, recipe_module.Float8CurrentScaling): + recipe_supported, reason = fp8.check_fp8_support() + elif isinstance(recipe, recipe_module.Float8BlockScaling): + recipe_supported, reason = fp8.check_fp8_block_scaling_support() + elif isinstance(recipe, recipe_module.MXFP8BlockScaling): + recipe_supported, reason = fp8.check_mxfp8_support() + elif isinstance(recipe, recipe_module.NVFP4BlockScaling): + recipe_supported, reason = fp8.check_nvfp4_support() + else: + recipe_supported = False + reason = "Unsupported recipe" + return recipe_supported, reason + + +def parametrize_recipes_with_support(recipes): + """Generate pytest.param objects with skip marks for unsupported recipes.""" + parametrized_recipes = [] + for recipe in recipes: + recipe_supported, reason = _check_recipe_support(recipe) + parametrized_recipes.append( + pytest.param( + recipe, + id=recipe.__class__.__name__, + marks=pytest.mark.xfail( + condition=not recipe_supported, + reason=reason, + ), + ) + ) + return parametrized_recipes + + +@pytest.fixture(params=parametrize_recipes_with_support(ALL_RECIPES)) +def fp8_recipe(request): + """Fixture to parametrize the FP8 recipe.""" + return request.param + + +@pytest.fixture(params=["bshd", "thd"]) +def input_format(request): + """Fixture to parametrize the input format.""" + return request.param + + +@pytest.fixture(params=["flash_attn", "fused_attn"]) +def te_attn_backend(request): + """Fixture to parametrize the attention implementation.""" + if request.param == "flash_attn": + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_FLASH_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + else: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + yield request.param + + os.environ.pop("NVTE_FUSED_ATTN", None) + os.environ.pop("NVTE_FLASH_ATTN", None) + _attention_backends["backend_selection_requires_update"] = True diff --git a/bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py b/bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py new file mode 100644 index 000000000..1912148ad --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py @@ -0,0 +1,1108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common test class for BioNeMo models, following HuggingFace transformers patterns.""" + +import fnmatch +import gc +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Literal, Type + +import pytest +import torch +import transformer_engine.pytorch +from torch import nn +from transformer_engine.pytorch import QuantizedTensor +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, set_seed + + +try: + HAS_DATA_CENTER_GPU = torch.cuda.is_available() and any( + gpu_name in torch.cuda.get_device_name(0).upper() for gpu_name in ["H100", "H200", "B100", "B200", "B300"] + ) +except (RuntimeError, AssertionError): + HAS_DATA_CENTER_GPU = False + + +@dataclass +class TestTolerances: + """Model-specific test tolerances for numerical comparisons.""" + + # Golden value test tolerances + golden_value_loss_atol: float = 1e-2 + golden_value_loss_rtol: float = 1e-3 + golden_value_logits_atol: float = 2.0 + golden_value_logits_rtol: float = 1e-4 + golden_value_hidden_states_atol: float = 0.1 + golden_value_hidden_states_rtol: float = 0.05 + + # Context parallel test tolerances + cp_loss_atol: float = 0.1 + cp_loss_rtol: float = 0.05 + cp_logits_atol: float = 1.0 + cp_logits_rtol: float = 0.1 + cp_gradients_atol: float = 0.1 + cp_gradients_rtol: float = 0.1 + + # FP8 test tolerances + fp8_loss_atol: float = 0.1 + fp8_loss_rtol: float = 0.05 + fp8_logits_atol: float = 5.0 + fp8_logits_rtol: float = 0.1 + + # Meta device initialization tolerances + init_mean_atol: float = 1e-3 + init_mean_rtol: float = 1e-4 + init_std_atol: float = 1e-3 + init_std_rtol: float = 1e-4 + + +class BaseModelTest(ABC): + """Abstract base class for testing BioNeMo models. + + This class provides common test utilities and defines the interface that + model-specific testers must implement. It follows the pattern used in + HuggingFace transformers for model testing. + + Subclasses must implement all abstract methods to provide model-specific + configuration, data preparation, and conversion functions. + + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + + Example: + ```python + class ESM2ModelTester(BioNeMoModelTester): + def get_model_class(self): + return NVEsmForMaskedLM + + def get_config_class(self): + return NVEsmConfig + + def get_upstream_model_id(self): + return "facebook/esm2_t6_8M_UR50D" + + # ... implement other abstract methods + ``` + """ + + is_autoregressive: bool = False + + @abstractmethod + def get_model_class(self) -> Type[PreTrainedModel]: + """Return the TransformerEngine model class to test. + + Returns: + The model class (e.g., NVEsmForMaskedLM, NVLlamaForCausalLM). + """ + pass + + @abstractmethod + def get_tokenizer(self) -> PreTrainedTokenizer: + """Return the tokenizer for the model. + + Returns: + The tokenizer (e.g., AutoTokenizer). + """ + pass + + @abstractmethod + def get_config_class(self) -> Type[PretrainedConfig]: + """Return the config class for the model. + + Returns: + The config class (e.g., NVEsmConfig, NVLlamaConfig). + """ + pass + + @abstractmethod + def get_upstream_model_id(self) -> str: + """Return the HuggingFace model ID for the reference model. + + Returns: + Model ID string (e.g., "facebook/esm2_t6_8M_UR50D"). + """ + pass + + @abstractmethod + def get_upstream_model_revision(self) -> str: + """Return the specific revision/commit hash for the upstream model. + + Returns: + Revision string or 'main' for latest. + """ + pass + + @abstractmethod + def get_upstream_model_class(self) -> Type[PreTrainedModel]: + """Return the HuggingFace reference model class. + + Returns: + The HF model class (e.g., AutoModelForMaskedLM, AutoModelForCausalLM). + """ + pass + + @abstractmethod + def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: + """Return the list of transformer layers in the model. + + Args: + model: The model instance. + + Returns: + List of transformer layer modules. + + Example: + For ESM2: model.esm.encoder.layers + For LLaMA3: model.model.layers + """ + pass + + @abstractmethod + def get_test_input_data( + self, + format: Literal["bshd", "thd"] = "bshd", + pad_to_multiple_of: int | None = None, + ) -> Dict[str, torch.Tensor]: + """Prepare test input data for the model. + + Args: + format: Whether to use sequence packing (THD) or bshd format. + + Returns: + Dictionary of input tensors (input_ids, attention_mask, etc.). + """ + pass + + @abstractmethod + def get_hf_to_te_converter(self) -> Callable: + """Return the function that converts HF model to TE model. + + Returns: + Conversion function with signature: (hf_model, **kwargs) -> te_model + """ + pass + + @abstractmethod + def get_te_to_hf_converter(self) -> Callable: + """Return the function that converts TE model to HF model. + + Returns: + Conversion function with signature: (te_model, **kwargs) -> hf_model + """ + pass + + def get_tolerances(self) -> TestTolerances: + """Return test tolerances for this model. + + Override this method to provide model-specific tolerances. + + Returns: + TestTolerances instance with appropriate values. + """ + return TestTolerances() + + def get_attn_input_formats(self) -> List[str]: + """Return supported attention input formats. + + Returns: + List of format strings (e.g., ["bshd", "thd"]). + """ + return ["bshd"] + + def verify_model_parameters_initialized_correctly( + self, + model: PreTrainedModel, + atol: float | None = None, + rtol: float | None = None, + should_be_fp8: bool = False, + ) -> None: + """Verify that model parameters are initialized correctly. + + This can be overridden for models that use non-standard weight initialization. + + This checks that: + 1. All parameters are on CUDA device + 2. Embeddings have correct mean and std + 3. Linear layers have correct weight/bias initialization + 4. LayerNorm parameters are initialized correctly + 5. FP8 quantization is applied if requested + + Args: + model: The model to verify. + atol: Absolute tolerance for comparisons (uses default if None). + rtol: Relative tolerance for comparisons (uses default if None). + should_be_fp8: Whether to expect FP8 quantized weights. + """ + config = model.config + tolerances = self.get_tolerances() + + if atol is None: + atol = tolerances.init_mean_atol + if rtol is None: + rtol = tolerances.init_mean_rtol + + # Verify all parameters are on CUDA + for name, parameter in model.named_parameters(): + assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device" + + # Verify initialization for each module type + for name, module in model.named_modules(): + + def msg(x): + return f"Mismatch in module {name}: {x}" + + if isinstance(module, torch.nn.Embedding): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), + config.initializer_range, + atol=tolerances.init_std_atol, + rtol=tolerances.init_std_rtol, + msg=msg, + ) + + elif isinstance(module, transformer_engine.pytorch.Linear): + torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg) + torch.testing.assert_close( + module.weight.std().item(), + config.initializer_range, + atol=tolerances.init_std_atol, + rtol=tolerances.init_std_rtol, + msg=msg, + ) + if module.bias is not None: + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + if should_be_fp8: + if f"{name}.weight" in set(model._tied_weights_keys): + continue # Skip tied weights + elif hasattr(model, "_do_not_quantize") and any( + fnmatch.fnmatch(name, pattern) for pattern in model._do_not_quantize + ): + continue # Skip weights that should be kept in bf16 + assert isinstance(module.weight, QuantizedTensor), f"Module {name} weight is not a Float8Tensor" + + elif isinstance(module, transformer_engine.pytorch.LayerNorm): + torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + elif isinstance(module, torch.nn.LayerNorm): + torch.testing.assert_close(module.weight, torch.ones_like(module.weight), msg=msg) + if module.bias is not None: + torch.testing.assert_close(module.bias, torch.zeros_like(module.bias), msg=msg) + + def create_test_config(self, **kwargs) -> PretrainedConfig: + """Create a test configuration with optional overrides. + + Args: + **kwargs: Configuration parameters to override. + + Returns: + Configuration instance. + """ + config_class = self.get_config_class() + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + return config_class.from_pretrained(upstream_id, revision=revision, **kwargs) + + def get_reference_model( + self, + dtype: torch.dtype = torch.bfloat16, + attn_implementation: str = "flash_attention_2", + ) -> PreTrainedModel: + """Load the reference HuggingFace model. + + Args: + dtype: Data type for the model. + device: Device to load model on. + attn_implementation: Attention implementation to use. + + Returns: + The loaded reference model. + """ + upstream_class = self.get_upstream_model_class() + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + + kwargs = { + "dtype": dtype, + "attn_implementation": attn_implementation, + } + if revision is not None: + kwargs["revision"] = revision + + model = upstream_class.from_pretrained(upstream_id, **kwargs) + model.to("cuda") + return model + + def get_reference_model_no_weights( + self, dtype: torch.dtype = torch.float32, revision: str | None = None, **kwargs + ) -> PreTrainedModel: + """Load the reference HuggingFace model with random weights.""" + if revision is None: + revision = self.get_upstream_model_revision() + return self.get_upstream_model_class()( + AutoConfig.from_pretrained( + self.get_upstream_model_id(), + dtype=dtype, + revision=revision, + **kwargs, + ) + ) + + def compare_outputs( + self, + te_outputs, + hf_outputs, + input_data: Dict[str, torch.Tensor], + compare_loss: bool = True, + compare_logits: bool = True, + compare_hidden_states: bool = False, + ) -> None: + """Compare outputs from TE and HF models. + + Args: + te_outputs: Outputs from TransformerEngine model. + hf_outputs: Outputs from HuggingFace model. + input_data: Input data dictionary (for attention mask). + compare_loss: Whether to compare loss values. + compare_logits: Whether to compare logits. + compare_hidden_states: Whether to compare hidden states. + """ + tolerances = self.get_tolerances() + + if compare_loss and hasattr(te_outputs, "loss") and hasattr(hf_outputs, "loss"): + torch.testing.assert_close( + te_outputs.loss, + hf_outputs.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + msg=lambda x: f"Loss mismatch between TE and HF models: {x}", + ) + + if compare_logits and hasattr(te_outputs, "logits") and hasattr(hf_outputs, "logits"): + # Only compare logits where attention mask is True + if "attention_mask" in input_data: + mask = input_data["attention_mask"].to(bool) + torch.testing.assert_close( + te_outputs.logits[mask], + hf_outputs.logits[mask], + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + msg=lambda x: f"Logits mismatch between TE and HF models: {x}", + ) + else: + torch.testing.assert_close( + te_outputs.logits, + hf_outputs.logits, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + msg=lambda x: f"Logits mismatch between TE and HF models: {x}", + ) + + if compare_hidden_states and hasattr(te_outputs, "hidden_states") and hasattr(hf_outputs, "hidden_states"): + for i, (te_hidden, hf_hidden) in enumerate(zip(te_outputs.hidden_states, hf_outputs.hidden_states)): + torch.testing.assert_close( + te_hidden, + hf_hidden, + atol=tolerances.golden_value_hidden_states_atol, + rtol=tolerances.golden_value_hidden_states_rtol, + msg=lambda x: f"Hidden states mismatch at layer {i}: {x}", + ) + + @pytest.fixture(autouse=True, scope="function") + def clear_gpu_memory(self): + """Clear GPU memory before and after each test to prevent OOM from fragmentation.""" + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + @pytest.fixture(autouse=True, scope="function") + def set_seed(self): + set_seed(42) + + @pytest.fixture(autouse=True, scope="function") + def reset_fp8_context(self): + """Make sure we clean up the FP8 context after each test.""" + FP8GlobalStateManager.reset() + + # ==================== Forward and Backward Smoke Tests ==================== + + def test_smoke_forward_pass(self, input_format): + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Forward pass with output_hidden_states + with torch.no_grad(): + outputs = model(**input_data, output_hidden_states=True) + + # Verify outputs + assert outputs.logits is not None, "Model should output logits" + assert outputs.hidden_states is not None, "Model should output hidden states when requested" + assert len(outputs.hidden_states) == config.num_hidden_layers + 1, ( + f"Expected {config.num_hidden_layers + 1} hidden states, got {len(outputs.hidden_states)}" + ) + + def test_smoke_backward_pass(self, input_format): + """Smoke test: backward pass.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Forward pass + outputs = model(**input_data, output_hidden_states=True) + + # Backward pass + outputs.logits.mean().backward() + + # Verify all parameters have gradients + for param in model.parameters(): + if param.requires_grad: + assert param.grad is not None, "All trainable parameters should have gradients after backward pass" + + def test_smoke_model_with_loss(self, input_format): + """Smoke test: model forward pass with labels produces loss.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data with labels + input_data = self.get_test_input_data(input_format) + + # Ensure labels are present + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward pass + with torch.no_grad(): + outputs = model(**input_data) + + # Verify loss is computed + assert outputs.loss is not None, "Model should compute loss when labels are provided" + assert outputs.loss.item() > 0, "Loss should be positive" + + def test_forward_and_backward(self, input_format): + """Test that model can perform forward and backward passes.""" + model_class = self.get_model_class() + config = self.create_test_config(attn_input_format=input_format) + + model = model_class(config) + model.to(torch.bfloat16) + model.to("cuda") + + # Prepare input data + input_data = self.get_test_input_data(input_format) + + # Add labels for loss computation + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward pass + outputs = model(**input_data) + loss = outputs.loss + + # Backward pass + loss.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient" + + # ==================== Conversion Tests ==================== + + def test_convert_hf_to_te(self): + """Test that HF model can be converted to TE format.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + # Convert to TE + convert_fn = self.get_hf_to_te_converter() + model_te = convert_fn(model_hf_original) + + # Verify model structure + assert model_te is not None + assert isinstance(model_te, self.get_model_class()) + + def test_convert_te_to_hf(self): + """Test that TE model can be converted back to HF format.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + # Convert to TE + hf_to_te_fn = self.get_hf_to_te_converter() + model_te = hf_to_te_fn(model_hf_original) + + # Convert back to HF + te_to_hf_fn = self.get_te_to_hf_converter() + model_hf_converted = te_to_hf_fn(model_te) + + # Verify model structure + assert model_hf_converted is not None + assert isinstance(model_hf_converted, self.get_upstream_model_class()) + + def test_convert_te_to_hf_roundtrip(self): + """Test that HF → TE → HF conversion preserves weights.""" + # Load reference HF model + model_hf_original = self.get_reference_model_no_weights() + original_state_dict = model_hf_original.state_dict() + + # Convert to TE and back + hf_to_te_fn = self.get_hf_to_te_converter() + te_to_hf_fn = self.get_te_to_hf_converter() + + model_te = hf_to_te_fn(model_hf_original) + model_hf_converted = te_to_hf_fn(model_te) + converted_state_dict = model_hf_converted.state_dict() + + # Compare state dicts + assert set(original_state_dict.keys()) == set(converted_state_dict.keys()), "State dict keys don't match" + + for key in original_state_dict.keys(): + original_param = original_state_dict[key] + converted_param = converted_state_dict[key] + + # Convert both to the same dtype for comparison (use the original dtype) + if original_param.dtype != converted_param.dtype: + converted_param = converted_param.to(original_param.dtype) + + torch.testing.assert_close( + original_param, + converted_param, + atol=1e-5, + rtol=1e-5, + msg=f"Mismatch in parameter {key} after roundtrip conversion", + ) + + def test_convert_config(self): + """Test that config can be converted between HF and TE formats.""" + upstream_id = self.get_upstream_model_id() + revision = self.get_upstream_model_revision() + + # Load HF config + from transformers import AutoConfig + + kwargs = {} + if revision is not None: + kwargs["revision"] = revision + hf_config = AutoConfig.from_pretrained(upstream_id, **kwargs) + + # Get TE config class + te_config_class = self.get_config_class() + + # Convert to TE config + te_config = te_config_class(**hf_config.to_dict()) + + # Verify key attributes match + assert te_config.hidden_size == hf_config.hidden_size + assert te_config.num_hidden_layers == hf_config.num_hidden_layers + assert te_config.num_attention_heads == hf_config.num_attention_heads + + @pytest.fixture(scope="class", autouse=True) + def _set_tmpdir(self, tmp_path_factory): + """Make sure we can see the saved te checkpoint as a class-scoped fixture.""" + # set on the class, visible as self._tmp_dir + type(self)._tmp_dir = tmp_path_factory.mktemp(self.__class__.__name__) + + def get_converted_te_model_checkpoint(self) -> Path: + """Get the path to the converted TE model checkpoint. + + This method manages GPU memory carefully to support large models: + 1. Load and convert the HF model + 2. Free the HF model before saving + 3. Move TE model to CPU before saving (save_pretrained clones state dict internally) + """ + model_hf = self.get_reference_model() + convert_fn = self.get_hf_to_te_converter() + model_te = convert_fn(model_hf) + + # Free source model to reduce peak GPU memory + del model_hf + gc.collect() + torch.cuda.empty_cache() + + # Move to CPU before saving - save_pretrained internally clones the state dict, + # which would double GPU memory usage and OOM for large models. + model_te.to("cpu") + + checkpoint_path: Path = self._tmp_dir / "converted_te_model" + model_te.save_pretrained(checkpoint_path) + + del model_te + gc.collect() + + return checkpoint_path + + def get_converted_te_model(self, **kwargs) -> PreTrainedModel: + """Get the converted TE model. + + This shouldn't get called before the checkpoint tests are run in case they're broken. + """ + checkpoint_path = self.get_converted_te_model_checkpoint() + model_te = self.get_model_class().from_pretrained(checkpoint_path, **kwargs) + model_te.to("cuda") + return model_te + + # ==================== Golden Value Tests ==================== + + def test_golden_values(self): + """Test that TE model outputs match HF reference model. + + Models are run sequentially and freed between runs to support large models + that cannot fit two copies on a single GPU simultaneously. + """ + input_data = self.get_test_input_data("bshd") + + # Run HF model first, then free it + model_hf = self.get_reference_model(dtype=torch.bfloat16) + model_hf.eval() + with torch.no_grad(): + hf_outputs = model_hf(**input_data) + hf_loss = hf_outputs.loss.detach().clone() + hf_logits = hf_outputs.logits.detach().clone() + del model_hf, hf_outputs + gc.collect() + torch.cuda.empty_cache() + + # Load and run TE model + model_te = self.get_converted_te_model(dtype=torch.bfloat16) + model_te.eval() + with torch.no_grad(): + te_outputs = model_te(**input_data) + del model_te + gc.collect() + torch.cuda.empty_cache() + + # Compare outputs + self.compare_outputs( + te_outputs, + type("HFOutputs", (), {"loss": hf_loss, "logits": hf_logits})(), + input_data, + compare_loss=True, + compare_logits=True, + compare_hidden_states=False, + ) + + def test_golden_values_thd(self, te_attn_backend): + """Test the model outputs the same results with THD and BSHD input formats.""" + + if te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: + pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + elif te_attn_backend == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") + + input_data_bshd = self.get_test_input_data(format="bshd") + input_data_thd = self.get_test_input_data(format="thd") + tolerances = self.get_tolerances() + + torch.testing.assert_close( + input_data_bshd["input_ids"][input_data_bshd["attention_mask"].to(bool)], + input_data_thd["input_ids"].flatten(0), + ) + + # The THD labels will have some extra -100 items due to the separator token, so we need to filter them out. + labels_bshd = input_data_bshd["labels"][input_data_bshd["attention_mask"].to(bool)] + labels_thd = input_data_thd["labels"].flatten(0) + torch.testing.assert_close(labels_bshd[labels_thd != -100], labels_thd[labels_thd != -100]) + + # Run models sequentially to support large models that cannot fit two copies on GPU + model_bshd = self.get_converted_te_model(attn_input_format="bshd", dtype=torch.bfloat16) + model_bshd.eval() + with torch.inference_mode(): + outputs_bshd = model_bshd(**input_data_bshd) + bshd_loss = outputs_bshd.loss.detach().clone() + bshd_logits = outputs_bshd.logits[input_data_bshd["attention_mask"].to(bool)].detach().clone() + del model_bshd, outputs_bshd + gc.collect() + torch.cuda.empty_cache() + + model_thd = self.get_converted_te_model(attn_input_format="thd", dtype=torch.bfloat16) + model_thd.eval() + with torch.inference_mode(): + outputs_thd = model_thd(**input_data_thd) + + # Compare logits + torch.testing.assert_close( + bshd_logits, + outputs_thd.logits, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + ) + + # Compare losses + torch.testing.assert_close( + bshd_loss, + outputs_thd.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + ) + + def test_thd_padding_input_data_equivalence(self): + """Test that the THD input data is the same before and after padding.""" + + input_data_thd = self.get_test_input_data(format="thd") + input_data_thd_padded = self.get_test_input_data(format="thd", pad_to_multiple_of=32) + + cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] + cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] + cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q + seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + + num_real_tokens = cu_seq_lens_q[-1] + + # How much we need to shift each sequence by. + offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) + + # The indices of the real tokens as appears in the padded logits. + real_idx = torch.arange(0, num_real_tokens, device="cuda") + offsets + + torch.testing.assert_close( + input_data_thd["input_ids"], + input_data_thd_padded["input_ids"].index_select(1, real_idx), + ) + + torch.testing.assert_close( + input_data_thd["labels"], + input_data_thd_padded["labels"].index_select(1, real_idx), + ) + assert input_data_thd_padded["pad_between_seqs"] is True + + @pytest.mark.xfail( + condition=not HAS_DATA_CENTER_GPU, + reason="Padded THD sequences are not supported on non-datacenter hardware.", + ) + def test_golden_values_thd_padded(self): + """Test that the model outputs the same results with padded input data.""" + + input_data_thd = self.get_test_input_data(format="thd") + input_data_thd_padded = self.get_test_input_data(format="thd", pad_to_multiple_of=32) + tolerances = self.get_tolerances() + + model_thd = self.get_converted_te_model(attn_input_format="thd", dtype=torch.bfloat16) + model_thd.eval() + + with torch.inference_mode(): + outputs_thd = model_thd(**input_data_thd) + outputs_thd_padded = model_thd(**input_data_thd_padded) + + cu_seq_lens_q = input_data_thd["cu_seq_lens_q"] + cu_seq_lens_q_padded = input_data_thd_padded["cu_seq_lens_q_padded"] + cu_num_pads = cu_seq_lens_q_padded - cu_seq_lens_q + seq_lengths_real = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + num_real_tokens = cu_seq_lens_q[-1] + offsets = torch.repeat_interleave(cu_num_pads[:-1], seq_lengths_real, dim=0) + + # The indices of the real tokens as appears in the padded logits. + real_idx = torch.arange(0, num_real_tokens, device="cuda") + offsets + logits_unpadded = outputs_thd_padded.logits.index_select(0, real_idx.cuda()) + + torch.testing.assert_close( + outputs_thd.logits, + logits_unpadded, + atol=tolerances.golden_value_logits_atol, + rtol=tolerances.golden_value_logits_rtol, + ) + + torch.testing.assert_close( + outputs_thd.loss, + outputs_thd_padded.loss, + atol=tolerances.golden_value_loss_atol, + rtol=tolerances.golden_value_loss_rtol, + ) + + # ==================== FP8 Tests ==================== + def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): + """Test that model works with FP8 autocast.""" + if input_format == "thd" and not HAS_DATA_CENTER_GPU: + pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") + + model_class = self.get_model_class() + config = self.create_test_config( + dtype=torch.bfloat16, attn_input_format=input_format, self_attn_mask_type="padding_causal" + ) + + model = model_class(config) + model.to("cuda") + model.eval() + + # Prepare input data + input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) + + # Run without FP8 + with torch.no_grad(): + outputs = model(**input_data) + loss_bf16 = outputs.loss + + # Run with FP8 + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs_fp8 = model(**input_data) + loss_fp8 = outputs_fp8.loss + + assert torch.isfinite(loss_fp8) + + # Backward pass + loss_fp8.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" + + # Compare losses (should be close but not identical due to quantization) + tolerances = self.get_tolerances() + torch.testing.assert_close( + loss_fp8, + loss_bf16, + atol=tolerances.fp8_loss_atol, + rtol=tolerances.fp8_loss_rtol, + msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", + ) + + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): + """Test that model initialized with FP8 works correctly.""" + if input_format == "thd" and not HAS_DATA_CENTER_GPU: + pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") + + model_class = self.get_model_class() + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) + + # Initialize with FP8 + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + model.to("cuda") + model.eval() + + # Prepare input data + input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) + if "labels" not in input_data: + input_data["labels"] = input_data["input_ids"].clone() + + # Forward and backward pass with FP8 + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) + + loss = outputs.loss + assert torch.isfinite(loss) + + loss.backward() + + # Verify gradients exist + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" + + # ==================== Meta Device Initialization Tests ==================== + + def test_cuda_init(self): + """Test that model can be initialized directly on CUDA device.""" + model_class = self.get_model_class() + config = self.create_test_config() + + model = model_class(config) + model.to("cuda") + + self.verify_model_parameters_initialized_correctly(model) + + def test_meta_init(self): + """Test that model can be initialized on meta device and moved to CUDA.""" + model_class = self.get_model_class() + config = self.create_test_config() + + # Initialize on meta device + with torch.device("meta"): + model = model_class(config) + + # Assert parameters are actually on the meta device + for name, parameter in model.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + # Move to CUDA (this will materialize the parameters) + model.init_empty_weights() + self.verify_model_parameters_initialized_correctly(model) + + def test_cuda_fp8_init(self, fp8_recipe): + """Test that model can be initialized on CUDA with FP8.""" + model_class = self.get_model_class() + config = self.create_test_config() + + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + model.to("cuda") + + self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + + def test_meta_fp8_init(self, fp8_recipe): + """Test that model can be initialized on meta device with FP8 and moved to CUDA.""" + model_class = self.get_model_class() + config = self.create_test_config() + + # Initialize on meta device with FP8 + with torch.device("meta"): + with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): + model = model_class(config) + + # Assert parameters are actually on the meta device + for name, parameter in model.named_parameters(): + assert parameter.device == torch.device("meta"), f"Parameter {name} is not on the meta device" + + # Move to CUDA + model.init_empty_weights() + self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + + # ==================== Generation Tests (Autoregressive Models Only) ==================== + + def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + raise NotImplementedError( + "Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams." + ) + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self._create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self._create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/qwen3/tests/conftest.py b/bionemo-recipes/models/qwen3/tests/conftest.py new file mode 100644 index 000000000..31c2a4549 --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/conftest.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +import pytest + + +sys.path.append(Path(__file__).parent.parent.as_posix()) +sys.path.append(Path(__file__).parent.as_posix()) + +pytest_plugins = ["tests.common.fixtures"] + + +@pytest.fixture(scope="session") +def recipe_path() -> Path: + """Return the root directory of the recipe.""" + return Path(__file__).parent.parent diff --git a/bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py b/bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py new file mode 100644 index 000000000..bfe7b9235 --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Qwen3 model. + +This file provides comprehensive tests for the Qwen3 model including: +- Common tests from the test library (meta device init, golden values, conversion, FP8) +- Qwen3-specific tests (inference, generation with KV-cache) +""" + +from typing import Callable, Dict, List, Literal, Type + +import pytest +import torch +import transformers +from torch import nn +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) + +from collator import DataCollatorWithFlattening +from convert import convert_qwen3_hf_to_te, convert_qwen3_te_to_hf +from modeling_qwen3_te import HFInferenceParams, NVQwen3Config, NVQwen3ForCausalLM +from tests.common import BaseModelTest, TestTolerances + + +class TestQwen3Model(BaseModelTest): + """Model tester for Qwen3. + + This class provides Qwen3-specific configuration for the common test suite. + """ + + is_autoregressive = True + + def get_model_class(self) -> Type[PreTrainedModel]: + """Return the Qwen3 TE model class.""" + return NVQwen3ForCausalLM + + def get_config_class(self) -> Type[PretrainedConfig]: + """Return the Qwen3 config class.""" + return NVQwen3Config + + def get_upstream_model_id(self) -> str: + """Return the upstream HuggingFace model ID.""" + return "Qwen/Qwen3-0.6B" + + def get_upstream_model_revision(self) -> str: + """Return the specific revision for the upstream model.""" + return "c1899de" + + def get_tokenizer(self) -> PreTrainedTokenizer: + """Return the Qwen3 tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained(self.get_upstream_model_id()) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def get_upstream_model_class(self) -> Type[PreTrainedModel]: + """Return the upstream HuggingFace model class.""" + + return transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM + + def create_test_config(self, **kwargs) -> PretrainedConfig: + # Limit the number of hidden layers to 2 for faster tests. + return super().create_test_config(num_hidden_layers=2, **kwargs) + + def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: + """Return the list of transformer layers.""" + return list(model.model.layers) # type: ignore + + def get_test_input_data( + self, format: Literal["bshd", "thd"] = "bshd", pad_to_multiple_of: int | None = None + ) -> Dict[str, torch.Tensor]: + """Prepare test input data (text sequences).""" + tokenizer = self.get_tokenizer() + test_texts = [ + "Unless required by applicable law or agreed to in writing, software distributed under the License.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt.", + "The quick brown fox jumps over the lazy dog.", + ] + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + pad_to_multiple_of=pad_to_multiple_of, + mlm=False, + ) + + if format == "thd": + data_collator = DataCollatorWithFlattening( + collator=data_collator, + pad_sequences_to_be_divisible_by=pad_to_multiple_of, + separator_id=-100, + ) + + batch = data_collator([tokenizer(text) for text in test_texts]) + return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + def get_hf_to_te_converter(self) -> Callable: + """Return the HF to TE conversion function.""" + return convert_qwen3_hf_to_te + + def get_te_to_hf_converter(self) -> Callable: + """Return the TE to HF conversion function.""" + return convert_qwen3_te_to_hf + + def get_tolerances(self) -> TestTolerances: + """Return Qwen3-specific test tolerances.""" + return TestTolerances( + golden_value_loss_atol=0.05, + golden_value_loss_rtol=0.02, + golden_value_logits_atol=2.0, + golden_value_logits_rtol=0.01, + cp_loss_atol=0.5, + cp_loss_rtol=0.25, + ) + + # ==================== Qwen3 Overrides ==================== + + @pytest.mark.parametrize("tie_word_embeddings", [True, False]) + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, tie_word_embeddings): + """Test FP8 forward and backward pass with both tied and untied word embeddings.""" + super().test_quantized_model_init_forward_and_backward( + fp8_recipe, input_format, tie_word_embeddings=tie_word_embeddings + ) + + # ==================== Qwen3-Specific Overrides ==================== + + def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """Create HFInferenceParams for the given config. + + Uses config.head_dim (not hidden_size // num_attention_heads) since Qwen3 + has independently configured head_dim. + """ + past_key_values = HFInferenceParams( + max_batch_size=batch_size * num_beams, + max_sequence_length=max_seq_len, + num_heads_kv=config.num_key_value_heads, + head_dim_k=config.head_dim, + dtype=torch.bfloat16, + qkv_format="thd", + max_ctx_len=max_seq_len, + ) + for layer_number in range(1, config.num_hidden_layers + 1): + past_key_values.allocate_memory(layer_number) + return past_key_values diff --git a/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py b/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py new file mode 100644 index 000000000..172875e38 --- /dev/null +++ b/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal reproducer for TransformerEngine qk_norm dtype mismatch. + +When qk_norm_type='RMSNorm' is enabled and params_dtype=bfloat16, the RMSNorm +weight is created in float32. During forward, Q and K are cast to float32 by the +norm while V stays in bfloat16, causing DotProductAttention to fail with: + + AssertionError: Queries, keys and values must have the same data type! + +torch.autocast masks the issue by casting everything back to bfloat16. +""" + +import pytest +import torch +import transformer_engine.pytorch as te + + +@pytest.fixture +def qk_norm_layer(): + """Minimal TransformerLayer with qk_norm enabled.""" + return te.TransformerLayer( + hidden_size=64, + ffn_hidden_size=128, + num_attention_heads=4, + num_gqa_groups=2, + normalization="RMSNorm", + activation="swiglu", + bias=False, + attn_input_format="bshd", + self_attn_mask_type="causal", + qk_norm_type="RMSNorm", + qk_norm_before_rope=True, + hidden_dropout=0, + attention_dropout=0, + layer_number=1, + params_dtype=torch.bfloat16, + device="cuda", + ) + + +@pytest.fixture +def input_tensor(): + return torch.randn(1, 8, 64, dtype=torch.bfloat16, device="cuda") + + +@pytest.mark.xfail(reason="qk_norm RMSNorm casts Q/K to float32 while V stays bfloat16", strict=True) +def test_qk_norm_forward_without_autocast(qk_norm_layer, input_tensor): + """Forward pass without torch.autocast fails due to Q/K vs V dtype mismatch.""" + qk_norm_layer(input_tensor) + + +def test_qk_norm_forward_with_autocast(qk_norm_layer, input_tensor): + """Forward pass with torch.autocast works (masks the dtype issue).""" + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + out = qk_norm_layer(input_tensor) + assert out.dtype == torch.bfloat16 diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 6e017b1fc..a4310e3fe 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -38,6 +38,7 @@ "bionemo-recipes/models/esm2/collator.py": [ "bionemo-recipes/models/llama3/collator.py", "bionemo-recipes/models/mixtral/collator.py", + "bionemo-recipes/models/qwen3/collator.py", "bionemo-recipes/recipes/esm2_native_te/collator.py", "bionemo-recipes/recipes/llama3_native_te/collator.py", "bionemo-recipes/recipes/esm2_peft_te/collator.py", @@ -46,6 +47,7 @@ "bionemo-recipes/models/amplify/src/amplify/state.py", "bionemo-recipes/models/llama3/state.py", "bionemo-recipes/models/mixtral/state.py", + "bionemo-recipes/models/qwen3/state.py", ], "bionemo-recipes/models/llama3/modeling_llama_te.py": [ "bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py", @@ -60,6 +62,7 @@ "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common", "bionemo-recipes/models/mixtral/tests/common", + "bionemo-recipes/models/qwen3/tests/common", ], } From 323b8a147b618994b5ead64c0df3741c41989423 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Sun, 1 Mar 2026 06:36:22 -0800 Subject: [PATCH 3/3] add qwen2.5 Signed-off-by: Peter St. John --- .../models/{qwen3 => qwen}/.ruff.toml | 0 .../models/{qwen3 => qwen}/collator.py | 0 bionemo-recipes/models/qwen/convert_qwen2.py | 250 ++++++++++ .../convert.py => qwen/convert_qwen3.py} | 0 .../models/{qwen3 => qwen}/export.py | 4 +- .../models/qwen/modeling_qwen2_te.py | 453 ++++++++++++++++++ .../{qwen3 => qwen}/modeling_qwen3_te.py | 0 .../models/{qwen3 => qwen}/requirements.txt | 0 .../models/{qwen3 => qwen}/state.py | 0 .../models/{qwen3 => qwen}/tests/__init__.py | 0 .../{qwen3 => qwen}/tests/common/README.md | 0 .../{qwen3 => qwen}/tests/common/__init__.py | 0 .../{qwen3 => qwen}/tests/common/fixtures.py | 0 .../tests/common/test_modeling_common.py | 0 .../models/{qwen3 => qwen}/tests/conftest.py | 0 .../qwen/tests/test_modeling_qwen2_te.py | 173 +++++++ .../tests/test_modeling_qwen3_te.py | 11 +- .../qwen3/tests/test_te_qk_norm_dtype.py | 70 --- 18 files changed, 888 insertions(+), 73 deletions(-) rename bionemo-recipes/models/{qwen3 => qwen}/.ruff.toml (100%) rename bionemo-recipes/models/{qwen3 => qwen}/collator.py (100%) create mode 100644 bionemo-recipes/models/qwen/convert_qwen2.py rename bionemo-recipes/models/{qwen3/convert.py => qwen/convert_qwen3.py} (100%) rename bionemo-recipes/models/{qwen3 => qwen}/export.py (95%) create mode 100644 bionemo-recipes/models/qwen/modeling_qwen2_te.py rename bionemo-recipes/models/{qwen3 => qwen}/modeling_qwen3_te.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/requirements.txt (100%) rename bionemo-recipes/models/{qwen3 => qwen}/state.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/__init__.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/common/README.md (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/common/__init__.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/common/fixtures.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/common/test_modeling_common.py (100%) rename bionemo-recipes/models/{qwen3 => qwen}/tests/conftest.py (100%) create mode 100644 bionemo-recipes/models/qwen/tests/test_modeling_qwen2_te.py rename bionemo-recipes/models/{qwen3 => qwen}/tests/test_modeling_qwen3_te.py (92%) delete mode 100644 bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py diff --git a/bionemo-recipes/models/qwen3/.ruff.toml b/bionemo-recipes/models/qwen/.ruff.toml similarity index 100% rename from bionemo-recipes/models/qwen3/.ruff.toml rename to bionemo-recipes/models/qwen/.ruff.toml diff --git a/bionemo-recipes/models/qwen3/collator.py b/bionemo-recipes/models/qwen/collator.py similarity index 100% rename from bionemo-recipes/models/qwen3/collator.py rename to bionemo-recipes/models/qwen/collator.py diff --git a/bionemo-recipes/models/qwen/convert_qwen2.py b/bionemo-recipes/models/qwen/convert_qwen2.py new file mode 100644 index 000000000..21afbb067 --- /dev/null +++ b/bionemo-recipes/models/qwen/convert_qwen2.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion utilities between HuggingFace Qwen2 and TransformerEngine formats.""" + +import inspect + +import torch +from transformers import Qwen2Config, Qwen2ForCausalLM + +import state +from modeling_qwen2_te import NVQwen2Config, NVQwen2ForCausalLM + + +mapping = { + "model.embed_tokens.weight": "model.embed_tokens.weight", + "model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight", + "model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight", + "model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight", + "model.norm.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", +} + +# Reverse mapping from TE to HF format by reversing the original mapping +reverse_mapping = {v: k for k, v in mapping.items()} + + +def _merge_qkv_bias(ctx: state.TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge separate q, k, v biases into interleave-concatenated qkv bias.""" + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + head_size = target_config.hidden_size // head_num + + q = q.view(head_num, head_size) + k = k.view(num_query_groups, head_size) + v = v.view(num_query_groups, head_size) + + qkv_bias_l = [] + for i in range(num_query_groups): + qkv_bias_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :]) + qkv_bias_l.append(k[i : i + 1, :]) + qkv_bias_l.append(v[i : i + 1, :]) + qkv_bias = torch.cat(qkv_bias_l) + + return qkv_bias.reshape(-1) + + +def _split_qkv_bias(ctx: state.TransformCTX, qkv_bias: torch.Tensor): + """Split interleave-concatenated qkv bias into separate q, k, v biases.""" + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + head_size = target_config.hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape(qkv_total_dim, head_size) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + +def _zero_bias_from_weight(ctx: state.TransformCTX, weight: torch.Tensor): + """Create a zero bias with dimension matching the weight's first axis.""" + return torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype) + + +def _zero_fc1_bias(ctx: state.TransformCTX, gate: torch.Tensor, up: torch.Tensor): + """Create a zero fc1 bias for the merged gate+up projection.""" + return torch.zeros(gate.shape[0] + up.shape[0], device=gate.device, dtype=gate.dtype) + + +def convert_qwen2_hf_to_te(model_hf: Qwen2ForCausalLM, **config_kwargs) -> NVQwen2ForCausalLM: + """Convert a Hugging Face Qwen2 model to a Transformer Engine model. + + Args: + model_hf (nn.Module): The Hugging Face model. + **config_kwargs: Additional configuration kwargs to be passed to NVQwen2Config. + + Returns: + nn.Module: The Transformer Engine model. + """ + config_dict = model_hf.config.to_dict() + # Ensure layer_types is consistent with num_hidden_layers (from_pretrained can leave stale layer_types) + if len(config_dict.get("layer_types", [])) != config_dict.get("num_hidden_layers", 0): + config_dict["layer_types"] = config_dict["layer_types"][: config_dict["num_hidden_layers"]] + te_config = NVQwen2Config(**config_dict, **config_kwargs) + with torch.device("meta"): + model_te = NVQwen2ForCausalLM(te_config) + + if model_hf.config.tie_word_embeddings: + state_dict_ignored_entries = ["lm_head.weight"] + else: + state_dict_ignored_entries = [] + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + # Merge Q/K/V weights into fused QKV + state.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="model.layers.*.self_attention.layernorm_qkv.weight", + fn=state.TransformFns.merge_qkv, + ), + # Merge Q/K/V biases into fused QKV bias + state.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), + target_key="model.layers.*.self_attention.layernorm_qkv.bias", + fn=_merge_qkv_bias, + ), + # Merge gate/up projections into fc1 + state.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="model.layers.*.layernorm_mlp.fc1_weight", + fn=state.TransformFns.merge_fc1, + ), + # TE bias=True creates biases for all linear layers, but Qwen2 only has bias on QKV. + # Initialize the extra TE biases (output projection, MLP) to zero. + state.state_transform( + source_key="model.layers.*.self_attn.o_proj.weight", + target_key="model.layers.*.self_attention.proj.bias", + fn=_zero_bias_from_weight, + ), + state.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="model.layers.*.layernorm_mlp.fc1_bias", + fn=_zero_fc1_bias, + ), + state.state_transform( + source_key="model.layers.*.mlp.down_proj.weight", + target_key="model.layers.*.layernorm_mlp.fc2_bias", + fn=_zero_bias_from_weight, + ), + ], + state_dict_ignored_entries=state_dict_ignored_entries, + ) + + output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone() + + return output_model + + +def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwen2ForCausalLM: + """Convert a Transformer Engine Qwen2 model to a Hugging Face model. + + Args: + model_te (nn.Module): The Transformer Engine model. + **config_kwargs: Additional configuration kwargs to be passed to Qwen2Config. + + Returns: + nn.Module: The Hugging Face model. + """ + # Filter out keys from model_te.config that are not valid Qwen2Config attributes + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(Qwen2Config.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + # Ensure layer_types is consistent with num_hidden_layers + if len(filtered_config.get("layer_types", [])) != filtered_config.get("num_hidden_layers", 0): + filtered_config["layer_types"] = filtered_config["layer_types"][: filtered_config["num_hidden_layers"]] + hf_config = Qwen2Config(**filtered_config, **config_kwargs) + + with torch.device("meta"): + model_hf = Qwen2ForCausalLM(hf_config) + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [ + # Split fused QKV weight into separate Q/K/V + state.state_transform( + source_key="model.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + fn=state.TransformFns.split_qkv, + ), + # Split fused QKV bias into separate Q/K/V biases + state.state_transform( + source_key="model.layers.*.self_attention.layernorm_qkv.bias", + target_key=( + "model.layers.*.self_attn.q_proj.bias", + "model.layers.*.self_attn.k_proj.bias", + "model.layers.*.self_attn.v_proj.bias", + ), + fn=_split_qkv_bias, + ), + # Split fc1 into gate/up projections + state.state_transform( + source_key="model.layers.*.layernorm_mlp.fc1_weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=state.TransformFns.split_fc1, + ), + ], + state_dict_ignored_entries=model_hf._tied_weights_keys, + ) + + output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone() + output_model.tie_weights() + + return output_model diff --git a/bionemo-recipes/models/qwen3/convert.py b/bionemo-recipes/models/qwen/convert_qwen3.py similarity index 100% rename from bionemo-recipes/models/qwen3/convert.py rename to bionemo-recipes/models/qwen/convert_qwen3.py diff --git a/bionemo-recipes/models/qwen3/export.py b/bionemo-recipes/models/qwen/export.py similarity index 95% rename from bionemo-recipes/models/qwen3/export.py rename to bionemo-recipes/models/qwen/export.py index 944013408..bda3aafe9 100644 --- a/bionemo-recipes/models/qwen3/export.py +++ b/bionemo-recipes/models/qwen/export.py @@ -24,7 +24,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -import convert +import convert_qwen3 from modeling_qwen3_te import AUTO_MAP @@ -38,7 +38,7 @@ def export_hf_checkpoint(tag: str, export_path: Path): model_hf = AutoConfig.from_pretrained(tag) model_hf = AutoModelForCausalLM.from_config(model_hf) - model_te = convert.convert_qwen3_hf_to_te(model_hf) + model_te = convert_qwen3.convert_qwen3_hf_to_te(model_hf) model_te.save_pretrained(export_path) tokenizer = AutoTokenizer.from_pretrained(tag) diff --git a/bionemo-recipes/models/qwen/modeling_qwen2_te.py b/bionemo-recipes/models/qwen/modeling_qwen2_te.py new file mode 100644 index 000000000..e76ba8d25 --- /dev/null +++ b/bionemo-recipes/models/qwen/modeling_qwen2_te.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TransformerEngine-optimized Qwen2 model.""" + +from collections import OrderedDict +from typing import ClassVar, Unpack + +import torch +import torch.nn as nn +import transformer_engine.pytorch +import transformers +from transformer_engine.pytorch.attention import InferenceParams +from transformer_engine.pytorch.attention.inference import PagedKVCacheManager +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import PreTrainedModel, Qwen2Config +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +AUTO_MAP = { + "AutoConfig": "modeling_qwen2_te.NVQwen2Config", + "AutoModel": "modeling_qwen2_te.NVQwen2Model", + "AutoModelForCausalLM": "modeling_qwen2_te.NVQwen2ForCausalLM", +} + + +class NVQwen2Config(Qwen2Config): + """NVQwen2 configuration.""" + + # Attention input format: + # "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + # "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + + +class NVQwen2PreTrainedModel(PreTrainedModel): + """Base class for NVQwen2 models.""" + + config_class = NVQwen2Config + base_model_prefix = "model" + _no_split_modules = ("TransformerLayer",) + _skip_keys_device_placement = ("past_key_values",) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The embed_tokens layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + + self.model.rotary_emb.inv_freq = Qwen2RotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling this method for TE modules, since the default _init_weights will assume + # any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking `LayerNormLinear` + # and `LayerNormMLP` modules that use `weight` for the linear layer and `layer_norm_weight` for the layer + # norm. + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out TransformerEngine's _extra_state keys. + + TransformerEngine layers add _extra_state attributes that are not compatible with + standard PyTorch/HuggingFace model loading. These are filtered out to ensure + checkpoints can be loaded with from_pretrained(). + """ + state_dict = super().state_dict(*args, **kwargs) + # Filter out _extra_state keys which are TransformerEngine-specific and not loadable + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVQwen2Model(NVQwen2PreTrainedModel): + """Qwen2 model implemented in Transformer Engine.""" + + def __init__(self, config: Qwen2Config): + """Initialize the NVQwen2 model.""" + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + head_dim = config.hidden_size // config.num_attention_heads + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=True, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + activation="swiglu", + attn_input_format=config.attn_input_format, + self_attn_mask_type=config.self_attn_mask_type, + num_gqa_groups=config.num_key_value_heads, + kv_channels=head_dim, + window_size=(config.sliding_window, config.sliding_window) + if config.layer_types[layer_idx] == "sliding_attention" and config.sliding_window is not None + else None, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + # We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original + # Qwen2RotaryEmbedding. + self.rotary_emb = RotaryPositionEmbedding(head_dim) + self.rotary_emb.inv_freq = Qwen2RotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: InferenceParams | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass for the NVQwen2 model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + use_cache (bool): Whether to use cache. + **kwargs: Additional keyword arguments. + + Returns: + BaseModelOutputWithPast: The output of the model. + """ + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # TE-specific input handling. + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + # Left-side padding is not supported in TE layers, so to make huggingface-style generation work with TE we + # dynamically convert to THD-style inputs in our forward pass, and then convert back to BSHD for the output. + # This lets the entire transformer stack run in THD mode. This might be slower for BSHD + padding with fused + # attention backend, but it should be faster for the flash attention backend. + assert attention_mask is not None, "Attention mask is required when packing BSHD inputs." + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + # Convert HF mask (1=attend, 0=pad) to TE boolean mask (True=masked, False=attend) + attention_mask = ~attention_mask[:, None, None, :].bool() + + if isinstance(past_key_values, InferenceParams): # InferenceParams is TE's way of managing kv-caching. + # In generation mode, we set the length to 1 for each batch index. Otherwise, we use the attention mask to + # compute the lengths of each sequence in the batch. + lengths = ( + attention_mask.sum(dim=1).tolist() + if attention_mask.shape == input_ids.shape + else [1] * input_ids.shape[0] + ) + past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) + + # Ensure that rotary embeddings are computed with at a higher precision + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer. Note that these will be in THD format; we could possibly pad + # these with the same _pad_input call as below if we wanted them returned in BSHD format. + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + if should_pack_inputs: + # If we've converted BSHD to THD for our TE layers, we need to convert back to BSHD for the output. + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + +class NVQwen2ForCausalLM(NVQwen2PreTrainedModel, transformers.GenerationMixin): + """Qwen2 model with causal language head.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config): + """Initialize the NVQwen2ForCausalLM model.""" + super().__init__(config) + self.model = NVQwen2Model(config) + self.vocab_size = config.vocab_size + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, ...], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + use_cache: bool | None = None, + cache_position: torch.Tensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + """Forward pass for the NVQwen2ForCausalLM model. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + past_key_values (tuple[tuple[torch.Tensor, ...], ...]): The past key values. + inputs_embeds (torch.Tensor): The inputs embeds. + labels (torch.Tensor): The labels. + shift_labels (torch.Tensor): Labels that have already been shifted by the dataloader, to be used instead of + labels for the loss function. For context parallelism, it is more reliable to shift the labels before + splitting the batch into shards. + use_cache (bool): Whether to use cache. + cache_position (torch.Tensor): The cache position. + logits_to_keep (int | torch.Tensor): Whether to keep only the last logits to reduce the memory footprint of + the model during generation. + **kwargs: Additional keyword arguments. + + Returns: + CausalLMOutputWithPast: The output of the model. + """ + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: # With THD inputs, batch and sequence dimensions are collapsed in the first dimension. + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +torch._dynamo.config.capture_scalar_outputs = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert a THD tensor to a BSHD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert a BSHD tensor to a THD equivalent tensor. + + Adapted from huggingface/transformers/modeling_flash_attention_utils.py + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: # Likely in generation mode with kv-caching + return ( + hidden_states.squeeze(1), # hidden_states + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), # indices + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), # cu_seqlens + 1, # max_seqlen + 1, # seqused + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +class HFInferenceParams(InferenceParams): + """Extension of the InferenceParams class to support beam search.""" + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorder the cache based on the beam indices.""" + if isinstance(self.cache_manager, PagedKVCacheManager): + raise NotImplementedError("Beam search is not supported for paged cache manager.") + for layer_number, (key_cache, value_cache) in self.cache_manager.cache.items(): + updated_key_cache = key_cache.index_select(0, beam_idx) + updated_value_cache = value_cache.index_select(0, beam_idx) + self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache) diff --git a/bionemo-recipes/models/qwen3/modeling_qwen3_te.py b/bionemo-recipes/models/qwen/modeling_qwen3_te.py similarity index 100% rename from bionemo-recipes/models/qwen3/modeling_qwen3_te.py rename to bionemo-recipes/models/qwen/modeling_qwen3_te.py diff --git a/bionemo-recipes/models/qwen3/requirements.txt b/bionemo-recipes/models/qwen/requirements.txt similarity index 100% rename from bionemo-recipes/models/qwen3/requirements.txt rename to bionemo-recipes/models/qwen/requirements.txt diff --git a/bionemo-recipes/models/qwen3/state.py b/bionemo-recipes/models/qwen/state.py similarity index 100% rename from bionemo-recipes/models/qwen3/state.py rename to bionemo-recipes/models/qwen/state.py diff --git a/bionemo-recipes/models/qwen3/tests/__init__.py b/bionemo-recipes/models/qwen/tests/__init__.py similarity index 100% rename from bionemo-recipes/models/qwen3/tests/__init__.py rename to bionemo-recipes/models/qwen/tests/__init__.py diff --git a/bionemo-recipes/models/qwen3/tests/common/README.md b/bionemo-recipes/models/qwen/tests/common/README.md similarity index 100% rename from bionemo-recipes/models/qwen3/tests/common/README.md rename to bionemo-recipes/models/qwen/tests/common/README.md diff --git a/bionemo-recipes/models/qwen3/tests/common/__init__.py b/bionemo-recipes/models/qwen/tests/common/__init__.py similarity index 100% rename from bionemo-recipes/models/qwen3/tests/common/__init__.py rename to bionemo-recipes/models/qwen/tests/common/__init__.py diff --git a/bionemo-recipes/models/qwen3/tests/common/fixtures.py b/bionemo-recipes/models/qwen/tests/common/fixtures.py similarity index 100% rename from bionemo-recipes/models/qwen3/tests/common/fixtures.py rename to bionemo-recipes/models/qwen/tests/common/fixtures.py diff --git a/bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py b/bionemo-recipes/models/qwen/tests/common/test_modeling_common.py similarity index 100% rename from bionemo-recipes/models/qwen3/tests/common/test_modeling_common.py rename to bionemo-recipes/models/qwen/tests/common/test_modeling_common.py diff --git a/bionemo-recipes/models/qwen3/tests/conftest.py b/bionemo-recipes/models/qwen/tests/conftest.py similarity index 100% rename from bionemo-recipes/models/qwen3/tests/conftest.py rename to bionemo-recipes/models/qwen/tests/conftest.py diff --git a/bionemo-recipes/models/qwen/tests/test_modeling_qwen2_te.py b/bionemo-recipes/models/qwen/tests/test_modeling_qwen2_te.py new file mode 100644 index 000000000..3850ccf97 --- /dev/null +++ b/bionemo-recipes/models/qwen/tests/test_modeling_qwen2_te.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Qwen2 model. + +This file provides comprehensive tests for the Qwen2 model including: +- Common tests from the test library (meta device init, golden values, conversion, FP8) +- Qwen2-specific tests (inference, generation with KV-cache) +""" + +import os +from typing import Callable, Dict, List, Literal, Type + +import pytest +import torch +import transformers +from torch import nn +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, +) + +from collator import DataCollatorWithFlattening +from convert_qwen2 import convert_qwen2_hf_to_te, convert_qwen2_te_to_hf +from modeling_qwen2_te import HFInferenceParams, NVQwen2Config, NVQwen2ForCausalLM +from tests.common import BaseModelTest, TestTolerances + + +class TestQwen2Model(BaseModelTest): + """Model tester for Qwen2. + + This class provides Qwen2-specific configuration for the common test suite. + """ + + is_autoregressive = True + + def get_model_class(self) -> Type[PreTrainedModel]: + """Return the Qwen2 TE model class.""" + return NVQwen2ForCausalLM + + def get_config_class(self) -> Type[PretrainedConfig]: + """Return the Qwen2 config class.""" + return NVQwen2Config + + def get_upstream_model_id(self) -> str: + """Return the upstream HuggingFace model ID.""" + return "Qwen/Qwen2.5-0.5B" + + def get_upstream_model_revision(self) -> str: + """Return the specific revision for the upstream model.""" + return "060db64" + + def get_tokenizer(self) -> PreTrainedTokenizer: + """Return the Qwen2 tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained(self.get_upstream_model_id()) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + def get_upstream_model_class(self) -> Type[PreTrainedModel]: + """Return the upstream HuggingFace model class.""" + return transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM + + def create_test_config(self, **kwargs) -> PretrainedConfig: + # Limit the number of hidden layers to 2 for faster tests. + return super().create_test_config(num_hidden_layers=2, **kwargs) + + def get_reference_model( + self, dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2" + ) -> PreTrainedModel: + """Return the reference HuggingFace model.""" + if os.environ.get("CI") == "true": + pytest.skip("Skipping Qwen2 reference model test in CI, requires Qwen2.5-0.5B download ~1GB") + return super().get_reference_model(dtype=dtype, attn_implementation=attn_implementation) + + def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: + """Return the list of transformer layers.""" + return list(model.model.layers) # type: ignore + + def get_test_input_data( + self, format: Literal["bshd", "thd"] = "bshd", pad_to_multiple_of: int | None = None + ) -> Dict[str, torch.Tensor]: + """Prepare test input data (text sequences).""" + tokenizer = self.get_tokenizer() + test_texts = [ + "Unless required by applicable law or agreed to in writing, software distributed under the License.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt.", + "The quick brown fox jumps over the lazy dog.", + ] + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + pad_to_multiple_of=pad_to_multiple_of, + mlm=False, + ) + + if format == "thd": + data_collator = DataCollatorWithFlattening( + collator=data_collator, + pad_sequences_to_be_divisible_by=pad_to_multiple_of, + separator_id=-100, + ) + + batch = data_collator([tokenizer(text) for text in test_texts]) + return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + def get_hf_to_te_converter(self) -> Callable: + """Return the HF to TE conversion function.""" + return convert_qwen2_hf_to_te + + def get_te_to_hf_converter(self) -> Callable: + """Return the TE to HF conversion function.""" + return convert_qwen2_te_to_hf + + def get_tolerances(self) -> TestTolerances: + """Return Qwen2-specific test tolerances.""" + return TestTolerances( + golden_value_loss_atol=0.05, + golden_value_loss_rtol=0.02, + golden_value_logits_atol=2.0, + golden_value_logits_rtol=0.01, + cp_loss_atol=0.5, + cp_loss_rtol=0.25, + ) + + # ==================== Qwen2 Overrides ==================== + + @pytest.mark.parametrize("tie_word_embeddings", [True, False]) + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, tie_word_embeddings): + """Test FP8 forward and backward pass with both tied and untied word embeddings.""" + super().test_quantized_model_init_forward_and_backward( + fp8_recipe, input_format, tie_word_embeddings=tie_word_embeddings + ) + + # ==================== Qwen2-Specific Overrides ==================== + + def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """Create HFInferenceParams for the given config. + + Uses hidden_size // num_attention_heads for head_dim since Qwen2 does not + independently configure head_dim. + """ + head_dim = config.hidden_size // config.num_attention_heads + past_key_values = HFInferenceParams( + max_batch_size=batch_size * num_beams, + max_sequence_length=max_seq_len, + num_heads_kv=config.num_key_value_heads, + head_dim_k=head_dim, + dtype=torch.bfloat16, + qkv_format="thd", + max_ctx_len=max_seq_len, + ) + for layer_number in range(1, config.num_hidden_layers + 1): + past_key_values.allocate_memory(layer_number) + return past_key_values diff --git a/bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py b/bionemo-recipes/models/qwen/tests/test_modeling_qwen3_te.py similarity index 92% rename from bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py rename to bionemo-recipes/models/qwen/tests/test_modeling_qwen3_te.py index bfe7b9235..89d2ffa60 100644 --- a/bionemo-recipes/models/qwen3/tests/test_modeling_qwen3_te.py +++ b/bionemo-recipes/models/qwen/tests/test_modeling_qwen3_te.py @@ -20,6 +20,7 @@ - Qwen3-specific tests (inference, generation with KV-cache) """ +import os from typing import Callable, Dict, List, Literal, Type import pytest @@ -35,7 +36,7 @@ ) from collator import DataCollatorWithFlattening -from convert import convert_qwen3_hf_to_te, convert_qwen3_te_to_hf +from convert_qwen3 import convert_qwen3_hf_to_te, convert_qwen3_te_to_hf from modeling_qwen3_te import HFInferenceParams, NVQwen3Config, NVQwen3ForCausalLM from tests.common import BaseModelTest, TestTolerances @@ -84,6 +85,14 @@ def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: """Return the list of transformer layers.""" return list(model.model.layers) # type: ignore + def get_reference_model( + self, dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2" + ) -> PreTrainedModel: + """Return the reference HuggingFace model.""" + if os.environ.get("CI") == "true": + pytest.skip("Skipping Qwen3 reference model test in CI, requires Qwen3-0.6B download ~1.5GB") + return super().get_reference_model(dtype=dtype, attn_implementation=attn_implementation) + def get_test_input_data( self, format: Literal["bshd", "thd"] = "bshd", pad_to_multiple_of: int | None = None ) -> Dict[str, torch.Tensor]: diff --git a/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py b/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py deleted file mode 100644 index 172875e38..000000000 --- a/bionemo-recipes/models/qwen3/tests/test_te_qk_norm_dtype.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Minimal reproducer for TransformerEngine qk_norm dtype mismatch. - -When qk_norm_type='RMSNorm' is enabled and params_dtype=bfloat16, the RMSNorm -weight is created in float32. During forward, Q and K are cast to float32 by the -norm while V stays in bfloat16, causing DotProductAttention to fail with: - - AssertionError: Queries, keys and values must have the same data type! - -torch.autocast masks the issue by casting everything back to bfloat16. -""" - -import pytest -import torch -import transformer_engine.pytorch as te - - -@pytest.fixture -def qk_norm_layer(): - """Minimal TransformerLayer with qk_norm enabled.""" - return te.TransformerLayer( - hidden_size=64, - ffn_hidden_size=128, - num_attention_heads=4, - num_gqa_groups=2, - normalization="RMSNorm", - activation="swiglu", - bias=False, - attn_input_format="bshd", - self_attn_mask_type="causal", - qk_norm_type="RMSNorm", - qk_norm_before_rope=True, - hidden_dropout=0, - attention_dropout=0, - layer_number=1, - params_dtype=torch.bfloat16, - device="cuda", - ) - - -@pytest.fixture -def input_tensor(): - return torch.randn(1, 8, 64, dtype=torch.bfloat16, device="cuda") - - -@pytest.mark.xfail(reason="qk_norm RMSNorm casts Q/K to float32 while V stays bfloat16", strict=True) -def test_qk_norm_forward_without_autocast(qk_norm_layer, input_tensor): - """Forward pass without torch.autocast fails due to Q/K vs V dtype mismatch.""" - qk_norm_layer(input_tensor) - - -def test_qk_norm_forward_with_autocast(qk_norm_layer, input_tensor): - """Forward pass with torch.autocast works (masks the dtype issue).""" - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - out = qk_norm_layer(input_tensor) - assert out.dtype == torch.bfloat16