diff --git a/bionemo-recipes/recipes/evo2_megatron/README.md b/bionemo-recipes/recipes/evo2_megatron/README.md index bd0e891e4c..5bc855f331 100644 --- a/bionemo-recipes/recipes/evo2_megatron/README.md +++ b/bionemo-recipes/recipes/evo2_megatron/README.md @@ -19,35 +19,142 @@ uv pip install -c pip-constraints.txt -e . --no-build-isolation ## Usage +### Example job + ``` # 3. Run an example job ## 2. if on a6000s, you may need to disable p2p to avoid crashing export NCCL_P2P_DISABLE=1 ## 3. Run the job: -torchrun --nproc-per-node 8 --no-python \ +torchrun --nproc-per-node 2 --no-python \ train_evo2 \ --hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_256 \ --model-size striped_hyena_1b_nv_parallel --max-steps 12 --eval-interval 10 \ --eval-iters 3 --mock-data \ - --micro-batch-size 32 --global-batch-size 256 --seq-length 1024 \ + --micro-batch-size 16 --global-batch-size 32 --seq-length 1024 \ --tensor-model-parallel 1 \ --use-precision-aware-optimizer --dataset-seed 33 \ - --seed 41 --ckpt-async-save --spike-no-more-embedding-init \ + --seed 41 --spike-no-more-embedding-init \ --no-weight-decay-embeddings --cross-entropy-loss-fusion \ --align-param-gather --overlap-param-gather --grad-reduce-in-fp32 \ --decay-steps 100 --warmup-steps 10 \ - --mixed-precision-recipe bf16-mixed \ + --mixed-precision-recipe bf16_with_fp8_current_scaling_mixed \ --no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \ --attention-dropout 0.001 --hidden-dropout 0.001 \ --eod-pad-in-loss-mask --enable-preemption \ --log-interval 5 --debug-ddp-parity-freq 10 \ - --wandb-project evo2-recipes-verification-tmp \ - --wandb-run-name tmp_workstation_run_mock_data \ - --result-dir tmpbf16 --no-renormalize-loss + --result-dir tmpfp8 --no-renormalize-loss +``` + +### Example fine-tune from an existing checkpoint + +First convert the checkpoint from nemo2 format (temporary step until we upload the new files) + +Good checkpoint names to try are: + +- evo2/1b-8k-bf16:1.0 (model_size: 1b) +- evo2/7b-1m:1.0 (model_size: 7b_arc_longcontext) +- evo2/40b-1m-fp8-bf16:1.0 (model_size: 40b_arc_longcontext) + +Other than the 7b version, the other two are checkpoints fine-tuned by the BioNeMo team to support both FP8 and BF16 +precision. The 7b version worked well on both FP8 and BF16 out of the box so it was not fine-tuned further. If you do +want to use one of the FP8 sensitive checkpoints, like `evo2/40b-1m` then be sure to add the `--vortex-style-fp8` +option to the checkpoint conversion step below. Also note that although 8k versions of the 7b and 40b checkpoints exist, +it is advisable to use the longer context versions since they were trained further and still run on shorter inputs. + +See `download_bionemo_data --list-resources` for other checkpoint options and a list of available +downloadable resources. + ``` +CKPT_NAME=evo2/1b-8k-bf16:1.0 +CKPT_OUT_DIR=evo2_1b_8k_bf16_mbridge +evo2_convert_nemo2_to_mbridge \ + --mixed-precision-recipe bf16_with_fp8_current_scaling_mixed \ + --tokenizer-path tokenizers/nucleotide_fast_tokenizer_512 \ + --model-size 1b \ + --seq-length 8192 \ + --nemo2-ckpt-dir $(download_bionemo_data $CKPT_NAME) \ + --mbridge-ckpt-dir $CKPT_OUT_DIR + +``` + +Now run like before, but include the fine-tuned checkpoint directory you converted in the previous step with +`--finetune-ckpt-dir $CKPT_OUT_DIR`. Also if you have problems with `bf16_with_fp8_current_scaling_mixed` try +`bf16_mixed`. + +``` +torchrun --nproc-per-node 2 --no-python \ + train_evo2 \ + --hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_512 \ + --model-size 1b --max-steps 12 --eval-interval 10 \ + --eval-iters 3 --mock-data \ + --micro-batch-size 16 --global-batch-size 32 --seq-length 1024 \ + --tensor-model-parallel 1 \ + --use-precision-aware-optimizer --dataset-seed 33 \ + --seed 41 \ + --cross-entropy-loss-fusion \ + --align-param-gather --overlap-param-gather --grad-reduce-in-fp32 \ + --decay-steps 100 --warmup-steps 10 \ + --mixed-precision-recipe bf16_with_fp8_current_scaling_mixed \ + --no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \ + --attention-dropout 0.001 --hidden-dropout 0.001 \ + --eod-pad-in-loss-mask --enable-preemption \ + --log-interval 5 --debug-ddp-parity-freq 10 \ + --result-dir tmpfp8-ft-example --no-renormalize-loss \ + --finetune-ckpt-dir $CKPT_OUT_DIR +``` + +## Where do the custom command line programs come from? + +See `pyproject.toml` for where runnable programs like `train_evo2` and `evo2_convert_nemo2_to_mbridge` are implemented +in code. ## Docker build ``` docker build -t evo2_megatron_recipe-$(git rev-parse --short HEAD) . ``` + +## Performance and accuracy comparisons + +NOTE: this section is largely a work in progress. This reflects the most updated information, but may not reflect the +current state of the code base at any given time. + +### Training accuracy convergence + +We ran a 12 hour 48 H100 GPU training run to compare megatron bridge with nemo2. We found that FP8 current scaling +converges by around the 5,000th step to the bf16 lines. And that bf16 is comparable with nemo2. Interestingly in nemo2 +bf16 and fp8 followed nearly identical trajectories for the first 5k steps as well. Note that in a typical training run +we are performing over 100k steps, so different behavior in the first 5k steps is less worrisome if the endpoints are +comparable. + +![Training Convergence Comparison](assets/mbridge_to_nemo_training_convergence_7ksteps.png) + +### Training performance comparisons + +FP8 current scaling which is supposed to have better convergence properties than delayed scaling, performs nearly as +well as delayed scaling in mbridge. Even leaving multiple transformer layers in bf16 precision trains faster than fp8 +delayed scaling in nemo2. + +| Evo2 1B Run | Seconds per step (lower is better) | Tokens/sec/GPU | Global Batch Size | Number of GPUs | Vocab Size | +| :----------------------------------------------: | :--------------------------------: | :------------: | :---------------: | :------------: | :--------: | +| MBridge BF16 | 6.10 | 26,859 | 960 | 48 | 256 | +| MBridge FP8 (delayed) | 5.38 | 30,453 | 960 | 48 | 256 | +| MBridge FP8 (current) | 5.44 | 28,755 | 960 | 48 | 512 | +| MBridge FP8 (current first/last two layers bf16) | 5.47 | 28,598 | 960 | 48 | 512 | +| Nemo2 FP8 (delayed) | 6.18 | 26,511 | 960 | 48 | 512 | + +Activation memory optimizations have enabled context parallelism to work better with evo2 style models in our mbridge +implementation than the previous nemo2 implementation. Since TP requires more node to node communication, you generally +want to limit TP to your fastest interconnects, which are typically configured in nodes of 8 GPUs. Evo2 would previously +OOM with these more ideal configurations, requiring much larger than typical levels of TP to handle long context +training. With our latest changes to the evo2 forward pass, we can now handle more typical TP vs CP configurations. +This enables significantly faster step timing at long context, as well as demonstrating up to 2M context length. We +have currently demonstrated small training runs at 2M context on only 512 H100 GPUs for the 40b parameter model. + +| Configuration | Precision | TP | CP | Number of Nodes | Number of GPUs | Context Length | Global Batch Size | Seconds per Step | +| :---------------: | :---------: | :-: | :-: | :-------------: | :------------: | :------------: | :---------------: | :--------------: | +| NeMo2 | fp8-delayed | 64 | 2 | 32 | 256 | 1M | 2 | 44 | +| NeMo2 | fp8-delayed | 8 | 16 | 32 | 256 | 1M | 2 | OOM | +| MBridge Optimized | bf16 | 8 | 16 | 32 | 256 | 1M | 2 | 30 | +| 2M Stress Test | bf16 | 8 | 32 | 64 | 512 | 2M | 2 | 48 | diff --git a/bionemo-recipes/recipes/evo2_megatron/assets/mbridge_to_nemo_training_convergence_7ksteps.png b/bionemo-recipes/recipes/evo2_megatron/assets/mbridge_to_nemo_training_convergence_7ksteps.png new file mode 100644 index 0000000000..d289fdf7af Binary files /dev/null and b/bionemo-recipes/recipes/evo2_megatron/assets/mbridge_to_nemo_training_convergence_7ksteps.png differ diff --git a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml index 8f6cbc9b52..355d808460 100644 --- a/bionemo-recipes/recipes/evo2_megatron/pyproject.toml +++ b/bionemo-recipes/recipes/evo2_megatron/pyproject.toml @@ -40,6 +40,7 @@ train_evo2 = "bionemo.evo2.run.train:main" #predict_evo2 = "bionemo.evo2.run.predict:main" preprocess_evo2 = "bionemo.evo2.data.preprocess:main" splice_evo2 = "bionemo.evo2.data.transcript_extraction:main" +evo2_convert_nemo2_to_mbridge = "bionemo.evo2.utils.checkpoint.nemo2_to_mbridge:main" #evo2_convert_to_nemo2 = "bionemo.evo2.utils.checkpoint.convert_to_nemo:main" #evo2_nemo2_to_hf = "bionemo.evo2.utils.checkpoint.nemo2_to_hf:main" #evo2_remove_optimizer = "bionemo.evo2.utils.checkpoint.evo2_remove_optimizer:main" diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py index 6bc42c9e35..947d00e15b 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/recipes/evo2.py @@ -267,7 +267,6 @@ def _evo2_common( ), tokenizer=TokenizerConfig( tokenizer_type="HuggingFaceTokenizer", - hf_tokenizer_kwargs={"trust_remote_code": True}, tokenizer_model=hf_tokenizer_model_or_path or "EleutherAI/gpt-neox-20b", ), checkpoint=CheckpointConfig( diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py index 4dd32b7adf..46289038e2 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py @@ -710,9 +710,9 @@ def train(args: argparse.Namespace) -> None: recipe_kwargs["stride"] = args.stride recipe_kwargs["window_min_length_threshold"] = args.window_min_length_threshold recipe_kwargs["rc_aug"] = args.rc_aug - elif args.dataset_config_path: + elif args.dataset_config: recipe_kwargs["dataset_dir"] = args.dataset_dir - recipe_kwargs["dataset_config_path"] = args.dataset_config_path + recipe_kwargs["dataset_config_path"] = args.dataset_config recipe_kwargs["pad_eod_loss_mask"] = args.eod_pad_in_loss_mask @@ -918,6 +918,7 @@ def train(args: argparse.Namespace) -> None: if args.finetune_ckpt_dir: cfg.checkpoint.finetune = True cfg.checkpoint.pretrained_checkpoint = args.finetune_ckpt_dir + cfg.checkpoint.dist_ckpt_strictness = "ignore_all" # necessary unfortunately to avoid extra_state issues. if args.nvidia_fault_tolerance: cfg.ft = FaultToleranceConfig( enable_ft_package=True, diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/nemo2_to_mbridge.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/nemo2_to_mbridge.py new file mode 100644 index 0000000000..dcc079c0c3 --- /dev/null +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/nemo2_to_mbridge.py @@ -0,0 +1,286 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from pathlib import Path +from typing import Any + +import torch +import torch.distributed.checkpoint as dcp +from megatron.bridge.training.checkpointing import save_tokenizer_assets +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES +from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter +from torch.distributed.checkpoint.metadata import BytesStorageMetadata + +from bionemo.evo2.models.evo2_provider import HYENA_MODEL_OPTIONS, HyenaModelProvider +from bionemo.evo2.recipes.evo2 import evo2_1b_pretrain_config as pretrain_config + + +logger = logging.getLogger(__name__) + + +def convert_nemo2_dcp_to_megatron( + src_path: str | Path, + dest_path: str | Path, +): + """Convert a torch_dist format checkpoint with nemo2 style names to one with megatron bridge style names. + + Args: + src_path: Path to the source DCP checkpoint. + dest_path: Path to the destination DCP checkpoint. + """ + logger.info(f"Reading metadata from {src_path}...") + reader = FileSystemReader(str(src_path)) + metadata = reader.read_metadata() + + # 1. Pre-allocate state_dict based on metadata + # We need to construct the state_dict so dcp.load knows what to load. + state_dict = {} + total_size_bytes = 0 + + for key, item_meta in metadata.state_dict_metadata.items(): + if isinstance(item_meta, BytesStorageMetadata): + # Skip or handle non-tensor data if necessary + continue + + # Create empty tensor on CPU with correct shape/dtype + # DCP will load data into these tensors in-place + state_dict[key] = torch.empty(item_meta.size, dtype=item_meta.properties.dtype, device="cpu") + + # Track size to calculate shard count later + total_size_bytes += state_dict[key].numel() * state_dict[key].element_size() + + print(f"Loading {len(state_dict)} tensors into memory (Approx {total_size_bytes / 1e9:.2f} GB)...") + + # 2. Load directly from DCP to memory (no_dist=True for single process) + dcp.load(state_dict=state_dict, storage_reader=reader, no_dist=True) + + # 3. Munge Keys + # Removing "module." prefix as requested + prefix_len = len("module.") + new_state_dict = {} + for k, v in state_dict.items(): + # Safety check: ensure key actually has the prefix before slicing + if k.startswith("module."): + new_key = k[prefix_len:] + else: + new_key = k + new_state_dict[new_key] = v + + logger.info(f"Keys munged. saving to {dest_path}...") + + # 4. Save to DCP with Sharding + # Calculate required threads to achieve target shard size + # DCP FileSystemWriter writes one file per thread when single_file_per_rank=False + + writer = FileSystemWriter( + dest_path, + single_file_per_rank=False, # roughly one file per parameter + thread_count=os.cpu_count(), + ) + + dcp.save(state_dict=new_state_dict, storage_writer=writer, no_dist=True) + logger.info("Conversion complete.") + + +def _dummy_train_state() -> dict[str, torch.Tensor]: + """Use for train_state.pt file, and latest_train_state.pt file in mbridge checkpoint.""" + return { + "step": torch.tensor(1, dtype=torch.int32), + "consumed_train_samples": torch.tensor(0, dtype=torch.int32), + "skipped_train_samples": torch.tensor(0, dtype=torch.int32), + "consumed_valid_samples": torch.tensor(0, dtype=torch.int32), + "floating_point_operations_so_far": torch.tensor(0, dtype=torch.float64), + "do_train": torch.tensor(True, dtype=torch.bool), + "do_valid": torch.tensor(True, dtype=torch.bool), + "do_test": torch.tensor(True, dtype=torch.bool), + } + + +def _dummy_common_pt_dict() -> dict[str, Any]: + """Use for common.pt file in mbridge checkpoint.""" + return { + "checkpoint_version": 3.0, + "iteration": 1, + "optimizer": {"param_state_sharding_type": "dp_reshardable"}, + "opt_param_scheduler": { + "max_lr": 0.0003, + "lr_warmup_steps": 10, + "num_steps": 2560, + "lr_decay_style": "cosine", + "lr_decay_steps": 25600, + "min_lr": 3e-05, + "start_wd": 0.01, + "end_wd": 0.01, + "wd_incr_style": "constant", + "wd_incr_steps": 3072, + }, + "content_metadata": { + "singleton_local_shards": False, + "distrib_optim_sharding_type": "dp_reshardable", + "chained_optim_avoid_prefix": True, + }, + } + + +def _dummy_format_metadata() -> dict[str, Any]: + """Use for metadata.json file in mbridge checkpoint.""" + return { + "sharded_backend": "torch_dist", + "sharded_backend_version": 1, + "common_backend": "torch", + "common_backend_version": 1, + } + + +def nemo2_to_mbridge( + nemo2_ckpt_dir: Path, + tokenizer_path: Path, + mbridge_ckpt_dir: Path, + model_provider: HyenaModelProvider, + mixed_precision_recipe: str, + vortex_style_fp8: bool, +) -> Path: + """Convert a Nemo2 checkpoint to a Megatron Bridge checkpoint. + + Args: + nemo2_ckpt_dir: Path to the Nemo2 checkpoint directory. + tokenizer_path: Path to the tokenizer directory. + mbridge_ckpt_dir: Path to the Megatron Bridge checkpoint directory. + model_provider: Model provider to use for the model. + mixed_precision_recipe: Mixed precision recipe to use for the model. + vortex_style_fp8: Whether to use vortex style fp8? This is needed for the fp8 sensitive checkpoints from the + original evo2 training. For example the 1b model and the 40b models (not the nvidia bf16 finetuned + checkpoints). In general leave this as False though because it will only put a small number of layers in fp8. + + Returns: + Path to the Megatron Bridge checkpoint directory. + + Structure of a megatron bridge checkpoint: + + |-- latest_checkpointed_iteration.txt # the older megatron way of communicating the latest checkpointed iteration + |-- latest_train_state.pt # a copy of train_state.pt from the latest iteration, used by megatron bridge + ├── iter_0000001 + | ├── __*_*.distcp # distcp checkpoint files for each shard (sometiems rank sometimes arbitrary shards) + | ├── .metadata # metadata for the distcp checkpoint files + | ├── common.pt # common metadata (training configuration related) + | ├── metadata.json # metadata for the checkpoint format etc + | ├── run_config.yaml # training configuration + | ├── tokenizer # tokenizer assets + | ├── train_state.pt # training state, eg current step, etc. + """ + assert not mbridge_ckpt_dir.exists(), f"Checkpoint directory {mbridge_ckpt_dir} already exists" + mbridge_ckpt_dir.mkdir(parents=True, exist_ok=True) + mbridge_ckpt_iter_dir = mbridge_ckpt_dir / "iter_0000001" + nemo2_model_path = nemo2_ckpt_dir / "weights" + convert_nemo2_dcp_to_megatron(nemo2_model_path, mbridge_ckpt_iter_dir) + assert mbridge_ckpt_iter_dir.exists(), f"Checkpoint directory {mbridge_ckpt_iter_dir} does not exist" + with open(mbridge_ckpt_dir / "latest_checkpointed_iteration.txt", "w") as f: + f.write("1\n") + train_state = _dummy_train_state() + torch.save(train_state, mbridge_ckpt_iter_dir / "train_state.pt") + torch.save(train_state, mbridge_ckpt_dir / "latest_train_state.pt") + + common_pt_dict = _dummy_common_pt_dict() + torch.save(common_pt_dict, mbridge_ckpt_iter_dir / "common.pt") + format_metadata = _dummy_format_metadata() + with open(mbridge_ckpt_iter_dir / "metadata.json", "w") as f: + json.dump(format_metadata, f) + config_container: ConfigContainer = pretrain_config( + precision_config=mixed_precision_recipe, hf_tokenizer_model_or_path=tokenizer_path, mock=True + ) + tokenizer = build_tokenizer(config_container.tokenizer) + model_provider.vocab_size = tokenizer.vocab_size + model_provider.vortex_style_fp8 = vortex_style_fp8 + config_container.model = model_provider + config_container.to_yaml(str(mbridge_ckpt_iter_dir / "run_config.yaml")) + save_tokenizer_assets(tokenizer, config_container.tokenizer, str(mbridge_ckpt_iter_dir)) + return mbridge_ckpt_dir + + +def run_nemo2_to_mbridge( + nemo2_ckpt_dir: Path, + tokenizer_path: Path, + mbridge_ckpt_dir: Path, + model_size: str, + seq_length: int, + mixed_precision_recipe: str, + vortex_style_fp8: bool, +) -> Path: + """Convert a Nemo2 checkpoint to a Megatron Bridge checkpoint. + + Args: + nemo2_ckpt_dir: Path to the Nemo2 checkpoint directory. + tokenizer_path: Path to the tokenizer directory. + mbridge_ckpt_dir: Path to the Megatron Bridge checkpoint directory. + model_size: Model size to use for the model. + seq_length: Sequence length to use for the model. + mixed_precision_recipe: Mixed precision recipe to use for the model. + vortex_style_fp8: Whether to use vortex style fp8? This is needed for the fp8 sensitive checkpoints from the + original evo2 training. For example the 1b model and the 40b models (not the nvidia bf16 finetuned + checkpoints). In general leave this as False though because it will only put a small number of layers in fp8. + + Returns: + Path to the Megatron Bridge checkpoint directory. + """ + model_provider = HYENA_MODEL_OPTIONS[model_size](seq_length=seq_length) + res_dir = nemo2_to_mbridge( + nemo2_ckpt_dir, tokenizer_path, mbridge_ckpt_dir, model_provider, mixed_precision_recipe, vortex_style_fp8 + ) + logger.info(f"Megatron Bridge checkpoint saved to {res_dir}") + return res_dir + + +def main(): + """Main function for handling cli args and running the conversion.""" + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--nemo2-ckpt-dir", type=Path, required=True) + parser.add_argument("--tokenizer-path", type=Path, required=True) + parser.add_argument("--mbridge-ckpt-dir", type=Path, required=True) + parser.add_argument("--model-size", type=str, choices=list(HYENA_MODEL_OPTIONS.keys()), required=True) + parser.add_argument("--seq-length", type=int, required=True) + parser.add_argument("--vortex-style-fp8", action="store_true", default=False) + parser.add_argument( + "--mixed-precision-recipe", + type=str, + choices=list(MIXED_PRECISION_RECIPES.keys()), + default="bf16_mixed", + help="Mixed precision recipe to use for training.", + ) + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + run_nemo2_to_mbridge( + args.nemo2_ckpt_dir, + args.tokenizer_path, + args.mbridge_ckpt_dir, + args.model_size, + args.seq_length, + args.mixed_precision_recipe, + args.vortex_style_fp8, + ) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py index f1f346830b..700419333b 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py @@ -30,8 +30,12 @@ import pandas as pd import pytest import torch +from megatron.bridge.training.checkpointing import ( + _load_model_weights_from_checkpoint, +) +from megatron.bridge.training.model_load_save import load_model_config from megatron.bridge.training.tokenizers.config import TokenizerConfig -from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer, build_tokenizer from megatron.core import dist_checkpointing, parallel_state from megatron.core.dist_checkpointing.mapping import ShardedTensor @@ -50,13 +54,14 @@ from pytest import MonkeyPatch from bionemo.core.data.load import load -from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH +from bionemo.evo2.data.dataset_tokenizer import DEFAULT_HF_TOKENIZER_MODEL_PATH, DEFAULT_HF_TOKENIZER_MODEL_PATH_512 from bionemo.evo2.models.evo2_provider import ( Hyena1bModelProvider, Hyena7bARCLongContextModelProvider, Hyena7bModelProvider, HyenaInferenceContext, ) +from bionemo.evo2.utils.checkpoint.nemo2_to_mbridge import run_nemo2_to_mbridge logger = logging.getLogger(__name__) @@ -277,6 +282,18 @@ def determine_memory_requirement_and_skip_if_not_met(ckpt_name: str, test_name: "seq_len_cap": 4000, "memory_needed_by_test": 21, }, # checked both variants in isolation + { + "test_name": "test_forward_ckpt_conversion", + "model_size": "1b", + "seq_len_cap": 6000, + "memory_needed_by_test": 18, + }, # checked both variants in isolation + { + "test_name": "test_forward_ckpt_conversion", + "model_size": "7b", + "seq_len_cap": 4000, + "memory_needed_by_test": 21, + }, # checked both variants in isolation { "test_name": "test_batch_generate", "model_size": "1b", @@ -755,6 +772,102 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc ) +@pytest.mark.parametrize( + "ckpt_name,expected_matchpercents,flash_decode", + [ + # Try flash decode with one and not the other to verify that both paths work. + ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True), + ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False), + ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False), + ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False), + ], +) +def test_forward_ckpt_conversion( + tmp_path: Path, sequences: list[str], ckpt_name: str, expected_matchpercents: list[float], flash_decode: bool +): + """Test the forward pass of the megatron model.""" + assert len(sequences) > 0 + seq_len_cap = determine_memory_requirement_and_skip_if_not_met( + ckpt_name, test_name=inspect.currentframe().f_code.co_name + ) + + is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) + skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported + + # vortex_style_fp8 = is_fp8_supported and "bf16" not in ckpt_name + if skip: + # This checkpoint is sensitive to FP8, so we skip it if it is not supported on the current device. + pytest.skip(f"Skipping {ckpt_name} because it is not supported on {device_info} ({compute_capability})") + with distributed_model_parallel_state(), torch.no_grad(): + ckpt_path: Path = load(ckpt_name) + + mbridge_ckpt_dir = run_nemo2_to_mbridge( + nemo2_ckpt_dir=ckpt_path, + tokenizer_path=DEFAULT_HF_TOKENIZER_MODEL_PATH_512, + mbridge_ckpt_dir=tmp_path / "mbridge_checkpoint", + model_size="1b" if "1b" in ckpt_name else "7b" if "7b-8k" in ckpt_name else "7b_arc_longcontext", + seq_length=1048576 if "1m" in ckpt_name else 8192, + mixed_precision_recipe="bf16_mixed" if not is_fp8_supported else "bf16_with_fp8_current_scaling_mixed", + # The checkpoints from the original evo2 training that are "fp8 sensitive" require vortex_style_fp8=True + # to run correctly. If we set it in the config going into the conversion then at load time users will + # get this setting without having to think about it. + vortex_style_fp8=is_fp8_supported and "evo2/1b-8k:" in ckpt_name, + ) + + mbridge_ckpt_path = mbridge_ckpt_dir / "iter_0000001" + + model_config, mtron_args = load_model_config(mbridge_ckpt_path) + assert mtron_args is None, "mtron_args should be None since this is a Megatron Bridge checkpoint" + if flash_decode: + model_config.flash_decode = flash_decode + model_config.attention_backend = AttnBackend.flash + tokenizer = _HuggingFaceTokenizer(mbridge_ckpt_path / "tokenizer") + # FIXME replace above with below once bug is fixed https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1900 + # tokenizer = load_tokenizer(mbridge_ckpt_path) + model_config.finalize() # important to call finalize before providing the model, this does post_init etc. + raw_megatron_model = model_config.provide(pre_process=True, post_process=True).eval().cuda() + device = raw_megatron_model.parameters().__next__().device + _load_model_weights_from_checkpoint( + checkpoint_path=mbridge_ckpt_path, model=[raw_megatron_model], dist_ckpt_strictness="ignore_all" + ) + model = Float16Module(model_config, raw_megatron_model) + + if flash_decode: + inference_context = HyenaInferenceContext(max_batch_size=1, max_sequence_length=8192) + # Ensure full-sequence logits are materialized for tests expecting [B, S, V] + inference_context.materialize_only_last_token_logits = False + forward_kwargs = {"runtime_gather_output": True, "inference_context": inference_context} + else: + forward_kwargs = {} + matchrates = [] + for seq in sequences: + # TODO: artificial limit, megatron uses more memory. Vortex can process full sequences + partial_seq = seq[:seq_len_cap] + with torch.no_grad(): + # tokens = torch.tensor([tokenizer.tokenize(seq)], device=device) + input_ids = torch.tensor(tokenizer.text_to_ids(partial_seq)).int().unsqueeze(0).to(device) + attention_mask = None + # when labels is None, the model returns logits + logits = model( + input_ids=input_ids, + position_ids=None, + attention_mask=attention_mask, + labels=None, + **forward_kwargs, + ) + if flash_decode: + forward_kwargs["inference_context"].reset() + matchrate = _calc_matchrate(tokenizer=tokenizer, in_seq=partial_seq, logits=logits) + matchrates.append(matchrate) + _check_matchrate(ckpt_name=ckpt_name, matchrate=matchrate, assert_matchrate=False) + assert len(matchrates) == len(expected_matchpercents) + matchperc_print = [f"{m * 100.0:.1f}%" for m in matchrates] + matchperc_print_expected = [f"{ep:.1f}%" for ep in expected_matchpercents] + assert all(m * 100.0 >= 0.95 * ep for m, ep in zip(matchrates, expected_matchpercents)), ( + f"Expected at least 95% of {matchperc_print_expected=}, got {matchperc_print=}" + ) + + # def mid_point_split(*, seq, num_tokens: int | None = None, fraction: float = 0.5): # mid_point = int(fraction * len(seq)) # prompt = seq[:mid_point] diff --git a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py index af9f02f151..e58eda0928 100644 --- a/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py +++ b/bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py @@ -240,3 +240,169 @@ def test_stop_and_go( assert first_loss_run2 - last_loss_run1 < 0.1, ( f"Run 2 first loss {first_loss_run2} is not better than run 1 last loss {last_loss_run1} by no worse than 0.1" ) + + +@pytest.mark.slow +def test_fine_tuning( + tmp_path: Path, + tp_size: int = 1, + cp_size: int = 1, + dp_size: int = 1, + dp_rank_check: bool = True, + precision_recipe: str = "bf16_mixed", + pp_size: int = 1, +): + """Test fine-tuning functionality, which should mirror stop/go but reset optimizer, data, and training state.""" + world_size = tp_size * pp_size * cp_size * dp_size + mbs = 32 + gbs = mbs * dp_size + num_gpus = torch.cuda.device_count() + if world_size > num_gpus: + pytest.skip(f"World size {world_size} is greater than the number of GPUs {num_gpus}") + if "nvfp4" in precision_recipe and not is_fp4_supported(): + pytest.skip("NVFP4 is not supported on this device") + if "mxfp8" in precision_recipe and not is_mxfp8_supported(): + pytest.skip("MXFP8 is not supported on this device") + if "fp8" in precision_recipe and not is_fp8_supported(): + pytest.skip("FP8 is not supported on this device") + if "bf16_with_fp8_delayed_scaling_mixed" == precision_recipe and is_fp8_supported(): + pytest.xfail(reason="FP8 delayed scaling is not currently working with Evo2, use another FP8 recipe.") + if "bf16_with_fp8_subchannel_scaling_mixed" == precision_recipe and is_fp8_supported(): + pytest.xfail(reason="FP8 subchannel scaling is not currently working with Evo2 on some GPUs.") + run_dir = tmp_path / f"run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" + run_dir.mkdir(parents=True, exist_ok=True) + master_port = find_free_network_port() + dp_rank_check_str = "--debug-ddp-parity-freq 5" if dp_rank_check else "" + cmd1 = f"""torchrun --nproc-per-node {world_size} --no-python --master_port {master_port} \ + train_evo2 \ + --hf-tokenizer-model-path {DEFAULT_HF_TOKENIZER_MODEL_PATH} \ + --model-size striped_hyena_1b_nv_parallel --num-layers 4 --hybrid-override-pattern SDH* \ + --max-steps 5 --eval-interval 5 \ + --eval-iters 3 --mock-data --result-dir {run_dir} \ + --micro-batch-size {mbs} --global-batch-size {gbs} --seq-length 512 \ + --tensor-model-parallel {tp_size} \ + --pipeline-model-parallel {pp_size} \ + --context-parallel {cp_size} \ + --mixed-precision-recipe {precision_recipe} \ + --overlap-param-gather \ + --overlap-grad-reduce \ + {dp_rank_check_str} \ + --use-precision-aware-optimizer --dataset-seed 33 \ + --seed 41 --spike-no-more-embedding-init \ + --no-weight-decay-embeddings --cross-entropy-loss-fusion \ + --grad-reduce-in-fp32 \ + --decay-steps 1000 --warmup-steps 10 \ + --eod-pad-in-loss-mask \ + --log-interval 1 \ + """ + + # Split the command and run it + cmd_parts = shlex.split(cmd1) + env = copy.deepcopy(PRETEST_ENV) + env["NCCL_P2P_DISABLE"] = "1" + result = subprocess.run(cmd_parts, check=False, capture_output=True, text=True, cwd=run_dir, env=env) + + stdout = result.stdout + stderr = result.stderr + returncode = result.returncode + + # For debugging, print the output + print(f"Return code: {returncode}") + print(f"STDOUT:\n{stdout}") + print(f"STDERR:\n{stderr}") + + # Assert the command succeeded + assert returncode == 0, f"Command failed with return code {returncode}\nSTDERR:\n{stderr}" + result_dir = run_dir / "evo2" + ckpt_dir = result_dir / "checkpoints" + tb_log_dir = result_dir / "tb_logs" + assert ckpt_dir.exists() and ckpt_dir.is_dir(), "Checkpoints directory not found" + assert tb_log_dir.exists() and tb_log_dir.is_dir(), "TensorBoard logs directory not found" + iter_5_dir = ckpt_dir / "iter_0000005" + assert iter_5_dir.exists() and iter_5_dir.is_dir(), f"No iterations 5 checkpoint found in {ckpt_dir}" + assert len(list(ckpt_dir.glob("iter_*"))) == 1, f"Expected 1 iterations, found {list(ckpt_dir.glob('iter_*'))}" + # Load tensorboard logs to verify they were written correctly + + # Find the events file(s) in tb_log_dir + event_files = list(tb_log_dir.rglob("events.out.*")) + assert len(event_files) > 0, f"No tensorboard event files found in {tb_log_dir}" + + # Load events from the event files + event_acc = EventAccumulator(str(tb_log_dir)) + event_acc.Reload() + + # 1. collect the last loss, as well as the average of the last step validation losses, as well as the last step + # Note: EventAccumulator.Scalars returns a list of ScalarEvent(wall_time, step, value) + lm_loss_events = event_acc.Scalars("lm loss") + + assert len(lm_loss_events) > 0, "No 'lm loss' events found in run 1" + last_lm_loss_step = lm_loss_events[-1].step + + assert last_lm_loss_step == 5, f"Expected run 1 to end at step 5, but got {last_lm_loss_step}" + + # 2. run the above training command a second time, this time set max_steps to 10. Verify that the run resumes from the last step. + # Do this by moving the tb_logs to a different directory from the first part so the second run makes fresh logs. + tb_log_dir_run1 = result_dir / "tb_logs_run1" + if tb_log_dir.exists(): + shutil.move(str(tb_log_dir), str(tb_log_dir_run1)) + + # Modify the command to increase max steps to 10 + # We reuse the same result_dir so it picks up the checkpoint + ft_run_dir = ( + tmp_path / f"ft_run_tp{tp_size}_pp{pp_size}_cp{cp_size}_dp{dp_size}_rc{dp_rank_check}_pr{precision_recipe}" + ) + ft_run_dir.mkdir(parents=True, exist_ok=True) + cmd2 = cmd1.rstrip().replace(f"--result-dir {run_dir}", f"--result-dir {ft_run_dir}") + cmd2 += f" --finetune-ckpt-dir {ckpt_dir} " + cmd_parts_2 = shlex.split(cmd2) + + print("Starting Run 2 (resuming to step 10)...") + result_2 = subprocess.run(cmd_parts_2, check=False, capture_output=True, text=True, cwd=run_dir, env=env) + + print(f"Run 2 Return code: {result_2.returncode}") + if result_2.returncode != 0: + print(f"Run 2 STDERR:\n{result_2.stderr}") + + assert result_2.returncode == 0, f"Run 2 failed with return code {result_2.returncode}" + + # 3. Load the new tb logs as before, and sanity check my recommendations as well as any others that make sense. + ft_result_dir = ft_run_dir / "evo2" + ft_tb_log_dir = ft_result_dir / "tb_logs" + assert ft_tb_log_dir.exists(), "TensorBoard logs directory not found after Run 2" + + event_acc_2 = EventAccumulator(str(ft_tb_log_dir)) + event_acc_2.Reload() + + lm_loss_events_2 = event_acc_2.Scalars("lm loss") + assert len(lm_loss_events_2) > 0, "No 'lm loss' events found in run 2" + + first_step_run2 = lm_loss_events_2[0].step + first_step_run1 = lm_loss_events[0].step + last_step_run2 = lm_loss_events_2[-1].step + + # Sanity checks: + # 1. Resumption: Should start after step 5 (e.g., step 6) + assert first_step_run2 == first_step_run1, ( + f"Run 2 FT steps should match run 1, but started at {first_step_run2} vs {first_step_run1}" + ) + + # 2. Completion: Should reach step 5 like run 1 + assert last_step_run2 == 5, f"Run 2 should reach step 5, but ended at {last_step_run2}" + + # 3. Loss Continuity check (basic): The first loss of run 2 should be reasonably close to the last loss of run 1, + # or at least not exploding, though optimization steps might cause fluctuations. + first_loss_run1 = lm_loss_events[0].value + first_loss_run2 = lm_loss_events_2[0].value + last_loss_run1 = lm_loss_events[-1].value + assert first_loss_run1 > last_loss_run1, ( + f"Run 1 first loss {first_loss_run1} is less than run 1 last loss {last_loss_run1}" + ) + assert first_loss_run2 < first_loss_run1, ( + f"Run 2 first loss {first_loss_run2} is greater than run 1 first loss {first_loss_run1}" + ) + assert abs(first_loss_run2 - first_loss_run1) > abs(last_loss_run1 - first_loss_run2), ( + f"Run 2 beginning {first_loss_run2} should be closer to end of run 1 {last_loss_run1} than beginning {first_loss_run1}." + ) + assert first_loss_run2 - last_loss_run1 < 0.1, ( + f"Run 2 first loss {first_loss_run2} is not better than run 1 last loss {last_loss_run1} by no worse than 0.1" + )