diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 3bda887b26..2733c93442 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -115,11 +115,23 @@ def get_batch( parallel_state = get_parallel_state() assert "tokens" in keys + has_domains = "domains" in data_iterator.rollout_data + if has_domains and "domains" not in keys: + keys = list(keys) + ["domains"] batch = data_iterator.get_next(keys) if "dynamic_global_batch_size" in data_iterator.rollout_data: batch["dynamic_global_batch_size"] = data_iterator.rollout_data["dynamic_global_batch_size"] + # Canonical domain set, cached on the iterator so every microbatch emits the + # same list (aggregate_train_losses keys positionally on the first microbatch). + if has_domains: + if not hasattr(data_iterator, "_all_domains_cache"): + data_iterator._all_domains_cache = sorted( + {d for d in data_iterator.rollout_data["domains"] if d} + ) + batch["all_domains"] = data_iterator._all_domains_cache + tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 4096a4f8ac..120560cd41 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -51,6 +51,13 @@ "returns": "reward", } +# Cumulative train-step counter across all rollouts. The previous formula +# `rollout_id * num_steps_per_rollout + step_id` collides (and decreases) when +# `num_steps_per_rollout` shrinks across rollouts under dynamic batching, since +# each rollout uses its own current num_steps_per_rollout as a scaling factor. +# A simple monotone counter is invariant to that jitter. +_TRAIN_STEP_COUNTER = 0 + def gather_log_data( metric_name: str, @@ -77,7 +84,15 @@ def gather_log_data( # dict to the union of keys with NaN so every rank sends the same shape. # Cost is one all_gather_object on a tiny key list. all_keys: list = [None] * dp_size + logger.info( + f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " + f"rollout={rollout_id} entering all_gather_object, keys={len(log_dict)}" + ) dist.all_gather_object(all_keys, sorted(log_dict.keys()), group=pg.gloo_group) + logger.info( + f"[rank={pg.rank}/{dp_size}] gather_log_data({metric_name}) " + f"rollout={rollout_id} all_gather_object returned" + ) union_keys: set = set() for ks in all_keys: if ks: @@ -170,6 +185,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "dynamic_global_batch_size", "weight_versions", "metadata", + "domains", ]: continue # Upload per sample mean for each rollout value @@ -492,7 +508,9 @@ def log_train_step( Returns: The formatted log_dict (for CI tests or other uses). """ - accumulated_step_id = rollout_id * num_steps_per_rollout + step_id + global _TRAIN_STEP_COUNTER + accumulated_step_id = _TRAIN_STEP_COUNTER + _TRAIN_STEP_COUNTER += 1 role_tag = "" if role == "actor" else f"{role}-" log_dict_out = { @@ -525,7 +543,7 @@ def log_train_step( # cross-plotted against rollout-side axes in the wandb UI. log_dict_out["train/rollout_id"] = rollout_id log_dict_out["train/step_in_rollout"] = step_id - log_dict_out["rollout/step"] = compute_rollout_step(args, rollout_id) + log_dict_out["train_step"] = accumulated_step_id # Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged) grouped_additions = {} @@ -534,9 +552,13 @@ def log_train_step( if not full_key.startswith(prefix): continue bare_key = full_key[len(prefix):] - if bare_key in _TRAIN_METRIC_GROUPS: - for group in _TRAIN_METRIC_GROUPS[bare_key]: - grouped_additions[f"{group}/{bare_key}"] = val + # Per-domain keys arrive as "/" — route to "//". + metric_name, sep, domain = bare_key.rpartition("/") + lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key + if lookup in _TRAIN_METRIC_GROUPS: + suffix = f"{domain}/{metric_name}" if lookup == metric_name else bare_key + for group in _TRAIN_METRIC_GROUPS[lookup]: + grouped_additions[f"{group}/{suffix}"] = val elif bare_key.startswith("lr-pg_"): grouped_additions[f"optimization/{bare_key}"] = val log_dict_out.update(grouped_additions) diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index e0eccde8b4..c676c46195 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -652,6 +652,11 @@ def policy_loss_function( else: pg_loss_reducer = sum_of_sample_mean + # Saved for per-domain fan-out (reducers below overwrite these names with scalars). + _pg_loss_per_token = pg_loss + _pg_clipfrac_per_token = pg_clipfrac + _ppo_kl_per_token = ppo_kl + pg_loss = pg_loss_reducer(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) @@ -690,18 +695,24 @@ def policy_loss_function( # Train-inference mismatch: compare inference engine vs FSDP at rollout time train_rollout_logprob_abs_diff = None train_rollout_logprob_diff = None + _train_rollout_logprob_abs_per_token = None + _train_rollout_logprob_signed_per_token = None if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + _train_rollout_logprob_abs_per_token = (old_log_probs - rollout_log_probs_cat).abs() # signed: log π(inf) − log π(fsdp rollout) - train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() + _train_rollout_logprob_signed_per_token = rollout_log_probs_cat - log_probs_batch_cat + train_rollout_logprob_abs_diff = sum_of_sample_mean(_train_rollout_logprob_abs_per_token).clone().detach() + train_rollout_logprob_diff = sum_of_sample_mean(_train_rollout_logprob_signed_per_token).clone().detach() # KL vs reference model — always log when ref present, regardless of use_kl_loss ref_kl_metric = None + _ref_kl_per_token = None if "ref_log_probs" in batch and batch["ref_log_probs"]: ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0) - ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach() + _ref_kl_per_token = log_probs - ref_log_probs_cat + ref_kl_metric = sum_of_sample_mean(_ref_kl_per_token).clone().detach() reported_loss = { "loss": loss.clone().detach(), @@ -735,6 +746,44 @@ def policy_loss_function( if args.use_opsm: reported_loss["opsm_clipfrac"] = opsm_clipfrac + # Per-domain fan-out: activated by batch["domains"] (set when samples carry + # metadata["domain"]). batch["all_domains"] is cached on DataIterator so every + # microbatch emits the same key set (aggregate_train_losses keys positionally). + # grad_norm isn't split: backward() has already mixed gradients. + if batch.get("domains") and batch.get("all_domains"): + per_token = { + "log_probs": log_probs, + "old_log_probs": old_log_probs, + "pg_loss": _pg_loss_per_token, + "pg_clipfrac": _pg_clipfrac_per_token, + "ppo_kl": _ppo_kl_per_token, + "entropy_loss": entropy, + } + if _ref_kl_per_token is not None: + per_token["ref_kl"] = _ref_kl_per_token + if _train_rollout_logprob_signed_per_token is not None: + per_token["train_rollout_logprob_diff"] = _train_rollout_logprob_signed_per_token + per_token["train_rollout_logprob_abs_diff"] = _train_rollout_logprob_abs_per_token + if args.get_mismatch_metrics or args.use_tis: + per_token["ois"] = ois + per_token.update(tis_metrics) + + for d in batch["all_domains"]: + masked = [ + lm if dd == d else torch.zeros_like(lm) + for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False) + ] + reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked, + args.calculate_per_token_loss, args.qkv_format, max_seq_lens, + loss_agg_mode=getattr(args, "loss_agg_mode", None), + ) + for name, t in per_token.items(): + reported_loss[f"{name}/{d}"] = reducer(t).clone().detach() + reported_loss[f"loss/{d}"] = ( + reported_loss[f"pg_loss/{d}"] - args.entropy_coef * reported_loss[f"entropy_loss/{d}"] + ) + return loss, reported_loss diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 17905500a3..18bc26b416 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -38,7 +38,13 @@ from miles.utils.iter_utils import group_by from miles.utils.logging_utils import configure_logger from miles.utils.metric_checker import MetricChecker -from miles.utils.metric_utils import compute_pass_rate, compute_rollout_step, compute_statistics, dict_add_prefix +from miles.utils.metric_utils import ( + compute_pass_rate, + compute_rollout_step, + compute_samples_seen, + compute_statistics, + dict_add_prefix, +) from miles.utils.misc import load_function from miles.utils.ray_utils import Box from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions @@ -728,6 +734,11 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] + # Presence of metadata["domain"] activates per-domain metric fan-out in policy_loss_function. + domains = [s.metadata.get("domain") for s in samples] + if any(domains): + train_data["domains"] = domains + if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] @@ -782,6 +793,7 @@ def _split_train_data_by_dp(self, data, dp_size): "prompt", "teacher_log_probs", "weight_versions", + "domains", ]: if key not in data: continue @@ -1205,10 +1217,16 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_ log_dict = {**(rollout_extra_metrics or {})} log_dict |= dict_add_prefix(compute_metrics_from_samples(args, samples), "rollout/") log_dict |= dict_add_prefix(compute_perf_metrics_from_samples(args, samples, rollout_time), "perf/") + # Mirror reward/* and response_stats/* as top-level wandb panels. + for full_key, val in list(log_dict.items()): + if full_key.startswith(("rollout/reward/", "rollout/response_stats/")): + log_dict[full_key[len("rollout/"):]] = val logger.info(f"perf {rollout_id}: {log_dict}") step = compute_rollout_step(args, rollout_id) log_dict["rollout/step"] = step log_dict["train/rollout_id"] = rollout_id + log_dict["samples_seen"] = compute_samples_seen(args, rollout_id) + log_dict["rollout_step"] = step tracking_utils.log(args, log_dict, step_key="rollout/step") @@ -1256,8 +1274,9 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward") # per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward) - correct = [s for s in samples if s.get_reward_value(args) > 0] - incorrect = [s for s in samples if s.get_reward_value(args) <= 0] + correct = [s for s in samples if _correctness(s, args)] + incorrect = [s for s in samples if not _correctness(s, args)] + log_dict["reward/correctness"] = len(correct) / n for label, grp in [("correct", correct), ("incorrect", incorrect)]: if grp: log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False) @@ -1272,10 +1291,10 @@ def compute_metrics_from_samples(args, samples): log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n) log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}") log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}") - for label, grp in [ - ("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]), - ("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]), - ]: + cat_correct = [s for s in cat_grp if _correctness(s, args)] + cat_incorrect = [s for s in cat_grp if not _correctness(s, args)] + log_dict[f"reward/{cat}/correctness"] = len(cat_correct) / len(cat_grp) + for label, grp in [("correct", cat_correct), ("incorrect", cat_incorrect)]: if grp: log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n) log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}") @@ -1320,6 +1339,9 @@ def _compute_zero_std_metrics(args, all_samples: list[Sample]): # only compute in GRPO-like algorithms where one prompt has multiple responses if args.advantage_estimator == "ppo": return {} + # Skip non-scalar rewards (round() and zero-std comparison are ill-defined on dicts). + if all_samples and not isinstance(all_samples[0].get_reward_value(args), (int, float)): + return {} def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] @@ -1364,8 +1386,11 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]): return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} -# Candidate metadata keys to auto-detect problem category (checked in order) -_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"] +# Candidate metadata keys to auto-detect problem category (checked in order). +# `domain` is first because it's the routing key in multi-teacher setups; if a sample +# carries both `domain` and one of the legacy fields like `category`, group by domain +# so per-cat correctness panels match the per-domain loss panels. +_CANDIDATE_CATEGORY_KEYS = ["domain", "category", "type", "subject", "problem_type"] def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: @@ -1381,11 +1406,22 @@ def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: return None +def _correctness(sample: Sample, args) -> bool: + """Non-scalar reward fns set metadata["correctness_reward"]; scalars fall back to sign.""" + if "correctness_reward" in sample.metadata: + return sample.metadata["correctness_reward"] > 0 + val = sample.get_reward_value(args) + return isinstance(val, (int, float)) and val > 0 + + def _compute_grouped_reward_metrics( args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True ) -> dict: """Reward/outcome metrics for a split — emitted under reward/ sections.""" - result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()} + result = {} + # Skip raw_reward when reward is non-scalar (e.g. dict-valued OPD rewards). + if group and isinstance(group[0].get_reward_value(args), (int, float)): + result[f"{prefix}/raw_reward"] = np.mean([s.get_reward_value(args) for s in group]).item() if include_count_frac: result[f"{prefix}/count_frac"] = len(group) / n_total return result @@ -1410,8 +1446,8 @@ def _compute_group_outcome_metrics( n_groups = len(groups) if n_groups == 0: return {} - all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g)) - all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g)) + all_correct = sum(1 for g in groups if all(_correctness(s, args) for s in g)) + all_incorrect = sum(1 for g in groups if all(not _correctness(s, args) for s in g)) return { f"{prefix}/all_correct_group_frac": all_correct / n_groups, f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups, diff --git a/miles/utils/metric_utils.py b/miles/utils/metric_utils.py index 66292c79e7..839b50e25f 100644 --- a/miles/utils/metric_utils.py +++ b/miles/utils/metric_utils.py @@ -118,3 +118,8 @@ def compute_rollout_step(args, rollout_id): if args.wandb_always_use_train_step: return rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size return rollout_id + + +def compute_samples_seen(args, rollout_id: int) -> int: + """Cumulative samples through (and including) rollout `rollout_id` (0-indexed).""" + return args.rollout_batch_size * args.n_samples_per_prompt * (rollout_id + 1) diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index c29fcc7eaa..f9524c91af 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -175,3 +175,8 @@ def _init_wandb_common(): # rollout counter. Declared after the "train/*" wildcard so the specific name # isn't inadvertently treated as step-metric'd against train/step. wandb.define_metric("train/rollout_id") + # Bare step counters — co-logged so one panel can plot all three as time series + # (useful for spotting non-monotone train/step jumps under dynamic batching). + wandb.define_metric("samples_seen") + wandb.define_metric("train_step") + wandb.define_metric("rollout_step") diff --git a/scripts/models/xllm-8B.sh b/scripts/models/xllm-8B.sh new file mode 100644 index 0000000000..d0ed77694e --- /dev/null +++ b/scripts/models/xllm-8B.sh @@ -0,0 +1,19 @@ +# xllm 8B dense GQA +MODEL_ARGS=( + --swiglu + --num-layers 36 + --hidden-size 4096 + --ffn-hidden-size 12288 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --kv-channels 128 + --disable-bias-linear + --normalization RMSNorm + --norm-epsilon 1e-6 + --position-embedding-type rope + --rotary-percent 1.0 + --rotary-base 10000000 + --untie-embeddings-and-output-weights + --vocab-size 250624 +) diff --git a/tests/fast/backends/training_utils/test_metric_domains.py b/tests/fast/backends/training_utils/test_metric_domains.py new file mode 100644 index 0000000000..79f7aea67b --- /dev/null +++ b/tests/fast/backends/training_utils/test_metric_domains.py @@ -0,0 +1,245 @@ +"""Tests for per-domain metric fan-out and unified correctness signal. + +These exercise the math/logic added when upstreaming OPD-specific metrics +into miles. The goal is to verify: + +1. The masked-loss-mask trick used for per-domain reductions: when we zero + out non-target samples' loss_masks, the resulting `sum_of_sample_mean` + reducer produces the same value as `sum_of_sample_mean` over only the + target-domain samples. + +2. Per-domain reductions partition the global reduction. When domains + partition the batch and every sample has the same per-token weight + distribution, the (count-weighted) sum across domains matches the + global reduction. + +3. The `_correctness(s, args)` helper unifies scalar GRPO reward sign and + non-scalar OPD `metadata["correctness_reward"]` into one path. + +4. `compute_samples_seen` returns the cumulative per-rollout sample count. +""" +from __future__ import annotations + +import sys +import types +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any + +import torch + +from miles.backends.training_utils.cp_utils import get_sum_of_sample_mean +from miles.utils.metric_utils import compute_samples_seen + + +# --------------------------------------------------------------------------- +# Stub miles.backends.training_utils.parallel.get_parallel_state so cp_utils' +# sum_of_sample_mean (which calls it) returns CP=1. +# --------------------------------------------------------------------------- + +class _FakePG: + size = 1 + rank = 0 + + +class _FakeParallelState: + cp = _FakePG() + + +def _patch_parallel_state(monkeypatch): + import miles.backends.training_utils.cp_utils as cp_utils + monkeypatch.setattr(cp_utils, "get_parallel_state", lambda: _FakeParallelState()) + + +# --------------------------------------------------------------------------- +# 1. Masked-loss-mask trick: per-domain reducer ignores non-target samples +# --------------------------------------------------------------------------- + +def test_domain_filtered_reducer_matches_per_domain_subset(monkeypatch): + _patch_parallel_state(monkeypatch) + + # 3 samples: math, code, math. Each sample has 4 response tokens. + # Per-token "values" tensor `x` is the per-sample value broadcast over tokens. + domains = ["math", "code", "math"] + response_lengths = [4, 4, 4] + total_lengths = [4, 4, 4] + loss_masks = [torch.ones(4) for _ in range(3)] + + # x: per-sample mean is [1.0, 2.0, 3.0] + x = torch.tensor([1.0]*4 + [2.0]*4 + [3.0]*4) + + # Build a "math"-filtered reducer + masked_for_math = [ + lm if d == "math" else torch.zeros_like(lm) + for d, lm in zip(domains, loss_masks) + ] + math_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_math, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + + # sum_of_sample_mean: per-sample token-mean, then sum across samples. + # For math: sample 0 contributes 1.0, sample 2 contributes 3.0, sample 1 + # contributes 0 (its mask is all zeros, clamp_min returns 1 in denominator + # but numerator is also 0). Result should be 1.0 + 0 + 3.0 = 4.0. + result = math_reducer(x).item() + assert result == 4.0, f"expected 4.0, got {result}" + + # Same for "code": only sample 1 contributes, value 2.0 + masked_for_code = [ + lm if d == "code" else torch.zeros_like(lm) + for d, lm in zip(domains, loss_masks) + ] + code_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_code, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + result = code_reducer(x).item() + assert result == 2.0, f"expected 2.0, got {result}" + + +def test_domain_reducer_returns_zero_for_absent_domain(monkeypatch): + """When no sample matches the target domain, the reducer must return 0 + (not NaN, not error). aggregate_train_losses requires every microbatch + emit the same key set — a domain with zero samples in this microbatch + must contribute 0 to the positional aggregation.""" + _patch_parallel_state(monkeypatch) + + domains = ["math", "math"] # no "code" samples + response_lengths = [3, 3] + total_lengths = [3, 3] + loss_masks = [torch.ones(3), torch.ones(3)] + x = torch.tensor([5.0, 5.0, 5.0, 7.0, 7.0, 7.0]) + + masked_for_code = [torch.zeros_like(lm) for lm in loss_masks] # all zero + code_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, masked_for_code, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + result = code_reducer(x).item() + assert result == 0.0, f"expected 0 for absent domain, got {result}" + assert not torch.isnan(torch.tensor(result)) + + +def test_per_domain_reductions_sum_to_global(monkeypatch): + """When domains partition the batch and we use sum-mode (sum_of_sample_mean + sums per-sample means), summing per-domain reductions equals the global one.""" + _patch_parallel_state(monkeypatch) + + domains = ["math", "code", "math", "code"] + response_lengths = [2, 2, 2, 2] + total_lengths = [2, 2, 2, 2] + loss_masks = [torch.ones(2) for _ in range(4)] + # per-sample means: [1, 2, 3, 4] + x = torch.tensor([1.0]*2 + [2.0]*2 + [3.0]*2 + [4.0]*2) + + global_reducer = get_sum_of_sample_mean( + total_lengths, response_lengths, loss_masks, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + global_value = global_reducer(x).item() + assert global_value == 1.0 + 2.0 + 3.0 + 4.0 # = 10.0 + + # Sum per-domain reductions. {math}=samples 0,2 -> 1+3=4; {code}=1,3 -> 2+4=6. + per_domain = 0.0 + for target in ["math", "code"]: + masked = [lm if d == target else torch.zeros_like(lm) for d, lm in zip(domains, loss_masks)] + red = get_sum_of_sample_mean( + total_lengths, response_lengths, masked, + calculate_per_token_loss=False, qkv_format="thd", max_seq_lens=None, + ) + per_domain += red(x).item() + + assert per_domain == global_value, f"per-domain sum {per_domain} != global {global_value}" + + +# --------------------------------------------------------------------------- +# 2. Unified correctness signal: scalar fallback + metadata override +# --------------------------------------------------------------------------- +# +# The actual `_correctness` helper lives in miles.ray.rollout but importing +# that module pulls in ray. Vendor the implementation here as a fixture and +# verify the *semantic* contract — the implementation in rollout.py is a +# verbatim copy of this snippet (kept in sync by code review). + + +def _vendored_correctness(sample, args) -> bool: + """Mirror of miles.ray.rollout._correctness — keep in sync.""" + if "correctness_reward" in sample.metadata: + return sample.metadata["correctness_reward"] > 0 + val = sample.get_reward_value(args) + return isinstance(val, (int, float)) and val > 0 + + +@dataclass +class _FakeSample: + reward: Any + metadata: dict = field(default_factory=dict) + + def get_reward_value(self, args): + return self.reward if not args.reward_key else self.reward[args.reward_key] + + +def test_correctness_scalar_fallback(): + args = Namespace(reward_key=None) + assert _vendored_correctness(_FakeSample(1.0), args) is True + assert _vendored_correctness(_FakeSample(0.5), args) is True + assert _vendored_correctness(_FakeSample(0.0), args) is False + assert _vendored_correctness(_FakeSample(-0.3), args) is False + + +def test_correctness_metadata_override_takes_precedence(): + """metadata['correctness_reward'] wins even when reward is also scalar.""" + args = Namespace(reward_key=None) + s = _FakeSample(1.0, metadata={"correctness_reward": 0.0}) # scalar says correct, metadata says wrong + assert _vendored_correctness(s, args) is False + s = _FakeSample(0.0, metadata={"correctness_reward": 1.0}) + assert _vendored_correctness(s, args) is True + + +def test_correctness_non_scalar_reward_with_metadata(): + """OPD path: reward is a dict, correctness comes from metadata.""" + args = Namespace(reward_key=None) + s = _FakeSample({"kl_a": 0.5, "kl_b": 0.3}, metadata={"correctness_reward": 1.0}) + assert _vendored_correctness(s, args) is True + s = _FakeSample({"kl_a": 0.5}, metadata={"correctness_reward": 0.0}) + assert _vendored_correctness(s, args) is False + + +def test_correctness_non_scalar_reward_no_metadata_returns_false(): + """Without correctness_reward metadata and a non-scalar reward, the + helper returns False (val>0 short-circuits via isinstance check).""" + args = Namespace(reward_key=None) + s = _FakeSample({"kl_a": 0.5}) + assert _vendored_correctness(s, args) is False + + +# --------------------------------------------------------------------------- +# 3. compute_samples_seen +# --------------------------------------------------------------------------- + +def test_compute_samples_seen_first_rollout(): + args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) + # rollout_id=0 means the first rollout has finished -> 32 samples seen. + assert compute_samples_seen(args, 0) == 32 + + +def test_compute_samples_seen_monotone(): + args = Namespace(rollout_batch_size=8, n_samples_per_prompt=4) + seen = [compute_samples_seen(args, i) for i in range(5)] + # Strictly monotone, increment of 32 per rollout. + assert seen == [32, 64, 96, 128, 160] + + +# --------------------------------------------------------------------------- +# 4. Activation-by-presence: domains list is sorted-unique (matches the +# DataIterator._all_domains_cache construction). +# --------------------------------------------------------------------------- + +def test_all_domains_cache_construction(): + """Mirror the construction in get_batch: sorted({d for d in domains if d}).""" + domains = ["math", "code", None, "math", "code", "science", None] + all_domains = sorted({d for d in domains if d}) + assert all_domains == ["code", "math", "science"] + # None values are filtered out (samples without a domain don't add a key). + assert None not in all_domains