Skip to content

Commit dd3c18d

Browse files
committed
feat: add dedicated merged mode to Megatron backend
1 parent a02920e commit dd3c18d

File tree

5 files changed

+1037
-247
lines changed

5 files changed

+1037
-247
lines changed

src/art/megatron/job_protocol.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Annotated, Literal, TypeAlias
2+
3+
from pydantic import BaseModel, Field, TypeAdapter
4+
5+
from art import dev, types
6+
from art.megatron.routing_replay import MoeRoutingReplayBundle
7+
from art.preprocessing.pack import DiskPackedTensors
8+
9+
10+
class MergedWeightTransferInitInfo(BaseModel):
11+
master_address: str
12+
master_port: int
13+
rank_offset: int
14+
world_size: int
15+
16+
17+
class MergedWeightTransferSpec(BaseModel):
18+
init_info: MergedWeightTransferInitInfo
19+
vllm_base_url: str
20+
served_model_name: str
21+
22+
23+
class MegatronSyncJob(BaseModel):
24+
kind: Literal["sync"]
25+
lora_path: str
26+
merged_weight_transfer: MergedWeightTransferSpec
27+
28+
29+
class _MegatronTrainJobBase(BaseModel):
30+
lora_path: str
31+
optimizer_state_path: str
32+
disk_packed_tensors: DiskPackedTensors
33+
config: types.TrainConfig
34+
experimental_config: dev.TrainConfig
35+
moe_routing_replay_path: str | None = None
36+
moe_routing_replay_strict: bool = True
37+
38+
39+
class MegatronLoraTrainJob(_MegatronTrainJobBase):
40+
kind: Literal["train_lora"]
41+
42+
43+
class MegatronMergedTrainJob(_MegatronTrainJobBase):
44+
kind: Literal["train_merged"]
45+
merged_weight_transfer: MergedWeightTransferSpec
46+
47+
48+
MegatronLoraTrainJob.model_rebuild(
49+
force=True,
50+
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
51+
)
52+
MegatronMergedTrainJob.model_rebuild(
53+
force=True,
54+
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
55+
)
56+
57+
MegatronJob: TypeAlias = Annotated[
58+
MegatronSyncJob | MegatronLoraTrainJob | MegatronMergedTrainJob,
59+
Field(discriminator="kind"),
60+
]
61+
62+
63+
def dump_megatron_job(job: MegatronJob) -> str:
64+
return TypeAdapter(MegatronJob).dump_json(job).decode()
65+
66+
67+
def load_megatron_job(raw: str | bytes) -> MegatronJob:
68+
return TypeAdapter(MegatronJob).validate_json(raw)

src/art/megatron/provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def get_provider(
8080
)
8181
)
8282
provider = bridge.to_megatron_provider()
83+
setattr(provider, "art_bridge", bridge)
8384
base_layer_spec = provider.transformer_layer_spec
8485

8586
def _flex_attention_layer_spec(

0 commit comments

Comments
 (0)