Skip to content

Commit c2039fc

Browse files
committed
Refresh shared training refactor on top of ART main
1 parent 1905677 commit c2039fc

File tree

14 files changed

+1233
-242
lines changed

14 files changed

+1233
-242
lines changed

src/art/_backend_training.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from collections.abc import Iterable
2+
import time
3+
from typing import Literal
4+
5+
from . import dev
6+
from .metrics_taxonomy import (
7+
average_metric_samples,
8+
build_training_summary_metrics,
9+
summarize_trajectory_groups,
10+
)
11+
from .trajectories import TrajectoryGroup
12+
from .types import TrainConfig
13+
14+
15+
def build_rl_train_configs(
16+
*,
17+
learning_rate: float,
18+
advantage_balance: float = 0.0,
19+
scale_rewards: bool = True,
20+
importance_sampling_level: Literal[
21+
"token", "sequence", "average", "geometric_average"
22+
] = "token",
23+
mask_prob_ratio: bool = False,
24+
ppo: bool = False,
25+
precalculate_logprobs: bool = False,
26+
epsilon: float | None = None,
27+
epsilon_high: float | None = None,
28+
max_negative_advantage_importance_sampling_weight: float | None = None,
29+
kimi_k2_tau: float | None = None,
30+
kl_penalty_coef: float = 0.0,
31+
allow_training_without_logprobs: bool | None = None,
32+
plot_tensors: bool | None = None,
33+
truncated_importance_sampling: float | None = None,
34+
scale_learning_rate_by_reward_std_dev: bool | None = None,
35+
logprob_calculation_chunk_size: int | None = None,
36+
num_trajectories_learning_rate_multiplier_power: float | None = None,
37+
kl_ref_adapter_path: str | None = None,
38+
) -> tuple[TrainConfig, dev.TrainConfig]:
39+
config = TrainConfig(
40+
learning_rate=learning_rate,
41+
kl_penalty_coef=kl_penalty_coef,
42+
)
43+
dev_config: dev.TrainConfig = {
44+
"advantage_balance": advantage_balance,
45+
"importance_sampling_level": importance_sampling_level,
46+
"kl_penalty_coef": kl_penalty_coef,
47+
"mask_prob_ratio": mask_prob_ratio,
48+
"ppo": ppo,
49+
"precalculate_logprobs": precalculate_logprobs,
50+
"scale_rewards": scale_rewards,
51+
}
52+
53+
if allow_training_without_logprobs is not None:
54+
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
55+
if plot_tensors is not None:
56+
dev_config["plot_tensors"] = plot_tensors
57+
if truncated_importance_sampling is not None:
58+
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
59+
if scale_learning_rate_by_reward_std_dev is not None:
60+
dev_config["scale_learning_rate_by_reward_std_dev"] = (
61+
scale_learning_rate_by_reward_std_dev
62+
)
63+
if logprob_calculation_chunk_size is not None:
64+
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
65+
if num_trajectories_learning_rate_multiplier_power is not None:
66+
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
67+
num_trajectories_learning_rate_multiplier_power
68+
)
69+
if epsilon is not None:
70+
dev_config["epsilon"] = epsilon
71+
if epsilon_high is not None:
72+
dev_config["epsilon_high"] = epsilon_high
73+
if max_negative_advantage_importance_sampling_weight is not None:
74+
dev_config["max_negative_advantage_importance_sampling_weight"] = (
75+
max_negative_advantage_importance_sampling_weight
76+
)
77+
if kimi_k2_tau is not None:
78+
dev_config["kimi_k2_tau"] = kimi_k2_tau
79+
if kl_ref_adapter_path is not None:
80+
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
81+
82+
return config, dev_config
83+
84+
85+
def aggregate_rl_training_metrics(
86+
*,
87+
training_metrics: list[dict[str, float]],
88+
trajectory_groups: Iterable[TrajectoryGroup],
89+
trainer_started: float,
90+
) -> dict[str, float]:
91+
groups_list = list(trajectory_groups)
92+
avg_metrics = average_metric_samples(training_metrics)
93+
summary = summarize_trajectory_groups(groups_list)
94+
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
95+
avg_metrics.update(
96+
{
97+
key: value
98+
for key, value in build_training_summary_metrics(
99+
summary,
100+
include_trainable_groups=True,
101+
).items()
102+
if key not in avg_metrics
103+
}
104+
)
105+
return avg_metrics

src/art/local/backend.py

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@
4343
from mp_actors import close_proxy, move_to_child_process
4444

