Skip to content

Commit b2ddae1

Browse files
authored
Assert fp32 for rope embeddings, misc test fixes (#1496)
This wouldn't have caught @savitha-eng's `cast_forward_inputs=True` bug (that casts these right as they enter the TransformerLayer), but it turns out our test suite was actually casting these to bfloat16 with `model.to(bfloat16)` calls 😬 . This also fixes a few other misc. test failures I saw locally making sure the esm2 & llama3 recipe and model tests pass. will require #1495 for tests to pass --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 5676160 commit b2ddae1

12 files changed

Lines changed: 41 additions & 32 deletions

File tree

bionemo-recipes/models/esm2/modeling_esm_te.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Adapted from `modeling_esm.py` in huggingface/transformers.
2323
"""
2424

25+
import warnings
2526
from typing import ClassVar, Literal, Optional, Unpack
2627

2728
# TODO: put import guard around transformer_engine here, with an informative error message around
@@ -197,6 +198,8 @@ def forward(
197198
with torch.autocast(device_type="cuda", enabled=False):
198199
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
199200
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
201+
if te_rope_emb.dtype == torch.float32:
202+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
200203

201204
for layer_module in self.layers:
202205
if kwargs.get("output_hidden_states", False):
@@ -374,7 +377,7 @@ def forward(
374377
)
375378
encoder_outputs = self.encoder(
376379
embedding_output,
377-
attention_mask=extended_attention_mask,
380+
attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
378381
**kwargs,
379382
)
380383
sequence_output = encoder_outputs[0]

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def test_smoke_forward_pass(self, input_format):
452452
config = self.create_test_config(attn_input_format=input_format)
453453

454454
model = model_class(config)
455-
model.to(torch.bfloat16)
456455
model.to("cuda")
457456

458457
# Prepare input data
@@ -475,7 +474,6 @@ def test_smoke_backward_pass(self, input_format):
475474
config = self.create_test_config(attn_input_format=input_format)
476475

477476
model = model_class(config)
478-
model.to(torch.bfloat16)
479477
model.to("cuda")
480478

481479
# Prepare input data
@@ -498,7 +496,6 @@ def test_smoke_model_with_loss(self, input_format):
498496
config = self.create_test_config(attn_input_format=input_format)
499497

500498
model = model_class(config)
501-
model.to(torch.bfloat16)
502499
model.to("cuda")
503500

504501
# Prepare input data with labels
@@ -522,7 +519,6 @@ def test_forward_and_backward(self, input_format):
522519
config = self.create_test_config(attn_input_format=input_format)
523520

524521
model = model_class(config)
525-
model.to(torch.bfloat16)
526522
model.to("cuda")
527523

528524
# Prepare input data
@@ -1011,7 +1007,7 @@ def test_generate_without_cache(self):
10111007
pytest.skip("Not an autoregressive model")
10121008

10131009
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1014-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1010+
model = self.get_model_class()(config).to("cuda")
10151011
model.eval()
10161012

10171013
tokenizer = self.get_tokenizer()
@@ -1030,7 +1026,7 @@ def test_generate_with_cache(self):
10301026
pytest.skip("Not an autoregressive model")
10311027

10321028
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1033-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1029+
model = self.get_model_class()(config).to("cuda")
10341030
model.eval()
10351031

10361032
tokenizer = self.get_tokenizer()
@@ -1051,7 +1047,7 @@ def test_generate_with_cache_batched(self):
10511047
pytest.skip("Not an autoregressive model")
10521048

10531049
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1054-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1050+
model = self.get_model_class()(config).to("cuda")
10551051
model.eval()
10561052

10571053
tokenizer = self.get_tokenizer()
@@ -1076,7 +1072,7 @@ def test_generate_with_cache_beam_search(self):
10761072
pytest.skip("Not an autoregressive model")
10771073

10781074
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1079-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1075+
model = self.get_model_class()(config).to("cuda")
10801076
model.eval()
10811077

10821078
tokenizer = self.get_tokenizer()

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""TransformerEngine-optimized Llama model."""
1717

18+
import warnings
1819
from collections import OrderedDict
1920
from typing import ClassVar, Unpack
2021

@@ -236,6 +237,8 @@ def forward(
236237
# Ensure that rotary embeddings are computed with at a higher precision
237238
with torch.autocast(device_type="cuda", enabled=False):
238239
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
240+
if te_rope_emb.dtype == torch.float32:
241+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
239242

240243
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
241244
if output_hidden_states:

bionemo-recipes/models/llama3/tests/common/test_modeling_common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def test_smoke_forward_pass(self, input_format):
452452
config = self.create_test_config(attn_input_format=input_format)
453453

454454
model = model_class(config)
455-
model.to(torch.bfloat16)
456455
model.to("cuda")
457456

458457
# Prepare input data
@@ -475,7 +474,6 @@ def test_smoke_backward_pass(self, input_format):
475474
config = self.create_test_config(attn_input_format=input_format)
476475

477476
model = model_class(config)
478-
model.to(torch.bfloat16)
479477
model.to("cuda")
480478

481479
# Prepare input data
@@ -498,7 +496,6 @@ def test_smoke_model_with_loss(self, input_format):
498496
config = self.create_test_config(attn_input_format=input_format)
499497

500498
model = model_class(config)
501-
model.to(torch.bfloat16)
502499
model.to("cuda")
503500

504501
# Prepare input data with labels
@@ -522,7 +519,6 @@ def test_forward_and_backward(self, input_format):
522519
config = self.create_test_config(attn_input_format=input_format)
523520

524521
model = model_class(config)
525-
model.to(torch.bfloat16)
526522
model.to("cuda")
527523

528524
# Prepare input data
@@ -1011,7 +1007,7 @@ def test_generate_without_cache(self):
10111007
pytest.skip("Not an autoregressive model")
10121008

10131009
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1014-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1010+
model = self.get_model_class()(config).to("cuda")
10151011
model.eval()
10161012

10171013
tokenizer = self.get_tokenizer()
@@ -1030,7 +1026,7 @@ def test_generate_with_cache(self):
10301026
pytest.skip("Not an autoregressive model")
10311027

10321028
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1033-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1029+
model = self.get_model_class()(config).to("cuda")
10341030
model.eval()
10351031

10361032
tokenizer = self.get_tokenizer()
@@ -1051,7 +1047,7 @@ def test_generate_with_cache_batched(self):
10511047
pytest.skip("Not an autoregressive model")
10521048

10531049
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1054-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1050+
model = self.get_model_class()(config).to("cuda")
10551051
model.eval()
10561052

10571053
tokenizer = self.get_tokenizer()
@@ -1076,7 +1072,7 @@ def test_generate_with_cache_beam_search(self):
10761072
pytest.skip("Not an autoregressive model")
10771073

10781074
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1079-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1075+
model = self.get_model_class()(config).to("cuda")
10801076
model.eval()
10811077

10821078
tokenizer = self.get_tokenizer()

bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def test_smoke_forward_pass(self, input_format):
452452
config = self.create_test_config(attn_input_format=input_format)
453453

454454
model = model_class(config)
455-
model.to(torch.bfloat16)
456455
model.to("cuda")
457456

458457
# Prepare input data
@@ -475,7 +474,6 @@ def test_smoke_backward_pass(self, input_format):
475474
config = self.create_test_config(attn_input_format=input_format)
476475

477476
model = model_class(config)
478-
model.to(torch.bfloat16)
479477
model.to("cuda")
480478

481479
# Prepare input data
@@ -498,7 +496,6 @@ def test_smoke_model_with_loss(self, input_format):
498496
config = self.create_test_config(attn_input_format=input_format)
499497

500498
model = model_class(config)
501-
model.to(torch.bfloat16)
502499
model.to("cuda")
503500

504501
# Prepare input data with labels
@@ -522,7 +519,6 @@ def test_forward_and_backward(self, input_format):
522519
config = self.create_test_config(attn_input_format=input_format)
523520

524521
model = model_class(config)
525-
model.to(torch.bfloat16)
526522
model.to("cuda")
527523

528524
# Prepare input data
@@ -1011,7 +1007,7 @@ def test_generate_without_cache(self):
10111007
pytest.skip("Not an autoregressive model")
10121008

10131009
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1014-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1010+
model = self.get_model_class()(config).to("cuda")
10151011
model.eval()
10161012

10171013
tokenizer = self.get_tokenizer()
@@ -1030,7 +1026,7 @@ def test_generate_with_cache(self):
10301026
pytest.skip("Not an autoregressive model")
10311027

10321028
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1033-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1029+
model = self.get_model_class()(config).to("cuda")
10341030
model.eval()
10351031

10361032
tokenizer = self.get_tokenizer()
@@ -1051,7 +1047,7 @@ def test_generate_with_cache_batched(self):
10511047
pytest.skip("Not an autoregressive model")
10521048

10531049
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1054-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1050+
model = self.get_model_class()(config).to("cuda")
10551051
model.eval()
10561052

10571053
tokenizer = self.get_tokenizer()
@@ -1076,7 +1072,7 @@ def test_generate_with_cache_beam_search(self):
10761072
pytest.skip("Not an autoregressive model")
10771073

10781074
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1079-
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1075+
model = self.get_model_class()(config).to("cuda")
10801076
model.eval()
10811077

10821078
tokenizer = self.get_tokenizer()

bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Adapted from `modeling_esm.py` in huggingface/transformers.
2323
"""
2424

25+
import warnings
2526
from typing import ClassVar, Literal, Optional, Unpack
2627

2728
# TODO: put import guard around transformer_engine here, with an informative error message around
@@ -197,6 +198,8 @@ def forward(
197198
with torch.autocast(device_type="cuda", enabled=False):
198199
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
199200
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
201+
if te_rope_emb.dtype == torch.float32:
202+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
200203

201204
for layer_module in self.layers:
202205
if kwargs.get("output_hidden_states", False):
@@ -374,7 +377,7 @@ def forward(
374377
)
375378
encoder_outputs = self.encoder(
376379
embedding_output,
377-
attention_mask=extended_attention_mask,
380+
attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
378381
**kwargs,
379382
)
380383
sequence_output = encoder_outputs[0]

bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Adapted from `modeling_esm.py` in huggingface/transformers.
2323
"""
2424

25+
import warnings
2526
from typing import ClassVar, Literal, Optional, Unpack
2627

2728
# TODO: put import guard around transformer_engine here, with an informative error message around
@@ -197,6 +198,8 @@ def forward(
197198
with torch.autocast(device_type="cuda", enabled=False):
198199
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
199200
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
201+
if te_rope_emb.dtype == torch.float32:
202+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
200203

201204
for layer_module in self.layers:
202205
if kwargs.get("output_hidden_states", False):
@@ -374,7 +377,7 @@ def forward(
374377
)
375378
encoder_outputs = self.encoder(
376379
embedding_output,
377-
attention_mask=extended_attention_mask,
380+
attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
378381
**kwargs,
379382
)
380383
sequence_output = encoder_outputs[0]

bionemo-recipes/recipes/esm2_native_te/tests/test_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,6 @@ def test_cp_dataloader(recipe_path):
964964
f"Expected at most {expected_tokens_per_rank + 100} tokens, got {actual_shape}"
965965
)
966966
assert batch["labels"].shape[1] == actual_shape
967+
968+
dataloader.close()
969+
torch.distributed.destroy_process_group()

bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Adapted from `modeling_esm.py` in huggingface/transformers.
2323
"""
2424

25+
import warnings
2526
from typing import ClassVar, Literal, Optional, Unpack
2627

2728
# TODO: put import guard around transformer_engine here, with an informative error message around
@@ -197,6 +198,8 @@ def forward(
197198
with torch.autocast(device_type="cuda", enabled=False):
198199
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
199200
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
201+
if te_rope_emb.dtype == torch.float32:
202+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
200203

201204
for layer_module in self.layers:
202205
if kwargs.get("output_hidden_states", False):
@@ -374,7 +377,7 @@ def forward(
374377
)
375378
encoder_outputs = self.encoder(
376379
embedding_output,
377-
attention_mask=extended_attention_mask,
380+
attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask,
378381
**kwargs,
379382
)
380383
sequence_output = encoder_outputs[0]

bionemo-recipes/recipes/esm2_peft_te/tests/test_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def peft_model(recipe_path):
3333
config.id2label = SS3_ID2LABEL
3434
config.label2id = SS3_LABEL2ID
3535

36-
base_model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True)
36+
base_model = AutoModelForTokenClassification.from_config(config, trust_remote_code=True, dtype=torch.bfloat16)
3737

3838
lora_config = peft.LoraConfig(
3939
task_type=peft.TaskType.TOKEN_CLS,
@@ -45,7 +45,7 @@ def peft_model(recipe_path):
4545
)
4646

4747
model = peft.get_peft_model(base_model, lora_config)
48-
model.to(device="cuda", dtype=torch.bfloat16)
48+
model.to(device="cuda")
4949
model.eval()
5050
return model
5151

0 commit comments

Comments
 (0)