Skip to content

Commit c589cbb

Browse files
committed
refactor autoregressive model tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent bd72d88 commit c589cbb

6 files changed

Lines changed: 421 additions & 130 deletions

File tree

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format):
885891
msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}",
886892
)
887893

888-
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format):
894+
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs):
889895
"""Test that model initialized with FP8 works correctly."""
890896
if input_format == "thd" and not HAS_DATA_CENTER_GPU:
891897
pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.")
892898

893899
model_class = self.get_model_class()
894-
config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal")
900+
config = self.create_test_config(
901+
attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs
902+
)
895903

896904
# Initialize with FP8
897905
with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe):
@@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
906914
input_data["labels"] = input_data["input_ids"].clone()
907915

908916
# Forward and backward pass with FP8
909-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
910-
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
911-
outputs = model(**input_data)
917+
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
918+
outputs = model(**input_data)
912919

913920
loss = outputs.loss
914921
assert torch.isfinite(loss)
@@ -979,4 +986,123 @@ def test_meta_fp8_init(self, fp8_recipe):
979986
model.init_empty_weights()
980987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
981988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
991+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
raise NotImplementedError(
1007+
"Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008+
)
1009+
1010+
def test_generate_without_cache(self):
1011+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1012+
if not self.is_autoregressive:
1013+
pytest.skip("Not an autoregressive model")
1014+
1015+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1016+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1017+
model.eval()
1018+
1019+
tokenizer = self.get_tokenizer()
1020+
prompt = "The quick brown fox jumps over"
1021+
inputs = tokenizer(prompt, return_tensors="pt")
1022+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1023+
1024+
with torch.no_grad():
1025+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1026+
1027+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1028+
1029+
def test_generate_with_cache(self):
1030+
"""Test single-prompt generation with KV-cache (THD format)."""
1031+
if not self.is_autoregressive:
1032+
pytest.skip("Not an autoregressive model")
1033+
1034+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1035+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1036+
model.eval()
1037+
1038+
tokenizer = self.get_tokenizer()
1039+
prompt = "The quick brown fox jumps over"
1040+
inputs = tokenizer(prompt, return_tensors="pt")
1041+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1042+
1043+
past_key_values = self._create_inference_params(config, batch_size=1)
1044+
1045+
with torch.no_grad():
1046+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1047+
1048+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1049+
1050+
def test_generate_with_cache_batched(self):
1051+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1052+
if not self.is_autoregressive:
1053+
pytest.skip("Not an autoregressive model")
1054+
1055+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1056+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1057+
model.eval()
1058+
1059+
tokenizer = self.get_tokenizer()
1060+
prompts = (
1061+
"The quick brown fox jumps over the lazy dog.",
1062+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1063+
)
1064+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1065+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1066+
1067+
past_key_values = self._create_inference_params(config, batch_size=2)
1068+
1069+
with torch.no_grad():
1070+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1071+
1072+
assert output_ids.shape[0] == 2
1073+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1074+
1075+
def test_generate_with_cache_beam_search(self):
1076+
"""Test batched generation with KV-cache and beam search."""
1077+
if not self.is_autoregressive:
1078+
pytest.skip("Not an autoregressive model")
1079+
1080+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1081+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1082+
model.eval()
1083+
1084+
tokenizer = self.get_tokenizer()
1085+
prompts = (
1086+
"The quick brown fox jumps over the lazy dog.",
1087+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1088+
)
1089+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1090+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1091+
1092+
num_beams = 2
1093+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
1094+
1095+
with torch.no_grad():
1096+
output_ids = model.generate(
1097+
**inputs,
1098+
max_new_tokens=16,
1099+
use_cache=True,
1100+
past_key_values=past_key_values,
1101+
num_beams=num_beams,
1102+
do_sample=True,
1103+
)
1104+
1105+
assert output_ids.shape[0] == 2
1106+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1107+
9821108
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

bionemo-recipes/models/llama3/tests/common/test_modeling_common.py

Lines changed: 131 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format):
885891
msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}",
886892
)
887893

888-
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format):
894+
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs):
889895
"""Test that model initialized with FP8 works correctly."""
890896
if input_format == "thd" and not HAS_DATA_CENTER_GPU:
891897
pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.")
892898

893899
model_class = self.get_model_class()
894-
config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal")
900+
config = self.create_test_config(
901+
attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs
902+
)
895903

896904
# Initialize with FP8
897905
with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe):
@@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
906914
input_data["labels"] = input_data["input_ids"].clone()
907915

908916
# Forward and backward pass with FP8
909-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
910-
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
911-
outputs = model(**input_data)
917+
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
918+
outputs = model(**input_data)
912919

913920
loss = outputs.loss
914921
assert torch.isfinite(loss)
@@ -979,4 +986,123 @@ def test_meta_fp8_init(self, fp8_recipe):
979986
model.init_empty_weights()
980987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
981988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
991+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
raise NotImplementedError(
1007+
"Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008+
)
1009+
1010+
def test_generate_without_cache(self):
1011+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1012+
if not self.is_autoregressive:
1013+
pytest.skip("Not an autoregressive model")
1014+
1015+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1016+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1017+
model.eval()
1018+
1019+
tokenizer = self.get_tokenizer()
1020+
prompt = "The quick brown fox jumps over"
1021+
inputs = tokenizer(prompt, return_tensors="pt")
1022+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1023+
1024+
with torch.no_grad():
1025+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1026+
1027+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1028+
1029+
def test_generate_with_cache(self):
1030+
"""Test single-prompt generation with KV-cache (THD format)."""
1031+
if not self.is_autoregressive:
1032+
pytest.skip("Not an autoregressive model")
1033+
1034+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1035+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1036+
model.eval()
1037+
1038+
tokenizer = self.get_tokenizer()
1039+
prompt = "The quick brown fox jumps over"
1040+
inputs = tokenizer(prompt, return_tensors="pt")
1041+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1042+
1043+
past_key_values = self._create_inference_params(config, batch_size=1)
1044+
1045+
with torch.no_grad():
1046+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1047+
1048+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1049+
1050+
def test_generate_with_cache_batched(self):
1051+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1052+
if not self.is_autoregressive:
1053+
pytest.skip("Not an autoregressive model")
1054+
1055+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1056+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1057+
model.eval()
1058+
1059+
tokenizer = self.get_tokenizer()
1060+
prompts = (
1061+
"The quick brown fox jumps over the lazy dog.",
1062+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1063+
)
1064+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1065+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1066+
1067+
past_key_values = self._create_inference_params(config, batch_size=2)
1068+
1069+
with torch.no_grad():
1070+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1071+
1072+
assert output_ids.shape[0] == 2
1073+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1074+
1075+
def test_generate_with_cache_beam_search(self):
1076+
"""Test batched generation with KV-cache and beam search."""
1077+
if not self.is_autoregressive:
1078+
pytest.skip("Not an autoregressive model")
1079+
1080+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1081+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1082+
model.eval()
1083+
1084+
tokenizer = self.get_tokenizer()
1085+
prompts = (
1086+
"The quick brown fox jumps over the lazy dog.",
1087+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1088+
)
1089+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1090+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1091+
1092+
num_beams = 2
1093+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
1094+
1095+
with torch.no_grad():
1096+
output_ids = model.generate(
1097+
**inputs,
1098+
max_new_tokens=16,
1099+
use_cache=True,
1100+
past_key_values=past_key_values,
1101+
num_beams=num_beams,
1102+
do_sample=True,
1103+
)
1104+
1105+
assert output_ids.shape[0] == 2
1106+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1107+
9821108
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

0 commit comments

Comments
 (0)