Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions miles/backends/training_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,23 @@ def get_batch(
parallel_state = get_parallel_state()

assert "tokens" in keys
has_domains = "domains" in data_iterator.rollout_data
if has_domains and "domains" not in keys:
keys = list(keys) + ["domains"]
batch = data_iterator.get_next(keys)

if "dynamic_global_batch_size" in data_iterator.rollout_data:
batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"]

# Canonical domain set, cached on the iterator so every microbatch emits the
# same list (aggregate_train_losses keys positionally on the first microbatch).
if has_domains:
if not hasattr(data_iterator, "_all_domains_cache"):
data_iterator._all_domains_cache = sorted(
{d for d in data_iterator.rollout_data["domains"] if d}
)
batch["all_domains"] = data_iterator._all_domains_cache

tokens = batch["tokens"]
# use 0 as the pad token id should be fine?
pad_token_id = 0
Expand Down
32 changes: 27 additions & 5 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
"returns": "reward",
}

# Cumulative train-step counter across all rollouts. The previous formula
# `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when
# `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since
# each rollout uses its own current num_steps_per_rollout as a scaling factor.
# A simple monotone counter is invariant to that jitter.
_TRAIN_STEP_COUNTER = 0


def gather_log_data(
metric_name: str,
Expand All @@ -77,7 +84,15 @@ def gather_log_data(
# dict to the union of keys with NaN so every rank sends the same shape.
# Cost is one all_gather_object on a tiny key list.
all_keys: list = [None] * dp_size
logger.info(
f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) "
f"rollout={rollout_id} entering all_gather_object, keys={len(log_dict)}"
)
dist.all_gather_object(all_keys, sorted(log_dict.keys()), group=pg.gloo_group)
logger.info(
f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) "
f"rollout={rollout_id} all_gather_object returned"
)
union_keys: set = set()
for ks in all_keys:
if ks:
Expand Down Expand Up @@ -170,6 +185,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
"dynamic_global_batch_size",
"weight_versions",
"metadata",
"domains",
]:
continue
# Upload per sample mean for each rollout value
Expand Down Expand Up @@ -492,7 +508,9 @@ def log_train_step(
Returns:
The formatted log_dict (for CI tests or other uses).
"""
accumulated_step_id = rollout_id * num_steps_per_rollout + step_id
global _TRAIN_STEP_COUNTER
accumulated_step_id = _TRAIN_STEP_COUNTER
_TRAIN_STEP_COUNTER += 1
role_tag = "" if role == "actor" else f"{role}-"

log_dict_out = {
Expand Down Expand Up @@ -525,7 +543,7 @@ def log_train_step(
# cross-plotted against rollout-side axes in the wandb UI.
log_dict_out["train/rollout_id"] = rollout_id
log_dict_out["train/step_in_rollout"] = step_id
log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id)
log_dict_out["train_step"] = accumulated_step_id

# Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged)
grouped_additions = {}
Expand All @@ -534,9 +552,13 @@ def log_train_step(
if not full_key.startswith(prefix):
continue
bare_key = full_key[len(prefix):]
if bare_key in _TRAIN_METRIC_GROUPS:
for group in _TRAIN_METRIC_GROUPS[bare_key]:
grouped_additions[f"{group}/{bare_key}"] = val
# Per-domain keys arrive as "<metric>/<domain>" — route to "<group>/<domain>/<metric>".
metric_name, sep, domain = bare_key.rpartition("/")
lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key
if lookup in _TRAIN_METRIC_GROUPS:
suffix = f"{domain}/{metric_name}" if lookup == metric_name else bare_key
for group in _TRAIN_METRIC_GROUPS[lookup]:
grouped_additions[f"{group}/{suffix}"] = val
elif bare_key.startswith("lr-pg_"):
grouped_additions[f"optimization/{bare_key}"] = val
log_dict_out.update(grouped_additions)
Expand Down
55 changes: 52 additions & 3 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,11 @@ def policy_loss_function(
else:
pg_loss_reducer = sum_of_sample_mean

# Saved for per-domain fan-out (reducers below overwrite these names with scalars).
_pg_loss_per_token = pg_loss
_pg_clipfrac_per_token = pg_clipfrac
_ppo_kl_per_token = ppo_kl

pg_loss = pg_loss_reducer(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
ppo_kl = sum_of_sample_mean(ppo_kl)
Expand Down Expand Up @@ -690,18 +695,24 @@ def policy_loss_function(
# Train-inference mismatch: compare inference engine vs FSDP at rollout time
train_rollout_logprob_abs_diff = None
train_rollout_logprob_diff = None
_train_rollout_logprob_abs_per_token = None
_train_rollout_logprob_signed_per_token = None
if "rollout_log_probs" in batch and batch["rollout_log_probs"]:
rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0)
log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0)
train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach()
_train_rollout_logprob_abs_per_token = (old_log_probs - rollout_log_probs_cat).abs()
# signed: log π(inf) − log π(fsdp rollout)
train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach()
_train_rollout_logprob_signed_per_token = rollout_log_probs_cat - log_probs_batch_cat
train_rollout_logprob_abs_diff = sum_of_sample_mean(_train_rollout_logprob_abs_per_token).clone().detach()
train_rollout_logprob_diff = sum_of_sample_mean(_train_rollout_logprob_signed_per_token).clone().detach()

# KL vs reference model — always log when ref present, regardless of use_kl_loss
ref_kl_metric = None
_ref_kl_per_token = None
if "ref_log_probs" in batch and batch["ref_log_probs"]:
ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0)
ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach()
_ref_kl_per_token = log_probs - ref_log_probs_cat
ref_kl_metric = sum_of_sample_mean(_ref_kl_per_token).clone().detach()

reported_loss = {
"loss": loss.clone().detach(),
Expand Down Expand Up @@ -735,6 +746,44 @@ def policy_loss_function(
if args.use_opsm:
reported_loss["opsm_clipfrac"] = opsm_clipfrac

# Per-domain fan-out: activated by batch["domains"] (set when samples carry
# metadata["domain"]). batch["all_domains"] is cached on DataIterator so every
# microbatch emits the same key set (aggregate_train_losses keys positionally).
# grad_norm isn't split: backward() has already mixed gradients.
if batch.get("domains") and batch.get("all_domains"):
per_token = {
"log_probs": log_probs,
"old_log_probs": old_log_probs,
"pg_loss": _pg_loss_per_token,
"pg_clipfrac": _pg_clipfrac_per_token,
"ppo_kl": _ppo_kl_per_token,
"entropy_loss": entropy,
}
if _ref_kl_per_token is not None:
per_token["ref_kl"] = _ref_kl_per_token
if _train_rollout_logprob_signed_per_token is not None:
per_token["train_rollout_logprob_diff"] = _train_rollout_logprob_signed_per_token
per_token["train_rollout_logprob_abs_diff"] = _train_rollout_logprob_abs_per_token
if args.get_mismatch_metrics or args.use_tis:
per_token["ois"] = ois
per_token.update(tis_metrics)

for d in batch["all_domains"]:
masked = [
lm if dd == d else torch.zeros_like(lm)
for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False)
]
reducer = get_sum_of_sample_mean(
total_lengths, response_lengths, masked,
args.calculate_per_token_loss, args.qkv_format, max_seq_lens,
loss_agg_mode=getattr(args, "loss_agg_mode", None),
)
for name, t in per_token.items():
reported_loss[f"{name}/{d}"] = reducer(t).clone().detach()
reported_loss[f"loss/{d}"] = (
reported_loss[f"pg_loss/{d}"] - args.entropy_coef * reported_loss[f"entropy_loss/{d}"]
)

return loss, reported_loss


Expand Down
60 changes: 48 additions & 12 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
from miles.utils.iter_utils import group_by
from miles.utils.logging_utils import configure_logger
from miles.utils.metric_checker import MetricChecker
from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix
from miles.utils.metric_utils import (
compute_pass_rate,
compute_rollout_step,
compute_samples_seen,
compute_statistics,
dict_add_prefix,
)
from miles.utils.misc import load_function
from miles.utils.ray_utils import Box
from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions
Expand Down Expand Up @@ -728,6 +734,11 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
if samples[0].train_metadata is not None:
train_data["metadata"] = [sample.train_metadata for sample in samples]

# Presence of metadata["domain"] activates per-domain metric fan-out in policy_loss_function.
domains = [s.metadata.get("domain") for s in samples]
if any(domains):
train_data["domains"] = domains

if any(sample.multimodal_train_inputs is not None for sample in samples):
train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples]

Expand Down Expand Up @@ -782,6 +793,7 @@ def _split_train_data_by_dp(self, data, dp_size):
"prompt",
"teacher_log_probs",
"weight_versions",
"domains",
]:
if key not in data:
continue
Expand Down Expand Up @@ -1205,10 +1217,16 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_
log_dict = {**(rollout_extra_metrics or {})}
log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/")
log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/")
# Mirror reward/* and response_stats/* as top-level wandb panels.
for full_key, val in list(log_dict.items()):
if full_key.startswith(("rollout/reward/", "rollout/response_stats/")):
log_dict[full_key[len("rollout/"):]] = val
logger.info(f"perf {rollout_id}: {log_dict}")
step = compute_rollout_step(args, rollout_id)
log_dict["rollout/step"] = step
log_dict["train/rollout_id"] = rollout_id
log_dict["samples_seen"] = compute_samples_seen(args, rollout_id)
log_dict["rollout_step"] = step
tracking_utils.log(args, log_dict, step_key="rollout/step")


Expand Down Expand Up @@ -1256,8 +1274,9 @@ def compute_metrics_from_samples(args, samples):
log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward")

# per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward)
correct = [s for s in samples if s.get_reward_value(args) > 0]
incorrect = [s for s in samples if s.get_reward_value(args) <= 0]
correct = [s for s in samples if _correctness(s, args)]
incorrect = [s for s in samples if not _correctness(s, args)]
log_dict["reward/correctness"] = len(correct) / n
for label, grp in [("correct", correct), ("incorrect", incorrect)]:
if grp:
log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False)
Expand All @@ -1272,10 +1291,10 @@ def compute_metrics_from_samples(args, samples):
log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n)
log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}")
log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}")
for label, grp in [
("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]),
("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]),
]:
cat_correct = [s for s in cat_grp if _correctness(s, args)]
cat_incorrect = [s for s in cat_grp if not _correctness(s, args)]
log_dict[f"reward/{cat}/correctness"] = len(cat_correct) / len(cat_grp)
for label, grp in [("correct", cat_correct), ("incorrect", cat_incorrect)]:
if grp:
log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n)
log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}")
Expand Down Expand Up @@ -1320,6 +1339,9 @@ def _compute_zero_std_metrics(args, all_samples: list[Sample]):
# only compute in GRPO-like algorithms where one prompt has multiple responses
if args.advantage_estimator == "ppo":
return {}
# Skip non-scalar rewards (round() and zero-std comparison are ill-defined on dicts).
if all_samples and not isinstance(all_samples[0].get_reward_value(args), (int, float)):
return {}

def _is_zero_std(samples: list[Sample]):
rewards = [sample.get_reward_value(args) for sample in samples]
Expand Down Expand Up @@ -1364,8 +1386,11 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]):
return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()}


# Candidate metadata keys to auto-detect problem category (checked in order)
_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"]
# Candidate metadata keys to auto-detect problem category (checked in order).
# `domain` is first because it's the routing key in multi-teacher setups; if a sample
# carries both `domain` and one of the legacy fields like `category`, group by domain
# so per-cat correctness panels match the per-domain loss panels.
_CANDIDATE_CATEGORY_KEYS = ["domain", "category", "type", "subject", "problem_type"]


def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None:
Expand All @@ -1381,11 +1406,22 @@ def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None:
return None


def _correctness(sample: Sample, args) -> bool:
"""Non-scalar reward fns set metadata["correctness_reward"]; scalars fall back to sign."""
if "correctness_reward" in sample.metadata:
return sample.metadata["correctness_reward"] > 0
val = sample.get_reward_value(args)
return isinstance(val, (int, float)) and val > 0


def _compute_grouped_reward_metrics(
args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True
) -> dict:
"""Reward/outcome metrics for a split — emitted under reward/ sections."""
result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()}
result = {}
# Skip raw_reward when reward is non-scalar (e.g. dict-valued OPD rewards).
if group and isinstance(group[0].get_reward_value(args), (int, float)):
result[f"{prefix}/raw_reward"] = np.mean([s.get_reward_value(args) for s in group]).item()
if include_count_frac:
result[f"{prefix}/count_frac"] = len(group) / n_total
return result
Expand All @@ -1410,8 +1446,8 @@ def _compute_group_outcome_metrics(
n_groups = len(groups)
if n_groups == 0:
return {}
all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g))
all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g))
all_correct = sum(1 for g in groups if all(_correctness(s, args) for s in g))
all_incorrect = sum(1 for g in groups if all(not _correctness(s, args) for s in g))
return {
f"{prefix}/all_correct_group_frac": all_correct / n_groups,
f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups,
Expand Down
5 changes: 5 additions & 0 deletions miles/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,8 @@ def compute_rollout_step(args, rollout_id):
if args.wandb_always_use_train_step:
return rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
return rollout_id


def compute_samples_seen(args, rollout_id: int) -> int:
"""Cumulative samples through (and including) rollout `rollout_id` (0-indexed)."""
return args.rollout_batch_size * args.n_samples_per_prompt * (rollout_id + 1)
5 changes: 5 additions & 0 deletions miles/utils/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,8 @@ def _init_wandb_common():
# rollout counter. Declared after the "train/*" wildcard so the specific name
# isn't inadvertently treated as step-metric'd against train/step.
wandb.define_metric("train/rollout_id")
# Bare step counters — co-logged so one panel can plot all three as time series
# (useful for spotting non-monotone train/step jumps under dynamic batching).
wandb.define_metric("samples_seen")
wandb.define_metric("train_step")
wandb.define_metric("rollout_step")
19 changes: 19 additions & 0 deletions scripts/models/xllm-8B.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# xllm 8B dense GQA
MODEL_ARGS=(
--swiglu
--num-layers 36
--hidden-size 4096
--ffn-hidden-size 12288
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--kv-channels 128
--disable-bias-linear
--normalization RMSNorm
--norm-epsilon 1e-6
--position-embedding-type rope
--rotary-percent 1.0
--rotary-base 10000000
--untie-embeddings-and-output-weights
--vocab-size 250624
)
Loading
Loading