From c69dd2d6f8c4558a83e7bac2c0960d8596475e6a Mon Sep 17 00:00:00 2001 From: John St John Date: Sat, 20 Dec 2025 00:06:50 +0000 Subject: [PATCH] Fixing argparse argument name error Signed-off-by: John St John --- .../evo2/models/megatron/hyena/hyena_utils.py | 2 +- .../evo2_megatron/src/bionemo/evo2/run/train.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py index d0e452ef6..07fd29f78 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/megatron/hyena/hyena_utils.py @@ -756,7 +756,7 @@ def forward(self, L, *args, **kwargs): # noqa: N803 """ return self.filter(L, *args, **kwargs) - @torch.compile(mode="max-autotune") + @torch.compile(mode="default") def filter(self, L, *args, **kwargs): # noqa: N803 """Compute the filter as a function of h and decay for the requested sequence length.""" h = self.h[:, :L] diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index bf789b784..54011e92c 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -291,6 +291,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: parser.add_argument( "--grad-reduce-in-fp32", action="store_true", default=False, help="Gradient reduce in FP32." ) # DONE + parser.add_argument( + "--fsdp", + action="store_true", + default=False, + help="Enable FSDP training.", + ) parser.add_argument("--use-megatron-comm-overlap-llama3-8k", action="store_true", default=False) # DONE parser.add_argument( "--tp-comm-overlap-backend", @@ -710,9 +716,9 @@ def train(args: argparse.Namespace) -> None: recipe_kwargs["stride"] = args.stride recipe_kwargs["window_min_length_threshold"] = args.window_min_length_threshold recipe_kwargs["rc_aug"] = args.rc_aug - elif args.dataset_config_path: + elif args.dataset_config: recipe_kwargs["dataset_dir"] = args.dataset_dir - recipe_kwargs["dataset_config_path"] = args.dataset_config_path + recipe_kwargs["dataset_config_path"] = args.dataset_config recipe_kwargs["pad_eod_loss_mask"] = args.eod_pad_in_loss_mask @@ -747,7 +753,7 @@ def train(args: argparse.Namespace) -> None: cfg: ConfigContainer = pretrain_config(**recipe_kwargs) cfg.checkpoint.async_save = args.ckpt_async_save - cfg.checkpoint.ckpt_format = args.ckpt_format + cfg.checkpoint.ckpt_format = args.ckpt_format if not args.fsdp else "fsdp_dtensor" cfg.checkpoint.save_interval = args.eval_interval cfg.checkpoint.save_optim = True cfg.checkpoint.save_rng = True @@ -828,6 +834,10 @@ def train(args: argparse.Namespace) -> None: cfg.ddp.overlap_grad_reduce = args.overlap_grad_reduce cfg.ddp.grad_reduce_in_fp32 = args.grad_reduce_in_fp32 cfg.ddp.check_for_nan_in_grad = not args.no_check_for_nan_in_grad + if args.fsdp: + cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params" + cfg.ddp.use_megatron_fsdp = True + cfg.checkpoint.ckpt_format = "fsdp_dtensor" if args.use_megatron_comm_overlap_llama3_8k: # Pick the floating point appropriate config. fp8 = "fp8" in args.mixed_precision_recipe