Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c2039fc
Refresh shared training refactor on top of ART main
Kovbo Mar 28, 2026
19c906b
Rename Megatron merge helper
Kovbo Mar 28, 2026
9d75910
Deduplicate local and shared training logic
Kovbo Mar 28, 2026
6d0d2ae
Fix Megatron rope theta compatibility
Kovbo Mar 28, 2026
9c474c9
Remove Megatron rope theta workaround
Kovbo Mar 28, 2026
2fa8ffb
Align Unsloth SFT weight decay defaults
Kovbo Mar 28, 2026
8cb71cc
remove apex from no-build-isolation-package
Kovbo Mar 28, 2026
3a679cb
update install script
Kovbo Mar 28, 2026
9e90c7d
Fix Megatron job finalization ordering
Kovbo Mar 30, 2026
511d72c
Share Megatron worker loop
Kovbo Mar 30, 2026
2e64da0
Default Megatron grad accumulation by DP size
Kovbo Apr 1, 2026
0cee7cf
Collapse Megatron shared API into train module
Kovbo Apr 1, 2026
911c082
Remove Megatron shared shim
Kovbo Apr 1, 2026
0fa9a2b
Collapse Unsloth shared API into train module
Kovbo Apr 1, 2026
f6cd445
Lighten Megatron orchestration imports
Kovbo Apr 1, 2026
ff28081
Merge branch 'main' of github.com:OpenPipe/ART into feat/shared-train…
Kovbo Apr 2, 2026
3116a1b
Merge branch 'feat/shared-training-code' of github.com:OpenPipe/ART i…
Kovbo Apr 2, 2026
d08f2ad
fix: normalize SFT loss by token count before backward pass
Kovbo Apr 2, 2026
21dd5a3
Revert "fix: normalize SFT loss by token count before backward pass"
Kovbo Apr 3, 2026
d68ae3d
Support Megatron SFT in local backend
Kovbo Apr 3, 2026
f8fee63
refactor: extract create_identity_lora as standalone function
Kovbo Apr 3, 2026
baac098
Fix SFT main_grad fallback in Megatron
Kovbo Apr 6, 2026
aa2fd4b
Fix ART lint and type issues
Kovbo Apr 6, 2026
497ff3c
Simplify ty-safe optimizer access
Kovbo Apr 6, 2026
2be0333
test: drop megatron sft batch unit test
Kovbo Apr 7, 2026
b322072
refactor: revert direct safetensors import in moe conversion
Kovbo Apr 7, 2026
7c5a02b
style: format megatron oracle harness
Kovbo Apr 7, 2026
82fa9d0
refactor: use direct safetensors import in routing replay
Kovbo Apr 7, 2026
40e66aa
fix: isolate megatron optimizer states and step counts
FurtherAI Apr 7, 2026
9bf7001
Add SFT oracle coverage and shared grad scheduling
FurtherAI Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,14 @@ else
echo "Skipping git reset/clean (GIT_RESET_CLEAN is not true). Preserving synced working tree."
fi

# Install astral-uv
if ! command -v uv >/dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
# Install astral-uv (standalone version)
# Always prepend standalone install path so it takes precedence over system/conda uv
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi

# Update uv
uv self update

# Sync the dependencies
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
uv sync --all-extras
Expand Down
105 changes: 105 additions & 0 deletions src/art/_backend_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Iterable
import time
from typing import Literal

from . import dev
from .metrics_taxonomy import (
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
from .trajectories import TrajectoryGroup
from .types import TrainConfig


def build_rl_train_configs(
*,
learning_rate: float,
advantage_balance: float = 0.0,
scale_rewards: bool = True,
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
mask_prob_ratio: bool = False,
ppo: bool = False,
precalculate_logprobs: bool = False,
epsilon: float | None = None,
epsilon_high: float | None = None,
max_negative_advantage_importance_sampling_weight: float | None = None,
kimi_k2_tau: float | None = None,
kl_penalty_coef: float = 0.0,
allow_training_without_logprobs: bool | None = None,
plot_tensors: bool | None = None,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool | None = None,
logprob_calculation_chunk_size: int | None = None,
num_trajectories_learning_rate_multiplier_power: float | None = None,
kl_ref_adapter_path: str | None = None,
) -> tuple[TrainConfig, dev.TrainConfig]:
config = TrainConfig(
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
"scale_rewards": scale_rewards,
}

if allow_training_without_logprobs is not None:
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
if plot_tensors is not None:
dev_config["plot_tensors"] = plot_tensors
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if scale_learning_rate_by_reward_std_dev is not None:
dev_config["scale_learning_rate_by_reward_std_dev"] = (
scale_learning_rate_by_reward_std_dev
)
if logprob_calculation_chunk_size is not None:
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
if num_trajectories_learning_rate_multiplier_power is not None:
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
num_trajectories_learning_rate_multiplier_power
)
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path

return config, dev_config


