Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,31 @@ class MegatronTrainingJob(BaseModel):
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MergedWeightTransferInitInfo(BaseModel):
master_address: str
master_port: int
rank_offset: int
world_size: int


class MergedWeightTransferSpec(BaseModel):
init_info: MergedWeightTransferInitInfo
vllm_base_url: str
served_model_name: str


class MegatronMergedTrainJob(MegatronTrainingJob):
job_type: Literal["merged"] = "merged"
merged_weight_transfer: MergedWeightTransferSpec


class MegatronSyncJob(BaseModel):
job_type: Literal["sync"] = "sync"
lora_path: str
merged_weight_transfer: MergedWeightTransferSpec
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MegatronSFTTrainingJob(BaseModel):
job_type: Literal["sft"] = "sft"
lora_path: str
Expand All @@ -35,4 +60,9 @@ class MegatronSFTTrainingJob(BaseModel):
log_path: str = DEFAULT_TRAINING_LOG_PATH


MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob
MegatronJob = (
MegatronTrainingJob
| MegatronMergedTrainJob
| MegatronSyncJob
| MegatronSFTTrainingJob
)
1 change: 1 addition & 0 deletions src/art/megatron/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def get_provider(
)
)
provider = bridge.to_megatron_provider()
setattr(provider, "art_bridge", bridge)
base_layer_spec = provider.transformer_layer_spec

def _flex_attention_layer_spec(
Expand Down
161 changes: 139 additions & 22 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
import shlex
import shutil
import signal
import socket
import subprocess
import sys
Expand All @@ -28,12 +29,17 @@
from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache
from ..utils.convert_moe_lora import convert_checkpoint_if_needed
from ..utils.get_model_step import get_step_from_dir
from ..utils.network import find_free_tcp_port
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, openai_server_task, run_on_workers
from .client import create_megatron_job_paths, stream_megatron_job, write_megatron_job
from .jobs import (
MegatronMergedTrainJob,
MegatronSFTTrainingJob,
MegatronSyncJob,
MegatronTrainingJob,
MergedWeightTransferInitInfo,
MergedWeightTransferSpec,
)
from .lora import LORA_ALPHA, LORA_RANK
from .sft_batches import materialize_sft_batches
Expand Down Expand Up @@ -148,6 +154,10 @@ class MegatronService:
_vllm_log_file: Any = field(default=None, repr=False)
_vllm_host: str = "127.0.0.1"
_vllm_port: int = 0
_merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = field(
default=None,
repr=False,
)

@property
def is_dedicated(self) -> bool:
Expand Down Expand Up @@ -247,17 +257,59 @@ def _ensure_lora_adapter_config(
return
self._default_lora_adapter_config().save_pretrained(lora_path)

def _build_merged_weight_transfer_spec(self, step: int) -> MergedWeightTransferSpec:
init_info = self._merged_weight_transfer_init_info
assert init_info is not None
return MergedWeightTransferSpec(
init_info=init_info,
vllm_base_url=self._vllm_base_url,
served_model_name=f"{self.model_name}@{step}",
)

def _resolve_active_lora_path(self) -> str:
lora_path = get_last_checkpoint_dir(self.output_dir)
if lora_path is None:
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
self._latest_step = 0
else:
self._latest_step = get_step_from_dir(self.output_dir)
self._ensure_identity_lora(lora_path)
if self.is_dedicated or self.rollout_weights_mode == "lora":
self._ensure_identity_lora(lora_path)
self._ensure_lora_adapter_config(lora_path)
return lora_path

async def _set_served_model_name(self, step: int) -> None:
import httpx

async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._vllm_base_url}/art/set_served_model_name",
json={"name": f"{self.model_name}@{step}"},
timeout=30.0,
)
response.raise_for_status()
self._latest_step = step

async def _init_merged_weight_transfer(self) -> None:
import httpx

if self._merged_weight_transfer_init_info is not None:
return
assert len(self.config["trainer_gpu_ids"]) == 1
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self._vllm_base_url}/get_world_size",
timeout=30.0,
)
response.raise_for_status()
inference_world_size = int(response.json()["world_size"])
self._merged_weight_transfer_init_info = MergedWeightTransferInitInfo(
master_address="127.0.0.1",
master_port=find_free_tcp_port(),
rank_offset=1,
world_size=inference_world_size + 1,
)

