|
44 | 44 | MegatronTrainingJob, |
45 | 45 | ) |
46 | 46 | from art.megatron.lora import apply_lora_adapters |
47 | | -from art.megatron.merge import merge_lora_adapter |
| 47 | +from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter |
48 | 48 | from art.megatron.offload import ( |
49 | 49 | OffloadState, |
50 | 50 | clear_optimizer_state, |
|
66 | 66 | safetensors = importlib.import_module("safetensors") |
67 | 67 | safetensors_torch = importlib.import_module("safetensors.torch") |
68 | 68 | safe_open = safetensors.safe_open |
69 | | -load_file = safetensors_torch.load_file |
70 | 69 | save_file = safetensors_torch.save_file |
71 | 70 |
|
72 | 71 | DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" |
@@ -496,6 +495,7 @@ def run_megatron_sft_job( |
496 | 495 | grad_accumulation_sequences = resolve_global_grad_accumulation_sequences( |
497 | 496 | job.grad_accumulation_sequences |
498 | 497 | ) |
| 498 | + checkpoint_interval = job.internal_checkpoint_interval |
499 | 499 |
|
500 | 500 | for batch_idx in range(job.num_batches): |
501 | 501 | batch_start_time = time.perf_counter() |
@@ -550,6 +550,20 @@ def run_megatron_sft_job( |
550 | 550 | ) |
551 | 551 | batch_time = time.perf_counter() - batch_start_time |
552 | 552 | tokens_per_second = global_tokens / batch_time if batch_time > 0 else 0.0 |
| 553 | + completed_batches = batch_idx + 1 |
| 554 | + |
| 555 | + if ( |
| 556 | + checkpoint_interval is not None |
| 557 | + and completed_batches < job.num_batches |
| 558 | + and completed_batches % checkpoint_interval == 0 |
| 559 | + ): |
| 560 | + _save_lora_and_optimizer( |
| 561 | + runtime, |
| 562 | + adapter_model=adapter_model, |
| 563 | + lora_path=job.lora_path, |
| 564 | + optimizer_state_path=job.optimizer_state_path, |
| 565 | + ) |
| 566 | + torch.distributed.barrier() # type: ignore[possibly-missing-attribute] |
553 | 567 |
|
554 | 568 | if runtime.rank == 0: |
555 | 569 | with open(job.log_path, "a+", encoding="utf-8") as log_file: |
@@ -609,11 +623,8 @@ def _load_lora_and_optimizer( |
609 | 623 | lora_path: str, |
610 | 624 | optimizer_state_path: str, |
611 | 625 | ) -> dict[str, torch.Tensor]: |
612 | | - adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors") |
613 | | - if not os.path.exists(adapter_model_path): |
614 | | - raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") |
615 | | - print0(runtime.rank, "Loading adapter model from", adapter_model_path) |
616 | | - adapter_model = load_file(adapter_model_path) |
| 626 | + print0(runtime.rank, "Loading adapter model from", lora_path) |
| 627 | + adapter_model = load_lora_adapter_state_dict(lora_path) |
617 | 628 | load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) |
618 | 629 |
|
619 | 630 | optimizer_shard_path = os.path.join( |
|
0 commit comments