Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .devcontainer/recipes/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ FROM nvcr.io/nvidia/pytorch:26.01-py3
# Remove once bug has been addressed in the nvidia/pytorch container.
RUN rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json

RUN --mount=type=cache,target=/var/cache/apt \
--mount=type=cache,target=/var/lib/apt \
apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y tmux && \
rm -rf /var/lib/apt/lists/*

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \
PIP_CONSTRAINT= pip install -r /workspace/requirements.txt
Expand Down
20 changes: 20 additions & 0 deletions bionemo-recipes/models/esm2/tests/common/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def fp8_recipe(request):
return request.param


RECIPE_NAME_TO_FACTORY = {
"DelayedScaling": recipe_module.DelayedScaling,
"Float8CurrentScaling": recipe_module.Float8CurrentScaling,
"Float8BlockScaling": recipe_module.Float8BlockScaling,
"MXFP8BlockScaling": recipe_module.MXFP8BlockScaling,
"NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True),
}


def recipe_to_name(recipe):
"""Convert a recipe instance to its CLI-passable string name."""
return type(recipe).__name__


def recipe_from_name(name):
"""Reconstruct a recipe instance from its CLI-passable string name."""
factory = RECIPE_NAME_TO_FACTORY[name]
return factory()


@pytest.fixture(params=["bshd", "thd"])
def input_format(request):
"""Fixture to parametrize the input format."""
Expand Down
216 changes: 216 additions & 0 deletions bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Worker script for distributed DCP (Distributed Checkpoint) tests.

Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init.
Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip.
"""

import argparse
import importlib.util
import os
import shutil
import sys
import tempfile
from pathlib import Path

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import transformer_engine.pytorch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
from transformers import set_seed


def _setup_sys_path():
"""Add model root and tests directory to sys.path so model/test imports work."""
script_dir = Path(__file__).resolve().parent # tests/common/
tests_dir = script_dir.parent # tests/
model_root = tests_dir.parent # model root (e.g., models/esm2/)
for p in [str(model_root), str(tests_dir)]:
if p not in sys.path:
sys.path.insert(0, p)


def _load_tester_class(tester_file, class_name):
"""Dynamically load a tester class from a file path."""
# Ensure the tester file's directory tree is importable
tester_dir = str(Path(tester_file).parent)
tester_parent = str(Path(tester_file).parent.parent)
for p in [tester_parent, tester_dir]:
if p not in sys.path:
sys.path.insert(0, p)

spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, class_name)


def _build_and_shard_model(tester, config, recipe, device, device_mesh):
"""Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device."""
model_class = tester.get_model_class()

if recipe is not None:
with transformer_engine.pytorch.quantized_model_init(recipe=recipe):
model = model_class(config)
else:
model = model_class(config)

# Shard each transformer layer, then the root model
for layer in tester.get_layer_path(model):
fully_shard(layer, mesh=device_mesh)
fully_shard(model, mesh=device_mesh)

model.to(device)
return model


def _forward(model, input_data, recipe):
"""Run a forward pass and return the model outputs."""
if recipe is not None:
# torch.autocast is needed when model was built with quantized_model_init
# (weights are FP8, non-quantized ops need bf16 casting)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with transformer_engine.pytorch.autocast(recipe=recipe):
return model(**input_data)
else:
return model(**input_data)


def _train_one_step(model, input_data, recipe, lr=1e-4):
"""Run a single training step (forward + backward + optimizer step) and return detached logits."""
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer.zero_grad()

outputs = _forward(model, input_data, recipe)
loss = outputs.logits.sum()
loss.backward()
optimizer.step()

return outputs.logits.detach().clone()


def _run_eval_forward(model, input_data, recipe):
"""Run an eval forward pass and return detached logits."""
model.eval()
with torch.no_grad():
outputs = _forward(model, input_data, recipe)
return outputs.logits.detach().clone()


def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42):
"""Core DCP round-trip test: build → train → save → rebuild → load → eval → compare."""
from tests.common.fixtures import recipe_from_name

rank = dist.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
world_size = dist.get_world_size()
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)

device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,))

# Resolve FP8 recipe
recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None

# Build config
set_seed(seed)
config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd")

# Prepare input data
input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32)

# --- Model A: build, shard, train one step, then eval ---
set_seed(seed)
model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh)
_train_one_step(model_a, input_data, recipe)
logits_a = _run_eval_forward(model_a, input_data, recipe)

# --- DCP Save ---
# Rank 0 creates temp dir, broadcast path to all ranks
if rank == 0:
tmp_dir = tempfile.mkdtemp(prefix="dcp_test_")
else:
tmp_dir = None
tmp_dir_list = [tmp_dir]
dist.broadcast_object_list(tmp_dir_list, src=0)
tmp_dir = tmp_dir_list[0]

checkpoint_path = os.path.join(tmp_dir, "checkpoint")

state_dict_a = {"model": model_a.state_dict()}
dcp.save(state_dict_a, checkpoint_id=checkpoint_path)

dist.barrier()

# Free model_a
del model_a, state_dict_a
torch.cuda.empty_cache()

# --- Model B: build fresh, shard, load, eval ---
set_seed(seed)
model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh)

state_dict_b = {"model": model_b.state_dict()}
dcp.load(state_dict_b, checkpoint_id=checkpoint_path)
model_b.load_state_dict(state_dict_b["model"], strict=False)

logits_b = _run_eval_forward(model_b, input_data, recipe)

# --- Compare ---
tolerances = tester.get_tolerances()
torch.testing.assert_close(
logits_a,
logits_b,
atol=tolerances.dcp_logits_atol,
rtol=tolerances.dcp_logits_rtol,
msg=lambda x: f"DCP round-trip logits mismatch: {x}",
)

# Cleanup
del model_b, state_dict_b
torch.cuda.empty_cache()
dist.barrier()

if rank == 0:
shutil.rmtree(tmp_dir, ignore_errors=True)

print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})")


if __name__ == "__main__":
_setup_sys_path()

parser = argparse.ArgumentParser(description="DCP distributed test worker")
parser.add_argument(
"--tester-file", required=True, help="Absolute path to the test file containing the tester class"
)
parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)")
parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()

dist.init_process_group(backend="nccl")

try:
tester_cls = _load_tester_class(args.tester_file, args.tester_class)
tester = tester_cls()
run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed)
finally:
dist.destroy_process_group()
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import fnmatch
import gc
import os
import subprocess
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -39,6 +41,12 @@
HAS_DATA_CENTER_GPU = False


_requires_multi_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2,
reason="Test requires at least 2 GPUs",
)


@dataclass
class TestTolerances:
"""Model-specific test tolerances for numerical comparisons."""
Expand All @@ -65,6 +73,10 @@ class TestTolerances:
fp8_logits_atol: float = 5.0
fp8_logits_rtol: float = 0.1

# DCP (distributed checkpoint) round-trip tolerances
dcp_logits_atol: float = 0.0
dcp_logits_rtol: float = 0.0

# Meta device initialization tolerances
init_mean_atol: float = 1e-3
init_mean_rtol: float = 1e-4
Expand Down Expand Up @@ -979,4 +991,69 @@ def test_meta_fp8_init(self, fp8_recipe):
model.init_empty_weights()
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)

# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.
# ==================== Distributed Checkpoint (DCP) Tests ====================

def _get_dcp_worker_script_path(self) -> str:
"""Return the absolute path to the run_distributed_dcp.py worker script."""
return str(Path(__file__).resolve().parent / "run_distributed_dcp.py")

def _get_tester_file_and_class(self):
"""Return (file_path, class_name) for dynamic loading in the worker subprocess."""
import inspect

return os.path.abspath(inspect.getfile(type(self))), type(self).__name__

def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2):
"""Launch the DCP worker script via torchrun and assert it succeeds."""
tester_file, class_name = self._get_tester_file_and_class()
worker_script = self._get_dcp_worker_script_path()

cmd = [
"torchrun",
f"--nproc_per_node={nproc_per_node}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{unused_tcp_port}",
worker_script,
"--tester-file",
tester_file,
"--tester-class",
class_name,
]

if fp8_recipe_name is not None:
cmd.extend(["--fp8-recipe", fp8_recipe_name])

result = subprocess.run(
cmd,
check=False,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=300,
)
if result.returncode != 0:
print(f"STDOUT:\n{result.stdout}")
print(f"STDERR:\n{result.stderr}")
pytest.fail(f"DCP worker failed with exit code {result.returncode}")

def test_dcp_output_parity_single_gpu(self, unused_tcp_port):
"""Test FSDP2 + DCP save/load round-trip on a single GPU."""
self._run_dcp_worker(unused_tcp_port, nproc_per_node=1)

def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port):
"""Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU."""
from .fixtures import recipe_to_name

self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1)

@_requires_multi_gpu
def test_dcp_output_parity(self, unused_tcp_port):
"""Test that a model sharded with FSDP2 produces identical outputs after DCP save/load."""
self._run_dcp_worker(unused_tcp_port)

@_requires_multi_gpu
def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port):
"""Test DCP save/load with FP8 quantized_model_init."""
from .fixtures import recipe_to_name

self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe))
8 changes: 2 additions & 6 deletions bionemo-recipes/models/esm2/tests/test_distributed_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def requires_fp8(func):
)


@pytest.mark.parametrize(
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
)
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"])
@requires_fp8
def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
cmd = [
Expand All @@ -63,9 +61,7 @@ def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
pytest.fail(f"Command failed with exit code {result.returncode}")


@pytest.mark.parametrize(
"strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))]
)
@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"])
@requires_fp8
@requires_multi_gpu
def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
Expand Down
Loading