def aggregate_rl_training_metrics(
*,
training_metrics: list[dict[str, float]],
trajectory_groups: Iterable[TrajectoryGroup],
trainer_started: float,
) -> dict[str, float]:
groups_list = list(trajectory_groups)
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
)
return avg_metrics
132 changes: 66 additions & 66 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@
from mp_actors import close_proxy, move_to_child_process

from .. import dev
from .._backend_training import (
aggregate_rl_training_metrics,
build_rl_train_configs,
)
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
Expand Down Expand Up @@ -642,45 +645,36 @@ async def train( # type: ignore[override]
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": loss_fn == "ppo",
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
}
# Only include optional fields if they're set
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
elif kl_penalty_reference_step is not None:
ref_checkpoint_dir = get_step_checkpoint_dir(
resolved_kl_ref_adapter_path = kl_ref_adapter_path
if (
resolved_kl_ref_adapter_path is None
and kl_penalty_reference_step is not None
):
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
)

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
Expand All @@ -690,21 +684,10 @@ async def train( # type: ignore[override]
):
training_metrics.append(metrics)

# Aggregate metrics
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault(
"time/step_trainer_s", time.monotonic() - trainer_started
)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
avg_metrics = aggregate_rl_training_metrics(
training_metrics=training_metrics,
trajectory_groups=groups_list,
trainer_started=trainer_started,
)

# Get step and checkpoint path
Expand Down Expand Up @@ -822,20 +805,31 @@ async def _train_model(
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
)
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
estimated_gradient_steps = math.ceil(
grad_accumulation_sequences = max(
1, int(config.grad_accumulation_sequences or 1)
)
fallback_gradient_steps = math.ceil(
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
)
pbar = tqdm.tqdm(total=estimated_gradient_steps, desc="train")
pbar = tqdm.tqdm(total=fallback_gradient_steps, desc="train")
reported_gradient_steps: int | None = None
async for result in service.train(
disk_packed_tensors, config, dev_config, verbose
):
num_gradient_steps = int(
result.pop(TRAIN_GRADIENT_STEPS_KEY, estimated_gradient_steps)
)
assert num_gradient_steps == estimated_gradient_steps, (
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
)
raw_num_gradient_steps = result.pop(TRAIN_GRADIENT_STEPS_KEY, None)
if raw_num_gradient_steps is not None:
num_gradient_steps = int(raw_num_gradient_steps)
if reported_gradient_steps is None:
reported_gradient_steps = num_gradient_steps
if pbar.total != num_gradient_steps:
pbar.total = num_gradient_steps
pbar.refresh()
else:
assert num_gradient_steps == reported_gradient_steps, (
f"num_gradient_steps {num_gradient_steps} != reported_gradient_steps {reported_gradient_steps}"
)
else:
num_gradient_steps = reported_gradient_steps or fallback_gradient_steps
yield {
**base_metrics,
**result,
Expand Down Expand Up @@ -882,10 +876,13 @@ async def _train_sft(
)
tokenizer = self._tokenizers[model.base_model]

# Determine batch_size
batch_size = config.batch_size
if batch_size == "auto":
batch_size = 2 # Default to 2 for SFT
from ..utils.sft import resolve_sft_batch_size

batch_size = resolve_sft_batch_size(
batch_size=config.batch_size,
default_batch_size=self._default_sft_batch_size(),
)
service_config = config.model_copy(update={"batch_size": batch_size})

# Auto-detect instruction/response parts from model
from ..utils.model_config import get_instruction_response_parts
Expand Down Expand Up @@ -931,7 +928,7 @@ async def _train_sft(
total_trajectories = len(trajectory_list)
batch_count = 0

async for result in service.train_sft(batches, verbose):
async for result in service.train_sft(batches, service_config, verbose):
pbar.update(1)
pbar.set_postfix({"loss": f"{result.get('loss/train', 0):.4f}"})
batch_count += 1
Expand All @@ -953,6 +950,9 @@ async def _train_sft(
if verbose:
print("_train_sft complete")

def _default_sft_batch_size(self) -> int:
return 2

# ------------------------------------------------------------------
# Experimental support for S3
# ------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions src/art/local/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def train(
def train_sft(
self,
batches: list[SFTBatch],
config: types.TrainSFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train using SFT on pre-computed batches.

Args:
batches: List of SFTBatch objects to train on.
config: SFT batch/grad-accumulation configuration.
verbose: Whether to print detailed logs.

Yields:
Expand Down
2 changes: 1 addition & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import dev

if TYPE_CHECKING:
from art.unsloth.service import TrainInputs
from art.preprocessing.inputs import TrainInputs


class Loss(BaseModel):
Expand Down
7 changes: 7 additions & 0 deletions src/art/megatron/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
process_name="megatron-service",
)
return self._services[model.name]

def _default_sft_batch_size(self) -> int:
import torch

num_gpus = max(int(torch.cuda.device_count()), 1)
tensor_parallel_size = min(2, num_gpus)
return max(num_gpus // tensor_parallel_size, 1)
Loading
Loading