Skip to content

Commit 7e75003

Browse files
committed
Add BSHD packed dataloader toggle and FP8 test
- Update train_fsdp2.py and train_ddp.py to toggle between dataloaders: - use_sequence_packing=true + attn_input_format=bshd -> BSHD packed - use_sequence_packing=true + attn_input_format=thd -> THD packed - use_sequence_packing=false -> BSHD unpacked - Add test_train_fsdp2_fp8_bshd_packed test for FP8 with BSHD packing Signed-off-by: Savitha Srinivasan <savithas@nvidia.com>
1 parent 5c5e8ac commit 7e75003

3 files changed

Lines changed: 38 additions & 4 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/tests/test_train.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,28 @@ def test_train_fsdp2_fp8_first_last_bf16(tmp_path, recipe_path):
431431
assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
432432

433433

434+
def test_train_fsdp2_fp8_bshd_packed(tmp_path, recipe_path):
435+
"""Test that FSDP2 training works with FP8 enabled and BSHD packed dataloader."""
436+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
437+
sanity_config = compose(
438+
config_name="L0_sanity",
439+
overrides=[
440+
f"+wandb.dir={tmp_path}",
441+
f"checkpoint.ckpt_dir={tmp_path}",
442+
"fp8_config.enabled=true",
443+
"use_sequence_packing=true",
444+
"config_kwargs.attn_input_format=bshd",
445+
"+dataset.pad_to_multiple_of=16",
446+
],
447+
)
448+
449+
final_loss = main_fsdp2(sanity_config)
450+
gc.collect()
451+
torch.cuda.empty_cache()
452+
453+
assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
454+
455+
434456
@requires_datacenter_hardware
435457
def test_sanity_fsdp2_cp(tmp_path, recipe_path):
436458
# Run the training script with Hydra configuration overrides

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from transformers.models.llama.modeling_llama import LlamaForCausalLM
2929

3030
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
31-
from dataset import create_bshd_dataloader, create_thd_dataloader
31+
from dataset import create_bshd_dataloader, create_bshd_packed_dataloader, create_thd_dataloader
3232
from distributed_config import DistributedConfig
3333
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
3434
from perf_logger import PerfLogger
@@ -93,8 +93,14 @@ def main(args: DictConfig) -> float | None:
9393
)
9494

9595
if args.use_sequence_packing:
96-
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
96+
if args.config_kwargs.attn_input_format == "bshd":
97+
# BSHD with full packing (cross-boundary attention, no cu_seqlens)
98+
train_dataloader, dataset_or_sampler = create_bshd_packed_dataloader(dist_config, **args.dataset)
99+
else:
100+
# THD with packing (respects boundaries via cu_seqlens)
101+
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
97102
else:
103+
# Standard BSHD with windowing (no packing)
98104
train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
99105

100106
if args.use_torch_compile:

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
save_final_model_fsdp2,
3737
should_save_checkpoint,
3838
)
39-
from dataset import create_bshd_dataloader, create_thd_dataloader
39+
from dataset import create_bshd_dataloader, create_bshd_packed_dataloader, create_thd_dataloader
4040
from distributed_config import DistributedConfig
4141
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
4242
from perf_logger import PerfLogger
@@ -110,8 +110,14 @@ def main(args: DictConfig) -> float | None:
110110
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
111111

112112
if args.use_sequence_packing:
113-
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
113+
if args.config_kwargs.attn_input_format == "bshd":
114+
# BSHD with full packing (cross-boundary attention, no cu_seqlens)
115+
train_dataloader, dataset_or_sampler = create_bshd_packed_dataloader(dist_config, **args.dataset)
116+
else:
117+
# THD with packing (respects boundaries via cu_seqlens)
118+
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
114119
else:
120+
# Standard BSHD with windowing (no packing)
115121
train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)
116122

117123
if args.use_torch_compile:

0 commit comments

Comments
 (0)