async def _start_vllm_subprocess(
self,
lora_path: str,
Expand Down Expand Up @@ -285,8 +337,13 @@ async def _start_vllm_subprocess(
if config and "engine_args" in config:
engine_args.update(dict(config["engine_args"]))
engine_args.setdefault("generation_config", "vllm")
engine_args["enable_lora"] = True
engine_args.setdefault("max_loras", 2)
if self.rollout_weights_mode == "merged":
engine_args["weight_transfer_config"] = {"backend": "nccl"}
engine_args.pop("enable_lora", None)
engine_args.pop("max_loras", None)
else:
engine_args["enable_lora"] = True
engine_args.setdefault("max_loras", 2)
for key in ("model", "served_model_name", "enable_sleep_mode"):
engine_args.pop(key, None)

Expand Down Expand Up @@ -366,6 +423,25 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
response.raise_for_status()
self._latest_step = step

async def _sync_dedicated_merged_weights(
self,
*,
lora_path: str,
step: int,
) -> None:
await self._ensure_megatron_running()
await self._init_merged_weight_transfer()
job_path, log_path = self._create_megatron_job_paths()
job = MegatronSyncJob(
lora_path=lora_path,
merged_weight_transfer=self._build_merged_weight_transfer_spec(step),
log_path=log_path,
)
write_megatron_job(job, job_path=job_path)
async for _ in stream_megatron_job(job, job_path=job_path):
pass
self._latest_step = step

def _stop_vllm_subprocess(self) -> None:
if self._vllm_process is not None:
self._vllm_process.terminate()
Expand All @@ -378,12 +454,13 @@ def _stop_vllm_subprocess(self) -> None:
if self._vllm_log_file is not None:
self._vllm_log_file.close()
self._vllm_log_file = None
self._merged_weight_transfer_init_info = None

def _stop_megatron_process(self) -> None:
if self._megatron_process is None:
return
if self._megatron_process.returncode is None:
self._megatron_process.terminate()
os.killpg(os.getpgid(self._megatron_process.pid), signal.SIGTERM)
self._megatron_process = None

async def _add_lora_aliases(
Expand All @@ -402,8 +479,10 @@ async def _add_lora_aliases(

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._reload_adapter(checkpoint_dir, step)
if self.rollout_weights_mode == "merged":
await self._set_served_model_name(step)
else:
await self._reload_adapter(checkpoint_dir, step)
return
llm = await self.llm
await llm.pause_generation()
Expand Down Expand Up @@ -458,6 +537,7 @@ async def _ensure_megatron_running(self) -> None:
command,
cwd=str(project_root),
env=launch_env,
start_new_session=True,
)

def _clear_pending_jobs(self) -> None:
Expand Down Expand Up @@ -535,9 +615,15 @@ async def start_openai_server(
lora_path = self._resolve_active_lora_path()

if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
port = (config or {}).get("server_args", {}).get("port", 8000)
return await self._start_vllm_subprocess(lora_path, port, config)
location = await self._start_vllm_subprocess(lora_path, port, config)
if self.rollout_weights_mode == "merged":
self._clear_pending_jobs()
await self._sync_dedicated_merged_weights(
lora_path=lora_path,
step=self._latest_step,
)
return location

lora_path_for_server = (
lora_path if self._adapter_has_weights(lora_path) else None
Expand Down Expand Up @@ -575,7 +661,6 @@ async def train(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._ensure_megatron_running()

lora_path = self._resolve_active_lora_path()
Expand All @@ -586,24 +671,56 @@ async def train(
"MegatronService subprocess jobs must use moe_routing_replay_path."
)
job_path, log_path = self._create_megatron_job_paths()
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
log_path=log_path,
)
next_step = self._latest_step + 1
if self.rollout_weights_mode == "merged":
await self._init_merged_weight_transfer()
job = MegatronMergedTrainJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
merged_weight_transfer=self._build_merged_weight_transfer_spec(
next_step
),
log_path=log_path,
)
else:
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._get_optimizer_state_path("rl"),
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=cast(dict[str, Any], _config),
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
log_path=log_path,
)
write_megatron_job(job, job_path=job_path)

async for result in stream_megatron_job(job, job_path=job_path):
yield {key: float(value) for key, value in result.items()}

await self._publish_dedicated_training_checkpoint(lora_path=lora_path)
if self.rollout_weights_mode == "merged":
new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step)
os.makedirs(new_checkpoint_dir, exist_ok=True)
shutil.copy(
f"{lora_path}/adapter_model.safetensors",
f"{new_checkpoint_dir}/adapter_model.safetensors",
)
self._ensure_lora_adapter_config(
new_checkpoint_dir,
source_path=lora_path,
)
self._latest_step = next_step
else:
await self._publish_dedicated_training_checkpoint(lora_path=lora_path)
return
llm, lora_path = await self._prepare_for_training()
if _config.get("moe_routing_replay_bundle") is not None:
Expand Down
Loading
Loading