diff --git a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py index aca85b7855..d04c9e5d24 100644 --- a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py index 18e21ffcd4..72e0a206a2 100644 --- a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py @@ -270,3 +270,7 @@ def test_convert_state_dict_explicit_check(self): model_te.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() == model_te.state_dict()["lm_head.decoder.weight"].data_ptr() ) + + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """These are unused for non-autoregressive models.""" + pass diff --git a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py index aca85b7855..d04c9e5d24 100644 --- a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py b/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py index cd4998c17b..caf0ddf242 100644 --- a/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py +++ b/bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py @@ -20,17 +20,13 @@ - LLaMA-specific tests (inference, generation, THD inputs, etc.) """ -import gc -import os from typing import Callable, Dict, List, Literal, Type import pytest import torch -import transformer_engine.pytorch import transformers from torch import nn from transformers import ( - AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, PretrainedConfig, @@ -41,7 +37,7 @@ from collator import DataCollatorWithFlattening from convert import convert_llama_hf_to_te, convert_llama_te_to_hf from modeling_llama_te import HFInferenceParams, NVLlamaConfig, NVLlamaForCausalLM -from tests.common import HAS_DATA_CENTER_GPU, BaseModelTest, TestTolerances +from tests.common import BaseModelTest, TestTolerances class TestLlama3Model(BaseModelTest): @@ -50,6 +46,8 @@ class TestLlama3Model(BaseModelTest): This class provides LLaMA3-specific configuration for the common test suite. """ + is_autoregressive = True + def get_model_class(self) -> Type[PreTrainedModel]: """Return the LLaMA3 TE model class.""" return NVLlamaForCausalLM @@ -139,7 +137,23 @@ def get_tolerances(self) -> TestTolerances: cp_loss_rtol=0.25, ) - # ==================== LLaMA3-Specific Tests ==================== + # ==================== LLaMA3-Specific Overrides ==================== + + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + """Create HFInferenceParams for the given config.""" + past_key_values = HFInferenceParams( + max_batch_size=batch_size * num_beams, + max_sequence_length=max_seq_len, + num_heads_kv=config.num_key_value_heads, + head_dim_k=config.hidden_size // config.num_attention_heads, + dtype=torch.bfloat16, + qkv_format="thd", + max_ctx_len=max_seq_len, + ) + for layer_number in range(1, config.num_hidden_layers + 1): + past_key_values.allocate_memory(layer_number) + return past_key_values + def test_golden_values(self, input_format): # pyright: ignore[reportIncompatibleMethodOverride] """For llama3, we can test both the dynamic sequence packing and native bshd attention formats.""" model_hf = self.get_reference_model(dtype=torch.bfloat16) @@ -166,247 +180,6 @@ def test_golden_values(self, input_format): # pyright: ignore[reportIncompatibl @pytest.mark.parametrize("tie_word_embeddings", [True, False]) def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, tie_word_embeddings): # pyright: ignore[reportIncompatibleMethodOverride] """There was a weird bug in BIO-217 on tied weights with quantized model init, so we test both cases.""" - if input_format == "thd" and not HAS_DATA_CENTER_GPU: - pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") - - model_class = self.get_model_class() - config = self.create_test_config( - attn_input_format=input_format, - self_attn_mask_type="padding_causal", - tie_word_embeddings=tie_word_embeddings, + super().test_quantized_model_init_forward_and_backward( + fp8_recipe, input_format, tie_word_embeddings=tie_word_embeddings ) - - # Initialize with FP8 - with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): - model = model_class(config) - - model.to("cuda") - model.eval() - - # Prepare input data - input_data = self.get_test_input_data(input_format, pad_to_multiple_of=32) - if "labels" not in input_data: - input_data["labels"] = input_data["input_ids"].clone() - - # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) - - loss = outputs.loss - assert torch.isfinite(loss) - - loss.backward() - - # Verify gradients exist - for name, param in model.named_parameters(): - if param.requires_grad: - assert param.grad is not None, f"Parameter {name} has no gradient after FP8 backward pass" - - -# NOTE: Keeping remaining LLaMA-specific tests (inference, generation, etc.) unchanged -# These tests are specific to causal LM functionality and don't have ESM2 equivalents - - -@pytest.fixture -def input_text(): - return ( - """Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License.""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore " - "et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip " - "ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu " - "fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt " - "mollit anim id est laborum.", - ) - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_llama_model_golden_values_padding_left(input_text): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(input_text, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - labels = inputs["input_ids"].clone() - labels[labels == tokenizer.pad_token_id] = -100 - model_hf.to("cuda") - with torch.no_grad(): - outputs_hf = model_hf(**inputs, labels=labels, output_hidden_states=True) - - del model_hf - gc.collect() - torch.cuda.empty_cache() - - model_te.to("cuda") - with torch.no_grad(): - outputs_te = model_te(**inputs, labels=labels, output_hidden_states=True) - - torch.testing.assert_close(outputs_te.loss, outputs_hf.loss, atol=0.02, rtol=0.03) # Higher than I'd like. - torch.testing.assert_close( - outputs_te.logits[inputs["attention_mask"].to(bool)], - outputs_hf.logits[inputs["attention_mask"].to(bool)], - atol=1.5, - rtol=0.01, - ) - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_hf_llama_model_generate_bshd(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_hf.to("cuda") - - with torch.no_grad(): - output_ids = model_hf.generate(**inputs, max_new_tokens=16, use_cache=False) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal") - - prompt = """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""" - - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - past_key_values = HFInferenceParams( - max_batch_size=1, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache_bshd(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - past_key_values = HFInferenceParams( - max_batch_size=2, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] - - -@pytest.mark.skipif(os.getenv("CI", "false") == "true", reason="Skipping test in CI not download llama3 model.") -def test_te_llama_model_generate_with_cache_bshd_beam_search(): - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", dtype=torch.bfloat16) - model_te = convert_llama_hf_to_te(model_hf, attn_input_format="thd") - - prompt = ( - """ - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at""", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - - tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(prompt, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - model_te.to("cuda") - - num_beams = 2 - - past_key_values = HFInferenceParams( - max_batch_size=2 * num_beams, - max_sequence_length=256, - num_heads_kv=model_te.config.num_key_value_heads, - head_dim_k=model_te.config.hidden_size // model_te.config.num_attention_heads, - dtype=torch.bfloat16, - qkv_format="thd", - max_ctx_len=256, - ) - - for layer_number in range(1, model_te.config.num_hidden_layers + 1): - past_key_values.allocate_memory(layer_number) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=True, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua. Ut enim ad minim " in generated_text[1] diff --git a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py index 4b66838135..6dd5a0a66c 100644 --- a/bionemo-recipes/models/mixtral/modeling_mixtral_te.py +++ b/bionemo-recipes/models/mixtral/modeling_mixtral_te.py @@ -166,7 +166,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Permute tokens by expert using TE moe_permute permuted_hidden, row_id_map = transformer_engine.pytorch.moe_permute( - hidden_states, selected_experts, map_type="index" + hidden_states, selected_experts.to(torch.int32), map_type="index" ) # Compute m_splits: number of tokens per expert @@ -185,11 +185,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Down projection expert_output = self.experts_down(intermediate, m_splits=m_splits) # [total_tokens, H] - # Unpermute and combine with routing weights + # Unpermute and combine with routing weights (keep probs in float32 for numerical stability) output = transformer_engine.pytorch.moe_unpermute( expert_output, row_id_map, - merging_probs=routing_weights.to(expert_output.dtype), + merging_probs=routing_weights, map_type="index", ) diff --git a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py index aca85b7855..d04c9e5d24 100644 --- a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Literal, Type +from typing import Any, Callable, Dict, List, Literal, Type import pytest import torch @@ -82,6 +82,10 @@ class BaseModelTest(ABC): Subclasses must implement all abstract methods to provide model-specific configuration, data preparation, and conversion functions. + Set ``is_autoregressive = True`` in subclasses for causal LM models to + enable generation / KV-cache smoke tests. Non-autoregressive models + (e.g. ESM2) leave the default ``False`` and those tests are skipped. + Example: ```python class ESM2ModelTester(BioNeMoModelTester): @@ -98,6 +102,8 @@ def get_upstream_model_id(self): ``` """ + is_autoregressive: bool = False + @abstractmethod def get_model_class(self) -> Type[PreTrainedModel]: """Return the TransformerEngine model class to test. @@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format): msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}", ) - def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format): + def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs): """Test that model initialized with FP8 works correctly.""" if input_format == "thd" and not HAS_DATA_CENTER_GPU: pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.") model_class = self.get_model_class() - config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal") + config = self.create_test_config( + attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs + ) # Initialize with FP8 with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe): @@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma input_data["labels"] = input_data["input_ids"].clone() # Forward and backward pass with FP8 - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - with transformer_engine.pytorch.autocast(recipe=fp8_recipe): - outputs = model(**input_data) + with transformer_engine.pytorch.autocast(recipe=fp8_recipe): + outputs = model(**input_data) loss = outputs.loss assert torch.isfinite(loss) @@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) + # ==================== Generation Tests (Autoregressive Models Only) ==================== + @abstractmethod + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any: + """Create inference params for KV-cache generation tests. + + Autoregressive model tests must override this method to provide + model-specific ``HFInferenceParams`` with allocated KV-cache memory. + + Args: + config: Model configuration. + batch_size: Batch size. + max_seq_len: Maximum sequence length. + num_beams: Number of beams for beam search. + + Returns: + HFInferenceParams instance with allocated memory. + """ + pass + + def test_generate_without_cache(self): + """Test basic generation without KV-cache (BSHD, use_cache=False).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache(self): + """Test single-prompt generation with KV-cache (THD format).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompt = "The quick brown fox jumps over" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=1) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_batched(self): + """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + past_key_values = self.create_inference_params(config, batch_size=2) + + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + + def test_generate_with_cache_beam_search(self): + """Test batched generation with KV-cache and beam search.""" + if not self.is_autoregressive: + pytest.skip("Not an autoregressive model") + + config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") + model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) + model.eval() + + tokenizer = self.get_tokenizer() + prompts = ( + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + + num_beams = 2 + past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams) + + with torch.no_grad(): + output_ids = model.generate( + **inputs, + max_new_tokens=16, + use_cache=True, + past_key_values=past_key_values, + num_beams=num_beams, + do_sample=True, + ) + + assert output_ids.shape[0] == 2 + assert output_ids.shape[1] > inputs["input_ids"].shape[1] + # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. diff --git a/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py b/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py index 08e9f30917..e15645a3c1 100644 --- a/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py +++ b/bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py @@ -47,6 +47,8 @@ class TestMixtralModel(BaseModelTest): This class provides Mixtral-specific configuration for the common test suite. """ + is_autoregressive = True + def get_model_class(self) -> Type[PreTrainedModel]: """Return the Mixtral TE model class.""" return NVMixtralForCausalLM @@ -146,9 +148,7 @@ def get_tolerances(self) -> TestTolerances: cp_loss_rtol=0.25, ) - # ==================== Mixtral-Specific KV-Cache Tests ==================== - - def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): + def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1): """Create HFInferenceParams for the given config.""" past_key_values = HFInferenceParams( max_batch_size=batch_size * num_beams, @@ -162,119 +162,3 @@ def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_be for layer_number in range(1, config.num_hidden_layers + 1): past_key_values.allocate_memory(layer_number) return past_key_values - - def test_generate_with_cache(self): - """Test single-prompt generation with KV-cache (THD format).""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompt = "The quick brown fox jumps over" - inputs = tokenizer(prompt, return_tensors="pt") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - past_key_values = self._create_inference_params(config, batch_size=1) - - with torch.no_grad(): - output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - # Verify generation produced new tokens - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - def test_generate_with_cache_batched(self): - """Test batched generation with KV-cache (left-padded BSHD converted to THD).""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompts = ( - "The quick brown fox jumps over the lazy dog.", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - past_key_values = self._create_inference_params(config, batch_size=2) - - with torch.no_grad(): - output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values) - - # Verify generation produced new tokens for both sequences - assert output_ids.shape[0] == 2 - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - def test_generate_with_cache_beam_search(self): - """Test batched generation with KV-cache and beam search.""" - config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal") - model = self.get_model_class()(config).to("cuda").to(torch.bfloat16) - model.eval() - - tokenizer = self.get_tokenizer() - prompts = ( - "The quick brown fox jumps over the lazy dog.", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - num_beams = 2 - past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) - - with torch.no_grad(): - output_ids = model.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=True, - ) - - # Verify generation produced new tokens for both sequences - assert output_ids.shape[0] == 2 - assert output_ids.shape[1] > inputs["input_ids"].shape[1] - - # ==================== Standalone Mixtral Generation Tests ==================== - - def test_te_mixtral_model_generate_with_cache_beam_search(self): - """Test Mixtral generation with KV-cache and beam search using real model weights.""" - import gc - - model_hf = self.get_reference_model() - model_te = convert_mixtral_hf_to_te(model_hf, attn_input_format="thd", self_attn_mask_type="padding_causal") - del model_hf - gc.collect() - - model_te.to("cuda") - model_te.eval() - - tokenizer = self.get_tokenizer() - - prompts = ( - 'Licensed under the Apache License, Version 2.0 (the "License");' - " you may not use this file except in compliance with the License." - " You may obtain a copy of the License at", - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore", - ) - inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left") - inputs = {k: v.to("cuda") for k, v in inputs.items()} - - num_beams = 2 - config = model_te.config - past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams) - - with torch.no_grad(): - output_ids = model_te.generate( - **inputs, - max_new_tokens=16, - use_cache=True, - past_key_values=past_key_values, - num_beams=num_beams, - do_sample=False, - ) - - generated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert "http://www.apache.org/licenses/LICENSE-2.0" in generated_text[0] - assert "et dolore magna aliqua" in generated_text[1]