diff --git a/.devcontainer/recipes/Dockerfile b/.devcontainer/recipes/Dockerfile index b244006515..d267d83c66 100644 --- a/.devcontainer/recipes/Dockerfile +++ b/.devcontainer/recipes/Dockerfile @@ -6,6 +6,12 @@ FROM nvcr.io/nvidia/pytorch:26.01-py3 # Remove once bug has been addressed in the nvidia/pytorch container. RUN rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json +RUN --mount=type=cache,target=/var/cache/apt \ + --mount=type=cache,target=/var/lib/apt \ + apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y tmux && \ + rm -rf /var/lib/apt/lists/* + RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \ PIP_CONSTRAINT= pip install -r /workspace/requirements.txt diff --git a/bionemo-recipes/models/esm2/tests/common/fixtures.py b/bionemo-recipes/models/esm2/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/esm2/tests/common/fixtures.py +++ b/bionemo-recipes/models/esm2/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..3294c90ccd --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _forward(model, input_data, recipe): + """Run a forward pass and return the model outputs.""" + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + return model(**input_data) + else: + return model(**input_data) + + +def _train_one_step(model, input_data, recipe, lr=1e-4): + """Run a single training step (forward + backward + optimizer step) and return detached logits.""" + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer.zero_grad() + + outputs = _forward(model, input_data, recipe) + loss = outputs.logits.sum() + loss.backward() + optimizer.step() + + return outputs.logits.detach().clone() + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + outputs = _forward(model, input_data, recipe) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → train → save → rebuild → load → eval → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, train one step, then eval --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + _train_one_step(model_a, input_data, recipe) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, eval --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py index aca85b7855..fd47da666d 100644 --- a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -979,4 +991,69 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe)) diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index 57fd3f9fb0..f0b36127c2 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -34,9 +34,7 @@ def requires_fp8(func): ) -@pytest.mark.parametrize( - "strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))] -) +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"]) @requires_fp8 def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): cmd = [ @@ -63,9 +61,7 @@ def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): pytest.fail(f"Command failed with exit code {result.returncode}") -@pytest.mark.parametrize( - "strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))] -) +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"]) @requires_fp8 @requires_multi_gpu def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): diff --git a/bionemo-recipes/models/llama3/tests/common/fixtures.py b/bionemo-recipes/models/llama3/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/llama3/tests/common/fixtures.py +++ b/bionemo-recipes/models/llama3/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..51f5a2d399 --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + outputs = model(**input_data) + else: + outputs = model(**input_data) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → forward → save → rebuild → load → forward → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, forward --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, forward --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py index aca85b7855..fd47da666d 100644 --- a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -979,4 +991,69 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe)) diff --git a/bionemo-recipes/models/mixtral/tests/common/fixtures.py b/bionemo-recipes/models/mixtral/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/mixtral/tests/common/fixtures.py +++ b/bionemo-recipes/models/mixtral/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..51f5a2d399 --- /dev/null +++ b/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + outputs = model(**input_data) + else: + outputs = model(**input_data) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → forward → save → rebuild → load → forward → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, forward --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, forward --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py index aca85b7855..fd47da666d 100644 --- a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -979,4 +991,69 @@ def test_meta_fp8_init(self, fp8_recipe): model.init_empty_weights() self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True) - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe)) diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py index 9983191a70..f21c334d01 100644 --- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py +++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py @@ -34,7 +34,9 @@ from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save from torch.distributed.checkpoint.state_dict_saver import save as dcp_save from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import DTensor from torchdata.stateful_dataloader import StatefulDataLoader +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from distributed_config import DistributedConfig @@ -117,8 +119,20 @@ def load_checkpoint_ddp( ckpt_path: str | os.PathLike, dist_config: DistributedConfig, dataloader: StatefulDataLoader | None = None, + weights_only: bool = True, ) -> CheckpointOutput: - """Load DDP checkpoint.""" + """Load DDP checkpoint. + + Args: + model: The model to load. + optimizer: The optimizer to load. + scheduler: The LR scheduler to load. + ckpt_path: The path to the checkpoint. + dist_config: The distributed configuration. + dataloader: The dataloader to load. + weights_only: Whether to load the checkpoint weights only. We have to set this to True when loading FP8 + checkpoints. + """ checkpoint_path, _ = get_latest_checkpoint(ckpt_path) if not checkpoint_path: @@ -128,7 +142,7 @@ def load_checkpoint_ddp( checkpoint = torch.load( checkpoint_path / "checkpoint.pt", map_location=f"cuda:{dist_config.local_rank}", - weights_only=True, + weights_only=weights_only, ) model.load_state_dict(checkpoint["model"]) @@ -223,6 +237,7 @@ class AppState(Stateful): def state_dict(self): """Get the state dict for the model, optimizer, scheduler, and step.""" model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) + model_state_dict = {k: v for k, v in model_state_dict.items() if not k.endswith("_extra_state")} return { "model": model_state_dict, "optim": optimizer_state_dict, @@ -238,6 +253,7 @@ def load_state_dict(self, state_dict: dict): self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"], + options=StateDictOptions(strict=False), ) self.scheduler.load_state_dict(state_dict["scheduler"]) self.step = state_dict["step"] @@ -324,6 +340,13 @@ def save_checkpoint_fsdp2( checkpoint_path = ckpt_path / f"step_{step}" checkpoint_path.mkdir(parents=True, exist_ok=True) + model_params = (p.to_local() if isinstance(p, DTensor) else p for p in model.parameters()) + if async_save and any((isinstance(p, Float8Tensor) for p in model_params)): + logger.warning( + "Async checkpointing is not supported for FP8 models, falling back to synchronous checkpointing." + ) + async_save = False + if dataloader is not None: save_dataloader( dataloader=dataloader, diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py index 3d6f63f256..726eb19e8e 100644 --- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py @@ -44,7 +44,7 @@ class PerfLogger: min_loss: The minimum loss seen so far. """ - def __init__(self, dist_config: DistributedConfig, args: DictConfig): + def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step: int): """Initialize the logger.""" self._dist_config = dist_config self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True) @@ -75,7 +75,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig): if self._dist_config.is_main_process(): # Log the entire args object to wandb for experiment tracking and reproducibility. self._wandb_run = wandb.init(**args.wandb, config=self._run_config) - self._progress_bar = tqdm(total=args.num_train_steps, desc="Training") + self._progress_bar = tqdm(initial=start_step, total=args.num_train_steps, desc="Training") if args.profiler.enabled: self._profiler = NsightProfiler( diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py index 87ba309ad7..08330b12f7 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py @@ -19,6 +19,7 @@ import pytest import torch +from transformer_engine.pytorch import fp8 as te_fp8 sys.path.append(Path(__file__).parent.parent.as_posix()) @@ -61,6 +62,56 @@ def pytest_collection_modifyitems(items): items[:] = stats_tests + other_tests +# --------------------------------------------------------------------------- +# FP8 recipe parametrization +# --------------------------------------------------------------------------- + +# Each entry: (recipe_class_name, hydra_overrides, check_fn) +_FP8_RECIPE_CONFIGS = [ + ( + "DelayedScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.DelayedScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8CurrentScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8CurrentScaling"], + te_fp8.check_fp8_support, + ), + ( + "Float8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.Float8BlockScaling"], + te_fp8.check_fp8_block_scaling_support, + ), + ( + "MXFP8BlockScaling", + ["fp8_config.fp8_recipe=transformer_engine.common.recipe.MXFP8BlockScaling"], + te_fp8.check_mxfp8_support, + ), +] + + +def _parametrize_fp8_recipes(): + """Generate pytest.param objects with xfail marks for unsupported FP8 recipes.""" + params = [] + for name, overrides, check_fn in _FP8_RECIPE_CONFIGS: + supported, reason = check_fn() + params.append( + pytest.param( + overrides, + id=name, + marks=pytest.mark.xfail(condition=not supported, reason=reason), + ) + ) + return params + + +@pytest.fixture(params=_parametrize_fp8_recipes()) +def fp_recipe(request): + """Parametrized fixture providing FP8 recipe Hydra overrides for each supported TE recipe.""" + return request.param + + @pytest.fixture(scope="session", autouse=True) def device_mesh(): """Create a re-usable torch process group for testing. diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py index 098223291d..8b63d4cacb 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_distributed_checkpointing.py @@ -19,11 +19,13 @@ distributed training configurations: - DDP (Distributed Data Parallel) with 1 and 2 processes - FSDP2 (PyTorch native Fully Sharded Data Parallel v2) with 1 and 2 processes +- FSDP2 with context parallelism +- FP8 quantized model init with checkpoint save/resume Test Strategy: 1. Phase 1: Train for N steps and save checkpoint 2. Phase 2: Resume training from checkpoint and continue -3. Validate: Checkpoints created, resuming works, training continues seamlessly +3. Validate: Checkpoints created, resuming works, training continues seamlessly, losses are valid Each test uses temporary directories and disables wandb logging for isolation. """ @@ -50,965 +52,390 @@ ) -def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for DDP with single process. - - This test validates: - - DDP creates single-file checkpoints (checkpoint.pt files) - - Standard PyTorch checkpoint format (model + optimizer state) - - Single-process DDP training and resuming works correctly - - Checkpoint files contain complete model state - - Process: - 1. Train 10 steps (0-9), save checkpoint file at step 5 - 2. Resume training from checkpoint, continue to step 15 - 3. Verify checkpoint files exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_ddp") +# --------------------------------------------------------------------------- +# Test Utilities +# --------------------------------------------------------------------------- - # Phase 1: Train for 10 steps, saving a checkpoint at step 5 - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_ddp(phase1_config) - gc.collect() - torch.cuda.empty_cache() - # Phase 1 creates this directory structure: - # ckpt_subdir/ - # └── step_5/ - # ├── checkpoint.pt - # └── dataloader_step_5_rank_0_num_workers_1.pt +def _compose_config(recipe_path, tmp_path, config_name, overrides): + """Compose a Hydra config with standard checkpoint-test settings. - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + Every config gets ``checkpoint.ckpt_dir``, ``+wandb.dir``, and + ``dataset.use_stateful_dataloader`` set automatically so that callers + only need to supply test-specific overrides. + """ + ckpt_dir = str(tmp_path / "ckpt") + base = [ + f"checkpoint.ckpt_dir={ckpt_dir}", + f"+wandb.dir={tmp_path}", + "dataset.use_stateful_dataloader=true", + ] + with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): + return compose(config_name=config_name, overrides=base + list(overrides or [])) - # Verify step_5 checkpoint was created - step_5_dir = os.path.join(ckpt_subdir, "step_5") - # Check step_5 directory exists and contains expected files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - assert len(step_5_files) == 2, f"Expected 2 files in step_5 directory, found {len(step_5_files)}: {step_5_files}" - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - assert any("dataloader" in f for f in step_5_files), ( - f"No dataloader file found in step_5 directory. Files found: {step_5_files}" - ) +def _assert_loss_valid(loss, label=""): + """Assert that a training loss is finite and not NaN.""" + tag = f" ({label})" if label else "" + assert loss is not None, f"Loss is None{tag}" + loss_val = float(loss) + assert not torch.isnan(torch.tensor(loss_val)), f"Loss is NaN{tag}" + assert torch.isfinite(torch.tensor(loss_val)), f"Loss is not finite: {loss_val}{tag}" - # Verify the actual checkpoint files are valid files - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - # Check that only step_5 exists at this point (no step_10 yet) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 1, ( - f"Expected only 1 checkpoint directory after phase 1, found {len(all_step_dirs)}: {all_step_dirs}" - ) - assert all_step_dirs[0] == "step_5", f"Expected only step_5 after phase 1, found: {all_step_dirs}" +def _assert_checkpoint_step(ckpt_subdir, step, num_ranks=1, is_ddp=True): + """Assert that a checkpoint step directory has the expected files. - # Phase 2: Resume training (should start from step 5, continue to step 15) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], + For DDP checks for ``checkpoint.pt`` and exact file counts. + For FSDP2 (DCP format) only checks for per-rank dataloader files. + """ + step_dir = os.path.join(ckpt_subdir, f"step_{step}") + assert os.path.isdir(step_dir), f"Step {step} directory not found: {step_dir}" + files = os.listdir(step_dir) + + if is_ddp: + expected_count = 1 + num_ranks # checkpoint.pt + one dataloader per rank + assert len(files) == expected_count, ( + f"Expected {expected_count} files in step_{step}, found {len(files)}: {files}" ) + assert "checkpoint.pt" in files, f"checkpoint.pt not in step_{step}: {files}" + assert os.path.isfile(os.path.join(step_dir, "checkpoint.pt")) - main_ddp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Phase 2 adds to the directory structure: - # ckpt_subdir/ - # ├── step_5/ - # │ ├── checkpoint.pt - # │ └── dataloader_step_5_rank_0_num_workers_1.pt - # └── step_10/ - # ├── checkpoint.pt - # └── dataloader_step_10_rank_0_num_workers_1.pt - - # Verify the checkpoint files exist in the correct directories - step_5_dir = os.path.join(ckpt_subdir, "step_5") - step_10_dir = os.path.join(ckpt_subdir, "step_10") - - # Check step_5 directory and files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - assert any("dataloader" in f for f in step_5_files), ( - f"No dataloader file found in step_5 directory. Files found: {step_5_files}" - ) - - # Check step_10 directory and files - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - assert "checkpoint.pt" in step_10_files, ( - f"checkpoint.pt not found in step_10 directory. Files found: {step_10_files}" - ) - assert any("dataloader" in f for f in step_10_files), ( - f"No dataloader file found in step_10 directory. Files found: {step_10_files}" + dataloader_files = [f for f in files if "dataloader" in f] + assert len(dataloader_files) >= num_ranks, ( + f"Expected >= {num_ranks} dataloader files in step_{step}, found {len(dataloader_files)}: {dataloader_files}" ) + for rank in range(num_ranks): + assert any(f"rank_{rank}" in f for f in dataloader_files), ( + f"No dataloader file for rank {rank} in step_{step}: {dataloader_files}" + ) - # Verify the actual checkpoint files are valid files - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - assert os.path.isfile(os.path.join(step_10_dir, "checkpoint.pt")), "step_10/checkpoint.pt is not a valid file" - # Final check: we should have exactly 2 checkpoint directories (step_5 and step_10) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 2, f"Expected 2 checkpoint directories, found {len(all_step_dirs)}: {all_step_dirs}" - assert set(all_step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {all_step_dirs}" +def _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fn, + ckpt_subdir_name, + config_name="L0_sanity", + extra_overrides=None, + is_ddp=True, +): + """Run a two-phase checkpoint save/resume test in a single process. + Phase 1 trains for 10 steps (saving at step 5), phase 2 resumes and + continues to step 15 (saving at step 10). Both phases validate that + checkpoints are created correctly and that losses are finite. -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for DDP with two processes. - - This test validates: - - Multi-process DDP checkpoint behavior (main process saves only) - - Checkpoint files can be loaded by all DDP processes - - Process synchronization during resume (all processes load same checkpoint) - - DDP training continues correctly after resume across processes - - Process: - 1. Train 10 steps (0-9) across 2 processes, main process saves checkpoint at step 5 - 2. Resume training with 2 processes, all load same checkpoint file, continue to step 15 - 3. Verify checkpoint files exist at steps 5 and 10 + Returns: + Tuple of (phase1_loss, phase2_loss). """ - temp_dir = str(tmp_path / "test_ckpt_ddp_2p") - - # Set environment for subprocess - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - - # Get the full path to train_ddp.py - train_script = recipe_path / "train_ddp.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", + ckpt_dir = str(tmp_path / "ckpt") + common = [ "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "checkpoint.async_save=false", + *(extra_overrides or []), ] - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - - # Phase 1 creates this directory structure with 2 processes: - # ckpt_subdir/ - # └── step_5/ - # ├── checkpoint.pt - # ├── dataloader_step_5_rank_0_num_workers_1.pt - # └── dataloader_step_5_rank_1_num_workers_1.pt - - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" - - # Verify step_5 checkpoint was created - step_5_dir = os.path.join(ckpt_subdir, "step_5") - - # Check step_5 directory exists and contains expected files - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - - # With 2 processes, we expect: 1 checkpoint.pt + 2 dataloader files (one per rank) - assert len(step_5_files) == 3, ( - f"Expected 3 files in step_5 directory (1 checkpoint + 2 dataloaders), found {len(step_5_files)}: {step_5_files}" - ) - assert "checkpoint.pt" in step_5_files, f"checkpoint.pt not found in step_5 directory. Files found: {step_5_files}" - - # Check for dataloader files for both ranks - dataloader_files = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files)}: {dataloader_files}" - ) - - # Verify we have dataloader files for both rank 0 and rank 1 - assert any("rank_0" in f for f in dataloader_files), ( - f"No dataloader file for rank 0 found. Files: {dataloader_files}" - ) - assert any("rank_1" in f for f in dataloader_files), ( - f"No dataloader file for rank 1 found. Files: {dataloader_files}" - ) - - # Verify the actual checkpoint file is valid - assert os.path.isfile(os.path.join(step_5_dir, "checkpoint.pt")), "step_5/checkpoint.pt is not a valid file" - - # Check that only step_5 exists at this point (no step_10 yet) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 1, ( - f"Expected only 1 checkpoint directory after phase 1, found {len(all_step_dirs)}: {all_step_dirs}" + # Phase 1: train 10 steps, checkpoint at step 5 + cfg1 = _compose_config( + recipe_path, + tmp_path, + config_name, + [ + "num_train_steps=10", + "checkpoint.resume_from_checkpoint=false", + *common, + ], ) - assert all_step_dirs[0] == "step_5", f"Expected only step_5 after phase 1, found: {all_step_dirs}" - - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + loss1 = main_fn(cfg1) + gc.collect() + torch.cuda.empty_cache() - # Phase 2 adds to the directory structure: - # ckpt_subdir/ - # ├── step_5/ - # │ ├── checkpoint.pt - # │ ├── dataloader_step_5_rank_0_num_workers_1.pt - # │ └── dataloader_step_5_rank_1_num_workers_1.pt - # └── step_10/ - # ├── checkpoint.pt - # ├── dataloader_step_10_rank_0_num_workers_1.pt - # └── dataloader_step_10_rank_1_num_workers_1.pt - - # Verify step_10 checkpoint was created - step_10_dir = os.path.join(ckpt_subdir, "step_10") - - # Check step_10 directory exists and contains expected files - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect: 1 checkpoint.pt + 2 dataloader files (one per rank) - assert len(step_10_files) == 3, ( - f"Expected 3 files in step_10 directory (1 checkpoint + 2 dataloaders), found {len(step_10_files)}: {step_10_files}" - ) - assert "checkpoint.pt" in step_10_files, ( - f"checkpoint.pt not found in step_10 directory. Files found: {step_10_files}" - ) + ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name) + assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp) - # Check for dataloader files for both ranks - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_10)}: {dataloader_files_10}" - ) + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert step_dirs == ["step_5"], f"Expected only step_5 after phase 1, found: {step_dirs}" - # Verify we have dataloader files for both rank 0 and rank 1 - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" + # Phase 2: resume and continue to step 15, checkpoint at step 10 + cfg2 = _compose_config( + recipe_path, + tmp_path, + config_name, + [ + "num_train_steps=15", + "checkpoint.resume_from_checkpoint=true", + *common, + ], ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" - ) - - # Verify the actual checkpoint file is valid - assert os.path.isfile(os.path.join(step_10_dir, "checkpoint.pt")), "step_10/checkpoint.pt is not a valid file" - - # Final check: we should have exactly 2 checkpoint directories (step_5 and step_10) - all_step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] - assert len(all_step_dirs) == 2, f"Expected 2 checkpoint directories, found {len(all_step_dirs)}: {all_step_dirs}" - assert set(all_step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {all_step_dirs}" - -def test_checkpoint_save_and_load_single_process_fsdp2(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with single process. - - This test validates: - - FSDP2 creates distributed checkpoints (step_X directories by default) - - Each rank saves its shard (even with single process) - - Dataloader state is saved alongside model checkpoint - - Training can resume from latest checkpoint and continue - - Resume starts from correct step count - - Process: - 1. Train 10 steps (0-9), save checkpoint at step 5 - 2. Resume training from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2") - - # Phase 1: Train for 10 steps (using distributed checkpoint by default) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "checkpoint.async_save=false", - ], - ) - - main_fsdp2(phase1_config) + loss2 = main_fn(cfg2) gc.collect() torch.cuda.empty_cache() - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" - - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=1, is_ddp=is_ddp) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=1, is_ddp=is_ddp) - # Check dataloader file exists in step_5 directory - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - - # With single process, we expect dataloader file for rank 0 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) >= 1, ( - f"Expected at least 1 dataloader file, found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" - ) + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert set(step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {step_dirs}" - # Phase 2: Resume training - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - # Sometimes the checkpoint hasn't finished saving by the time we resume training, so we disable async - # save for this test. - "checkpoint.async_save=false", - ], - ) + # Validate losses are finite and not NaN + _assert_loss_valid(loss1, "phase 1") + _assert_loss_valid(loss2, "phase 2") - main_fsdp2(phase2_config) - gc.collect() - torch.cuda.empty_cache() + return loss1, loss2 - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader file exists in step_10 directory - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With single process, we expect dataloader file for rank 0 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) >= 1, ( - f"Expected at least 1 dataloader file in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) +def _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + train_script_name, + ckpt_subdir_name, + nproc=2, + extra_overrides=None, + is_ddp=True, +): + """Run a two-phase checkpoint save/resume test using ``torchrun``. -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with two processes. - - This test validates: - - Multi-process FSDP2 distributed checkpointing (each rank saves its shard) - - Dataloader state is saved for each rank alongside model checkpoint - - All ranks participate in saving and loading - - Training resumes correctly with proper process synchronization - - Process: - 1. Train 10 steps (0-9) across 2 processes, save checkpoint at step 5 - 2. Resume training with 2 processes from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 with dataloader files for both ranks + Same two-phase strategy as :func:`_run_single_process_checkpoint_test` + but spawns *nproc* processes via ``torchrun``. """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_2p") - - # Set environment for subprocess + ckpt_dir = str(tmp_path / "ckpt") env = os.environ.copy() env["WANDB_MODE"] = "disabled" - # Get the full path to train_fsdp2.py - train_script = recipe_path / "train_fsdp2.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", + train_script = recipe_path / train_script_name + common = [ + f"checkpoint.ckpt_dir={ckpt_dir}", "checkpoint.save_every_n_steps=5", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing + "dataset.use_stateful_dataloader=true", + *(extra_overrides or []), ] - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) + base_cmd = ["torchrun", "--standalone", f"--nproc_per_node={nproc}", str(train_script)] + + # Phase 1 + result1 = subprocess.run( + [*base_cmd, "num_train_steps=10", "checkpoint.resume_from_checkpoint=false", *common], + check=False, + capture_output=True, + text=True, + env=env, + ) assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") + ckpt_subdir = os.path.join(ckpt_dir, ckpt_subdir_name) assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=nproc, is_ddp=is_ddp) - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" + step_dirs = [d for d in os.listdir(ckpt_subdir) if d.startswith("step_")] + assert len(step_dirs) == 1, f"Expected 1 checkpoint dir after phase 1, found: {step_dirs}" - # Check dataloader files exist in step_5 directory for both ranks - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" + # Phase 2 + result2 = subprocess.run( + [*base_cmd, "num_train_steps=15", "checkpoint.resume_from_checkpoint=true", *common], + check=False, + capture_output=True, + text=True, + env=env, ) - assert any("rank_1" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 1 found in step_5. Files: {dataloader_files_5}" - ) - - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader files exist in step_10 directory for both ranks - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1) in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" - ) - - -def test_checkpoint_save_and_load_single_process_fsdp2_with_context_parallelism(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with single process and context parallelism. - - This test validates: - - FSDP2 creates distributed checkpoints (step_X directories by default) - - Each rank saves its shard (even with single process) - - Dataloader state is saved alongside model checkpoint - - Training can resume from latest checkpoint and continue - - Resume starts from correct step count - - Process: - 1. Train 10 steps (0-9), save checkpoint at step 5 - 2. Resume training from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp") - - # Phase 1: Train for 10 steps (using distributed checkpoint by default) - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity_cp", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "checkpoint.async_save=false", - ], - ) - - main_fsdp2_cp(phase1_config) - gc.collect() - torch.cuda.empty_cache() + _assert_checkpoint_step(ckpt_subdir, 5, num_ranks=nproc, is_ddp=is_ddp) + _assert_checkpoint_step(ckpt_subdir, 10, num_ranks=nproc, is_ddp=is_ddp) - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" + step_dirs = sorted(d for d in os.listdir(ckpt_subdir) if d.startswith("step_")) + assert set(step_dirs) == {"step_5", "step_10"}, f"Expected step_5 and step_10, found: {step_dirs}" - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" +# --------------------------------------------------------------------------- +# DDP Checkpoint Tests +# --------------------------------------------------------------------------- - # Check dataloader file exists in step_5 directory - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - # With single process, we expect dataloader file for rank 0 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) >= 1, ( - f"Expected at least 1 dataloader file, found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" +def test_checkpoint_save_and_load_single_process_ddp(recipe_path, tmp_path): + """Test checkpoint save/resume for DDP with a single process.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + is_ddp=True, ) - # Phase 2: Resume training - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity_cp", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - # Sometimes the checkpoint hasn't finished saving by the time we resume training, so we disable async - # save for this test. - "checkpoint.async_save=false", - ], - ) - main_fsdp2_cp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader file exists in step_10 directory - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With single process, we expect dataloader file for rank 0 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) >= 1, ( - f"Expected at least 1 dataloader file in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_ddp(recipe_path, tmp_path): + """Test checkpoint save/resume for DDP with two processes.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_ddp.py", + ckpt_subdir_name="train_ddp", + is_ddp=True, ) -@requires_multi_gpu -def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(recipe_path, tmp_path): - """Test checkpoint save/resume functionality for FSDP2 with two processes. - - This test validates: - - Multi-process FSDP2 distributed checkpointing (each rank saves its shard) - - Dataloader state is saved for each rank alongside model checkpoint - - All ranks participate in saving and loading - - Training resumes correctly with proper process synchronization - - Process: - 1. Train 10 steps (0-9) across 2 processes with context parallelism, save checkpoint at step 5 - 2. Resume training with 2 processes from step 5, continue to step 15 - 3. Verify checkpoints exist at steps 5 and 10 with dataloader files for both ranks - """ - temp_dir = str(tmp_path / "test_ckpt_fsdp2_cp_2p") +# --------------------------------------------------------------------------- +# FSDP2 Checkpoint Tests +# --------------------------------------------------------------------------- - # Set environment for subprocess - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - # Get the full path to train_fsdp2.py - train_script = recipe_path / "train_fsdp2_cp.py" - - # Phase 1: Train for 10 steps with 2 processes - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.async_save=false", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "cp_size=2", - ] +def test_checkpoint_save_and_load_single_process_fsdp2(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with a single process.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2, + ckpt_subdir_name="train_fsdp2", + is_ddp=False, + ) - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" - # Checkpoints are saved in a subdirectory named after the script - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with two processes.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2.py", + ckpt_subdir_name="train_fsdp2", + is_ddp=False, + ) - # Verify checkpoint was created (FSDP2 creates directories by default) - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert len(checkpoint_dirs) > 0, "No checkpoint directories created in phase 1" - # Check that checkpoint at step 5 exists - expected_checkpoint = "step_5" - assert expected_checkpoint in checkpoint_dirs, f"Expected {expected_checkpoint} not found" +# --------------------------------------------------------------------------- +# FSDP2 + Context Parallelism Checkpoint Tests +# --------------------------------------------------------------------------- - # Check dataloader files exist in step_5 directory for both ranks - step_5_dir = os.path.join(ckpt_subdir, "step_5") - assert os.path.isdir(step_5_dir), f"Step 5 directory not found: {step_5_dir}" - step_5_files = os.listdir(step_5_dir) - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_5 = [f for f in step_5_files if "dataloader" in f] - assert len(dataloader_files_5) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1), found {len(dataloader_files_5)}: {dataloader_files_5}" - ) - assert any("rank_0" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 0 found in step_5. Files: {dataloader_files_5}" - ) - assert any("rank_1" in f for f in dataloader_files_5), ( - f"No dataloader file for rank 1 found in step_5. Files: {dataloader_files_5}" +def test_checkpoint_save_and_load_single_process_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with context parallelism (single process).""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + is_ddp=False, ) - # Phase 2: Resume training with 2 processes - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "checkpoint.async_save=false", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - "cp_size=2", - ] - - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" - # Verify phase 2 completed and created additional checkpoints - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoints = ["step_5", "step_10"] - for expected in expected_checkpoints: - assert expected in final_checkpoint_dirs, f"Missing checkpoint: {expected}" - - # Check dataloader files exist in step_10 directory for both ranks - step_10_dir = os.path.join(ckpt_subdir, "step_10") - assert os.path.isdir(step_10_dir), f"Step 10 directory not found: {step_10_dir}" - step_10_files = os.listdir(step_10_dir) - - # With 2 processes, we expect dataloader files for rank 0 and rank 1 - dataloader_files_10 = [f for f in step_10_files if "dataloader" in f] - assert len(dataloader_files_10) == 2, ( - f"Expected 2 dataloader files (rank 0 and 1) in step_10, found {len(dataloader_files_10)}: {dataloader_files_10}" - ) - assert any("rank_0" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 0 found in step_10. Files: {dataloader_files_10}" - ) - assert any("rank_1" in f for f in dataloader_files_10), ( - f"No dataloader file for rank 1 found in step_10. Files: {dataloader_files_10}" +@requires_multi_gpu +def test_checkpoint_save_and_load_two_processes_fsdp2_with_context_parallelism(recipe_path, tmp_path): + """Test checkpoint save/resume for FSDP2 with context parallelism (two processes).""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2_cp.py", + ckpt_subdir_name="train_fsdp2", + extra_overrides=["checkpoint.async_save=false", "cp_size=2"], + is_ddp=False, ) -def test_scheduler_resume_single_gpu(recipe_path, tmp_path): - """Test that learning rate scheduler resumes from correct state after checkpoint load. - - This test validates: - - Scheduler state is saved in checkpoint - - Scheduler resumes with correct step count - - Learning rate continues from where it left off (not reset) - - Warmup and decay continue correctly after resume - - Process: - 1. Train for 10 steps, save checkpoint with scheduler state at step 5 - 2. Resume training, verify scheduler continues from step 6 (not step 0) - 3. Check that learning rate progression is continuous across resume - """ - temp_dir = str(tmp_path / "test_scheduler_resume") +# --------------------------------------------------------------------------- +# Scheduler Resume Tests +# --------------------------------------------------------------------------- - # Phase 1: Train for 10 steps with warmup - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase1_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_ddp(phase1_config) - gc.collect() - torch.cuda.empty_cache() - # Phase 2: Resume training for 5 more steps - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - phase2_config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) +def test_scheduler_resume_single_gpu(recipe_path, tmp_path): + """Test that the LR scheduler resumes from the correct state after checkpoint load.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + extra_overrides=[ + "lr_scheduler_kwargs.num_warmup_steps=20", + "lr_scheduler_kwargs.num_decay_steps=100", + ], + is_ddp=True, + ) - main_ddp(phase2_config) - gc.collect() - torch.cuda.empty_cache() - # Verify checkpoints were created - ckpt_subdir = os.path.join(temp_dir, "train_ddp") - assert os.path.exists(ckpt_subdir), f"Checkpoint subdirectory {ckpt_subdir} not created" +@requires_multi_gpu +def test_scheduler_resume_two_gpu(recipe_path, tmp_path): + """Test that the LR scheduler resumes correctly with multi-GPU FSDP2 training.""" + _run_multi_process_checkpoint_test( + recipe_path, + tmp_path, + "train_fsdp2.py", + ckpt_subdir_name="train_fsdp2", + extra_overrides=[ + "lr_scheduler_kwargs.num_warmup_steps=20", + "lr_scheduler_kwargs.num_decay_steps=100", + ], + is_ddp=False, + ) - # Check that checkpoint directories exist - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - expected_checkpoint_dirs = ["step_5", "step_10"] - for expected_dir in expected_checkpoint_dirs: - assert expected_dir in checkpoint_dirs, f"Missing checkpoint directory: {expected_dir}" - # Verify each checkpoint directory contains the checkpoint file - checkpoint_file = os.path.join(ckpt_subdir, expected_dir, "checkpoint.pt") - assert os.path.isfile(checkpoint_file), f"Missing checkpoint file: {checkpoint_file}" +# --------------------------------------------------------------------------- +# Final Model Save Tests +# --------------------------------------------------------------------------- def test_final_model_save_ddp(recipe_path, tmp_path): - """Test final model saving for DDP. - - Validates that DDP saves the final model correctly with: - - model.safetensors containing weights - - config.json with model configuration - - Can be loaded for inference - - This is important for: - - Exporting trained models - - HuggingFace model hub compatibility - - Inference deployment - """ - temp_dir = str(tmp_path / "test_final_ddp") - - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "checkpoint.save_final_model=true", - "num_train_steps=3", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_ddp(config) + """Test that DDP saves a final model with model.safetensors and config.json.""" + cfg = _compose_config( + recipe_path, + tmp_path, + "L0_sanity", + [ + "checkpoint.save_final_model=true", + "num_train_steps=3", + ], + ) + + loss = main_ddp(cfg) gc.collect() torch.cuda.empty_cache() - # Check final model directory - final_model_dir = os.path.join(temp_dir, "train_ddp", "final_model") - assert os.path.exists(final_model_dir), "Final model directory not created" + _assert_loss_valid(loss, "final model ddp") - # Check required files - required_files = ["model.safetensors", "config.json"] - for file in required_files: - file_path = os.path.join(final_model_dir, file) - assert os.path.exists(file_path), f"Missing required file: {file}" - assert os.path.getsize(file_path) > 0, f"File {file} is empty" + final_model_dir = os.path.join(str(tmp_path / "ckpt"), "train_ddp", "final_model") + assert os.path.exists(final_model_dir), "Final model directory not created" + for fname in ("model.safetensors", "config.json"): + fpath = os.path.join(final_model_dir, fname) + assert os.path.exists(fpath), f"Missing: {fname}" + assert os.path.getsize(fpath) > 0, f"{fname} is empty" def test_final_model_save_fsdp2(recipe_path, tmp_path): - """Test final model saving for FSDP2. - - Validates that FSDP2 gathers full state dict and saves the final model with: - - model.safetensors containing gathered weights - - config.json with model configuration - - This tests that FSDP2's parameter gathering works correctly: - - All shards are gathered - - Full model state is consolidated - - Model can be loaded for inference - """ - temp_dir = str(tmp_path / "test_final_fsdp2") - - with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"): - config = compose( - config_name="L0_sanity", - overrides=[ - f"checkpoint.ckpt_dir={temp_dir}", - f"+wandb.dir={tmp_path}", - "checkpoint.save_final_model=true", - "num_train_steps=3", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ], - ) - - main_fsdp2(config) + """Test that FSDP2 gathers weights and saves a final model.""" + cfg = _compose_config( + recipe_path, + tmp_path, + "L0_sanity", + [ + "checkpoint.save_final_model=true", + "num_train_steps=3", + ], + ) + + loss = main_fsdp2(cfg) gc.collect() torch.cuda.empty_cache() - # Check final model directory - final_model_dir = os.path.join(temp_dir, "train_fsdp2", "final_model") - assert os.path.exists(final_model_dir), "Final model directory not created" - - # Check required files - required_files = ["model.safetensors", "config.json"] - for file in required_files: - file_path = os.path.join(final_model_dir, file) - assert os.path.exists(file_path), f"Missing required file: {file}" - assert os.path.getsize(file_path) > 0, f"File {file} is empty" - - -@requires_multi_gpu -def test_scheduler_resume_two_gpu(recipe_path, tmp_path): - """Test that learning rate scheduler resumes correctly with multi-GPU training. - - This test validates: - - Scheduler state is synchronized across GPUs during save - - All GPUs resume with same scheduler state - - Learning rate is consistent across all processes after resume - - No divergence in LR between ranks - - Process: - 1. Train for 10 steps across 2 GPUs, save checkpoint at step 5 - 2. Resume training on 2 GPUs, verify scheduler continues correctly - 3. Ensure both GPUs have same learning rate progression - """ - temp_dir = str(tmp_path / "test_scheduler_resume_2gpu") - - env = os.environ.copy() - env["WANDB_MODE"] = "disabled" - - # Test with FSDP2 as it's most complex for scheduler state - train_script = recipe_path / "train_fsdp2.py" - - # Phase 1: Train for 10 steps with 2 GPUs - cmd_phase1 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=10", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=false", # Start fresh - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) - assert result1.returncode == 0, f"Phase 1 failed: {result1.stderr}" + _assert_loss_valid(loss, "final model fsdp2") - # Check that checkpoint was created (FSDP2 uses distributed format by default) - ckpt_subdir = os.path.join(temp_dir, "train_fsdp2") - checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert "step_5" in checkpoint_dirs, "Checkpoint at step 5 not found" - - # Phase 2: Resume training with 2 GPUs - cmd_phase2 = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - str(train_script), - f"checkpoint.ckpt_dir={temp_dir}", - "num_train_steps=15", - "checkpoint.save_every_n_steps=5", - "checkpoint.resume_from_checkpoint=true", # Resume from checkpoint - "lr_scheduler_kwargs.num_warmup_steps=20", - "lr_scheduler_kwargs.num_decay_steps=100", - "dataset.use_stateful_dataloader=true", # Enable for checkpoint testing - ] - - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) - assert result2.returncode == 0, f"Phase 2 failed: {result2.stderr}" + final_model_dir = os.path.join(str(tmp_path / "ckpt"), "train_fsdp2", "final_model") + assert os.path.exists(final_model_dir), "Final model directory not created" + for fname in ("model.safetensors", "config.json"): + fpath = os.path.join(final_model_dir, fname) + assert os.path.exists(fpath), f"Missing: {fname}" + assert os.path.getsize(fpath) > 0, f"{fname} is empty" - # Verify training continued successfully - # The fact that it completed without errors means scheduler state was properly synchronized - # Check that final checkpoint was created (distributed format) - final_checkpoint_dirs = [ - d for d in os.listdir(ckpt_subdir) if d.startswith("step_") and os.path.isdir(os.path.join(ckpt_subdir, d)) - ] - assert "step_10" in final_checkpoint_dirs, "Checkpoint at step 10 not found" +# --------------------------------------------------------------------------- +# Checkpoint Pruning Tests +# --------------------------------------------------------------------------- def test_checkpoint_pruning(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning keeps only the latest N checkpoints.""" from checkpoint import prune_checkpoints temp_dir = str(tmp_path / "test_checkpoint_pruning") @@ -1026,8 +453,7 @@ def test_checkpoint_pruning(tmp_path): def test_checkpoint_pruning_not_enough_checkpoints(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning when fewer checkpoints than max exist.""" from checkpoint import prune_checkpoints temp_dir = str(tmp_path / "test_checkpoint_pruning") @@ -1040,8 +466,7 @@ def test_checkpoint_pruning_not_enough_checkpoints(tmp_path): def test_checkpoint_pruning_with_files(tmp_path): - """Test checkpoint pruning functionality.""" - + """Test checkpoint pruning with file-based checkpoints.""" from checkpoint import prune_checkpoints for i in range(11): @@ -1054,3 +479,74 @@ def test_checkpoint_pruning_with_files(tmp_path): assert (tmp_path / "step_8.pt").exists() assert (tmp_path / "step_9.pt").exists() assert (tmp_path / "step_10.pt").exists() + + +# --------------------------------------------------------------------------- +# FP8 Checkpoint Tests (with quantized_model_init) +# --------------------------------------------------------------------------- + +_FP8_QUANTIZED_OVERRIDES = [ + "fp8_config.enabled=true", + "fp8_config.quantized_model_init_kwargs.enabled=true", + "+dataset.pad_sequences_to_be_divisible_by=16", +] + + +def test_checkpoint_save_and_load_single_process_ddp_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for DDP with FP8 quantized model init.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_ddp, + ckpt_subdir_name="train_ddp", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=True, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2 with FP8 quantized model init.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=False, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2 with context parallelism and FP8 quantized model init.""" + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[*_FP8_QUANTIZED_OVERRIDES, *fp_recipe], + is_ddp=False, + ) + + +def test_checkpoint_save_and_load_single_process_fsdp2_cp_fp8_quantized_async(recipe_path, tmp_path, fp_recipe): + """Test checkpoint save/resume for FSDP2+CP with FP8 quantized model init and async save. + + This reproduces the corys_config scenario where async_save=true (the default) + is used with FP8 quantized model init. + """ + _run_single_process_checkpoint_test( + recipe_path, + tmp_path, + main_fsdp2_cp, + ckpt_subdir_name="train_fsdp2", + config_name="L0_sanity_cp", + extra_overrides=[ + *_FP8_QUANTIZED_OVERRIDES, + *fp_recipe, + "checkpoint.async_save=true", + ], + is_ddp=False, + ) diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py index 370e174c6f..aebdfe17ef 100644 --- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py +++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py @@ -71,7 +71,7 @@ def _create_perf_logger(logging_frequency, mock_wandb, mock_tqdm): """Create a PerfLogger with the given logging_frequency.""" dist_config = DistributedConfig() args = _make_args(logging_frequency=logging_frequency) - return PerfLogger(dist_config, args) + return PerfLogger(dist_config, args, start_step=0) def _run_steps(perf_logger, losses, grad_acc_steps=1): diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py index 7aed3ff6f5..0a25c02940 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py @@ -137,6 +137,7 @@ def main(args: DictConfig) -> float | None: ckpt_path=ckpt_path, dist_config=dist_config, dataloader=train_dataloader, + weights_only=not args.fp8_config.quantized_model_init_kwargs.enabled, ) logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch) else: @@ -144,7 +145,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py index 558d27366d..4d88f2e0c0 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py @@ -160,7 +160,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache() diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py index 9ad3d0e297..06fb6630ba 100644 --- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py @@ -178,7 +178,7 @@ def main(args: DictConfig) -> float | None: start_step = 0 epoch = 0 - perf_logger = PerfLogger(dist_config, args) + perf_logger = PerfLogger(dist_config, args, start_step=start_step) gc.collect() torch.cuda.empty_cache()