4545
from .. import dev
46+
from .._backend_training import (
47+
aggregate_rl_training_metrics,
48+
build_rl_train_configs,
49+
)
4650
from ..backend import AnyTrainableModel, Backend
4751
from ..costs import build_cost_calculator, get_model_pricing
4852
from ..metrics_taxonomy import (
4953
TRAIN_GRADIENT_STEPS_KEY,
50-
average_metric_samples,
5154
build_training_summary_metrics,
5255
summarize_trajectory_groups,
5356
)
@@ -642,45 +645,36 @@ async def train( # type: ignore[override]
642645
if adam_params is not None:
643646
raise ValueError("LocalBackend requires adam_params=None.")
644647

645-
# Build config objects from explicit kwargs
646-
config = TrainConfig(
647-
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
648-
)
649-
dev_config: dev.TrainConfig = {
650-
"advantage_balance": advantage_balance,
651-
"allow_training_without_logprobs": allow_training_without_logprobs,
652-
"importance_sampling_level": importance_sampling_level,
653-
"kl_penalty_coef": kl_penalty_coef,
654-
"mask_prob_ratio": mask_prob_ratio,
655-
"plot_tensors": plot_tensors,
656-
"ppo": loss_fn == "ppo",
657-
"precalculate_logprobs": precalculate_logprobs,
658-
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
659-
"scale_rewards": scale_rewards,
660-
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
661-
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
662-
}
663-
# Only include optional fields if they're set
664-
if epsilon is not None:
665-
dev_config["epsilon"] = epsilon
666-
if epsilon_high is not None:
667-
dev_config["epsilon_high"] = epsilon_high
668-
if max_negative_advantage_importance_sampling_weight is not None:
669-
dev_config["max_negative_advantage_importance_sampling_weight"] = (
670-
max_negative_advantage_importance_sampling_weight
671-
)
672-
if kimi_k2_tau is not None:
673-
dev_config["kimi_k2_tau"] = kimi_k2_tau
674-
if truncated_importance_sampling is not None:
675-
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
676-
if kl_ref_adapter_path is not None:
677-
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
678-
elif kl_penalty_reference_step is not None:
679-
ref_checkpoint_dir = get_step_checkpoint_dir(
648+
resolved_kl_ref_adapter_path = kl_ref_adapter_path
649+
if (
650+
resolved_kl_ref_adapter_path is None
651+
and kl_penalty_reference_step is not None
652+
):
653+
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
680654
get_model_dir(model=model, art_path=self._path),
681655
kl_penalty_reference_step,
682656
)
683-
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
657+
config, dev_config = build_rl_train_configs(
658+
learning_rate=learning_rate,
659+
advantage_balance=advantage_balance,
660+
scale_rewards=scale_rewards,
661+
importance_sampling_level=importance_sampling_level,
662+
mask_prob_ratio=mask_prob_ratio,
663+
ppo=loss_fn == "ppo",
664+
precalculate_logprobs=precalculate_logprobs,
665+
epsilon=epsilon,
666+
epsilon_high=epsilon_high,
667+
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
668+
kimi_k2_tau=kimi_k2_tau,
669+
kl_penalty_coef=kl_penalty_coef,
670+
allow_training_without_logprobs=allow_training_without_logprobs,
671+
plot_tensors=plot_tensors,
672+
truncated_importance_sampling=truncated_importance_sampling,
673+
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
674+
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
675+
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
676+
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
677+
)
684678

685679
# Collect metrics from training
686680
training_metrics: list[dict[str, float]] = []
@@ -690,21 +684,10 @@ async def train( # type: ignore[override]
690684
):
691685
training_metrics.append(metrics)
692686

693-
# Aggregate metrics
694-
avg_metrics = average_metric_samples(training_metrics)
695-
summary = summarize_trajectory_groups(groups_list)
696-
avg_metrics.setdefault(
697-
"time/step_trainer_s", time.monotonic() - trainer_started
698-
)
699-
avg_metrics.update(
700-
{
701-
key: value
702-
for key, value in build_training_summary_metrics(
703-
summary,
704-
include_trainable_groups=True,
705-
).items()
706-
if key not in avg_metrics
707-
}
687+
avg_metrics = aggregate_rl_training_metrics(
688+
training_metrics=training_metrics,
689+
trajectory_groups=groups_list,
690+
trainer_started=trainer_started,
708691
)
709692

710693
# Get step and checkpoint path

src/art/megatron/jobs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Literal
2+
3+
from pydantic import BaseModel
4+
5+
from .. import dev, types
6+
from ..preprocessing.pack import DiskPackedTensors
7+
from .routing_replay import MoeRoutingReplayBundle
8+
9+
DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
10+
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
11+
DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking"
12+
13+
14+
class MegatronTrainingJob(BaseModel):
15+
lora_path: str
16+
optimizer_state_path: str
17+
disk_packed_tensors: DiskPackedTensors
18+
config: types.TrainConfig
19+
experimental_config: dev.TrainConfig
20+
moe_routing_replay_path: str | None = None
21+
moe_routing_replay_strict: bool = True
22+
log_path: str = DEFAULT_TRAINING_LOG_PATH
23+
24+
25+
MegatronTrainingJob.model_rebuild(
26+
force=True,
27+
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
28+
)
29+
30+
31+
class MegatronSFTTrainingJob(BaseModel):
32+
job_type: Literal["sft"] = "sft"
33+
lora_path: str
34+
optimizer_state_path: str
35+
sft_data_dir: str
36+
num_batches: int
37+
learning_rates: list[float]
38+
weight_decay: float = 0.0
39+
max_grad_norm: float = 1.0
40+
log_path: str = DEFAULT_TRAINING_LOG_PATH

src/art/megatron/routing_replay.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
import importlib
45
import json
56
from pathlib import Path
67
import re
@@ -13,9 +14,12 @@
1314
)
1415
from megatron.core.transformer.moe.moe_utils import permute, sort_chunks_by_idxs
1516
from pydantic import BaseModel, ConfigDict, model_validator
16-
from safetensors.torch import load_file, save_file
1717
import torch
1818

19+
safetensors_torch = importlib.import_module("safetensors.torch")
20+
load_file = safetensors_torch.load_file
21+
save_file = safetensors_torch.save_file
22+
1923
ROUTER_NAME_TOKEN = ".mlp.router"
2024
ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1"
2125
GLOBAL_TOKEN_UIDS_KEY = "global_token_uids"

src/art/megatron/runtime_env.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
4+
def _set_cache_dir(env_var: str, default_path: str) -> None:
5+
if not os.environ.get(env_var):
6+
os.environ[env_var] = os.path.expanduser(default_path)
7+
os.makedirs(os.environ[env_var], exist_ok=True)
8+
9+
10+
def configure_megatron_runtime_env() -> None:
11+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
12+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13+
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
14+
_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor")
15+
_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache")

0 commit comments

Comments
 (0)