Skip to content

Commit 953577a

Browse files
authored
Add Megatron SFT internal checkpoints (#645)
* Add Megatron SFT internal checkpoints * Simplify Megatron SFT checkpoint flow * lint fix
1 parent 8ad8e50 commit 953577a

3 files changed

Lines changed: 52 additions & 11 deletions

File tree

src/art/megatron/jobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Literal
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, Field
44

55
from .. import types
66
from ..preprocessing.pack import DiskPackedTensors
@@ -31,6 +31,7 @@ class MegatronSFTTrainingJob(BaseModel):
3131
grad_accumulation_sequences: int | None = None
3232
weight_decay: float = 0.0
3333
max_grad_norm: float = 1.0
34+
internal_checkpoint_interval: int | None = Field(default=None, ge=1)
3435
log_path: str = DEFAULT_TRAINING_LOG_PATH
3536

3637

src/art/megatron/merge.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
save_file = safetensors_torch.save_file
1212

1313

14-
def merge_lora_adapter(lora_path: str) -> None:
15-
base_dir = Path(lora_path)
14+
def _load_adapter_shards(
15+
base_dir: Path,
16+
) -> tuple[
17+
dict[str, torch.Tensor],
18+
list[Path],
19+
list[Path],
20+
]:
1621
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))
1722
if not shard_filenames:
18-
return
23+
raise FileNotFoundError(f"No adapter shards found in {base_dir}")
1924

2025
shard_files_by_suffix = {
2126
path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path
@@ -93,6 +98,30 @@ def merge_lora_adapter(lora_path: str) -> None:
9398
concat_dim = 1 if "lora_A" in key else 0
9499
tensor = torch.cat(ordered_shards, dim=concat_dim)
95100
adapter_model[key] = tensor
101+
return adapter_model, shard_filenames, manifest_filenames
102+
103+
104+
def load_lora_adapter_state_dict(lora_path: str) -> dict[str, torch.Tensor]:
105+
base_dir = Path(lora_path)
106+
adapter_model_path = base_dir / "adapter_model.safetensors"
107+
if adapter_model_path.exists():
108+
with safe_open(adapter_model_path, framework="pt") as file:
109+
return {key: file.get_tensor(key) for key in file.keys()}
110+
111+
adapter_model, _shard_filenames, _manifest_filenames = _load_adapter_shards(
112+
base_dir
113+
)
114+
return adapter_model
115+
116+
117+
def merge_lora_adapter(lora_path: str) -> None:
118+
base_dir = Path(lora_path)
119+
try:
120+
adapter_model, shard_filenames, manifest_filenames = _load_adapter_shards(
121+
base_dir
122+
)
123+
except FileNotFoundError:
124+
return
96125

97126
adapter_model_path = base_dir / "adapter_model.safetensors"
98127
save_file(adapter_model, adapter_model_path)

src/art/megatron/train.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
MegatronTrainingJob,
4545
)
4646
from art.megatron.lora import apply_lora_adapters
47-
from art.megatron.merge import merge_lora_adapter
47+
from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter
4848
from art.megatron.offload import (
4949
OffloadState,
5050
clear_optimizer_state,
@@ -66,7 +66,6 @@
6666
safetensors = importlib.import_module("safetensors")
6767
safetensors_torch = importlib.import_module("safetensors.torch")
6868
safe_open = safetensors.safe_open
69-
load_file = safetensors_torch.load_file
7069
save_file = safetensors_torch.save_file
7170

7271
DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507"
@@ -496,6 +495,7 @@ def run_megatron_sft_job(
496495
grad_accumulation_sequences = resolve_global_grad_accumulation_sequences(
497496
job.grad_accumulation_sequences
498497
)
498+
checkpoint_interval = job.internal_checkpoint_interval
499499

500500
for batch_idx in range(job.num_batches):
501501
batch_start_time = time.perf_counter()
@@ -550,6 +550,20 @@ def run_megatron_sft_job(
550550
)
551551
batch_time = time.perf_counter() - batch_start_time
552552
tokens_per_second = global_tokens / batch_time if batch_time > 0 else 0.0
553+
completed_batches = batch_idx + 1
554+
555+
if (
556+
checkpoint_interval is not None
557+
and completed_batches < job.num_batches
558+
and completed_batches % checkpoint_interval == 0
559+
):
560+
_save_lora_and_optimizer(
561+
runtime,
562+
adapter_model=adapter_model,
563+
lora_path=job.lora_path,
564+
optimizer_state_path=job.optimizer_state_path,
565+
)
566+
torch.distributed.barrier() # type: ignore[possibly-missing-attribute]
553567

554568
if runtime.rank == 0:
555569
with open(job.log_path, "a+", encoding="utf-8") as log_file:
@@ -609,11 +623,8 @@ def _load_lora_and_optimizer(
609623
lora_path: str,
610624
optimizer_state_path: str,
611625
) -> dict[str, torch.Tensor]:
612-
adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors")
613-
if not os.path.exists(adapter_model_path):
614-
raise FileNotFoundError(f"No adapter model found at {adapter_model_path}")
615-
print0(runtime.rank, "Loading adapter model from", adapter_model_path)
616-
adapter_model = load_file(adapter_model_path)
626+
print0(runtime.rank, "Loading adapter model from", lora_path)
627+
adapter_model = load_lora_adapter_state_dict(lora_path)
617628
load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer)
618629

619630
optimizer_shard_path = os.path.join(

0 commit comments

Comments
 (0)