diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index c6ed8b6ffb..06f568fcc6 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -19,6 +19,38 @@ logger = logging.getLogger(__name__) +# Maps bare metric names to their W&B top-level section(s). +# Keys appearing in multiple sections (e.g. pg_loss) are emitted under each. +_TRAIN_METRIC_GROUPS: dict[str, list[str]] = { + "ppo_kl": ["policy_shift"], + "ois": ["policy_shift"], + "pg_clipfrac": ["policy_shift"], + "pg_loss": ["policy_shift", "optimization"], + "log_probs": ["policy_shift"], # current policy (training forward pass) + "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) + "ref_kl": ["policy_shift"], + "train_rollout_logprob_abs_diff": ["train_inference_mismatch"], + "train_rollout_logprob_diff": ["train_inference_mismatch"], + "tis": ["train_inference_mismatch"], + "tis_abs": ["train_inference_mismatch"], + "tis_clipfrac": ["train_inference_mismatch"], + "loss": ["optimization"], + "entropy_loss": ["optimization"], + "kl_loss": ["optimization"], + "grad_norm": ["optimization"], +} + +# Maps rollout batch field names to their W&B top-level section. +_ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = { + "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time + "rollout_log_probs": "train_inference_mismatch", # inference engine log probs + "ref_log_probs": "policy_shift", # reference model log probs + "rewards": "reward", + "raw_reward": "reward", + "advantages": "reward", + "returns": "reward", +} + def gather_log_data( metric_name: str, @@ -185,6 +217,17 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc if "rollout/entropy" in reduced_log_dict: assert 0 < reduced_log_dict["rollout/entropy"] < 0.7 + # Emit top-level grouped keys from reduced values (only on DP source rank) + if reduced_log_dict is not None: + top_level = {} + for key, group in _ROLLOUT_DATA_METRIC_GROUPS.items(): + rollout_key = f"rollout/{key}" + if rollout_key in reduced_log_dict: + top_level[f"{group}/{key}"] = reduced_log_dict[rollout_key] + if top_level: + step = compute_rollout_step(args, rollout_id) + top_level["rollout/step"] = step + tracking_utils.log(args, top_level, step_key="rollout/step") if args.ci_test and args.true_on_policy_mode: assert log_dict["log_probs"] == log_dict["rollout_log_probs"], ( f"CI check failed: true_on_policy_mode is enabled, but log_probs " @@ -436,6 +479,20 @@ def log_train_step( log_dict_out["train/step"] = accumulated_step_id + # Emit top-level grouped copies for W&B panel organization (existing train/ keys unchanged) + grouped_additions = {} + prefix = f"train/{role_tag}" + for full_key, val in log_dict_out.items(): + if not full_key.startswith(prefix): + continue + bare_key = full_key[len(prefix):] + if bare_key in _TRAIN_METRIC_GROUPS: + for group in _TRAIN_METRIC_GROUPS[bare_key]: + grouped_additions[f"{group}/{bare_key}"] = val + elif bare_key.startswith("lr-pg_"): + grouped_additions[f"optimization/{bare_key}"] = val + log_dict_out.update(grouped_additions) + if should_log is None: should_log = dist.get_rank() == 0 diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index 8b69d18146..3e5f78495c 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -682,10 +682,25 @@ def policy_loss_function( if log_probs.numel() == 0: loss += 0 * logits.sum() + # Current and old policy log probs for policy_shift panel + log_probs_metric = sum_of_sample_mean(log_probs).clone().detach() + old_log_probs_metric = sum_of_sample_mean(old_log_probs).clone().detach() + + # Train-inference mismatch: compare inference engine vs FSDP at rollout time train_rollout_logprob_abs_diff = None + train_rollout_logprob_diff = None if "rollout_log_probs" in batch and batch["rollout_log_probs"]: - rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) - train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs).abs()) + rollout_log_probs_cat = torch.cat(batch["rollout_log_probs"], dim=0) + log_probs_batch_cat = torch.cat(batch["log_probs"], dim=0) + train_rollout_logprob_abs_diff = sum_of_sample_mean((old_log_probs - rollout_log_probs_cat).abs()).clone().detach() + # signed: log π(inf) − log π(fsdp rollout) + train_rollout_logprob_diff = sum_of_sample_mean(rollout_log_probs_cat - log_probs_batch_cat).clone().detach() + + # KL vs reference model — always log when ref present, regardless of use_kl_loss + ref_kl_metric = None + if "ref_log_probs" in batch and batch["ref_log_probs"]: + ref_log_probs_cat = torch.cat(batch["ref_log_probs"], dim=0) + ref_kl_metric = sum_of_sample_mean(log_probs - ref_log_probs_cat).clone().detach() reported_loss = { "loss": loss.clone().detach(), @@ -693,10 +708,16 @@ def policy_loss_function( "entropy_loss": entropy_loss.clone().detach(), "pg_clipfrac": pg_clipfrac.clone().detach(), "ppo_kl": ppo_kl.clone().detach(), + "log_probs": log_probs_metric, + "old_log_probs": old_log_probs_metric, } if train_rollout_logprob_abs_diff is not None: - reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff.clone().detach() + reported_loss["train_rollout_logprob_abs_diff"] = train_rollout_logprob_abs_diff + reported_loss["train_rollout_logprob_diff"] = train_rollout_logprob_diff + + if ref_kl_metric is not None: + reported_loss["ref_kl"] = ref_kl_metric if args.use_kl_loss: reported_loss["kl_loss"] = kl_loss.clone().detach() diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..14719fdd4b 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1186,8 +1186,10 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_ def compute_metrics_from_samples(args, samples): response_lengths = [sample.effective_response_length for sample in samples] + n = len(samples) log_dict = {} + # existing keys (unchanged) log_dict |= dict_add_prefix(compute_statistics(response_lengths), "response_len/") log_dict |= _compute_zero_std_metrics(args, samples) log_dict |= _compute_spec_metrics(args, samples) @@ -1196,6 +1198,35 @@ def compute_metrics_from_samples(args, samples): log_dict["repetition_frac"] = np.mean([int(has_repetition(s.response)) for s in samples]).item() log_dict["truncated_ratio"] = np.mean([int(s.status == Sample.Status.TRUNCATED) for s in samples]).item() + # new top-level grouped keys: global + log_dict |= _compute_grouped_reward_metrics(args, samples, "reward", n, include_count_frac=False) + log_dict |= _compute_grouped_response_metrics(args, samples, "response_stats") + log_dict |= _compute_group_outcome_metrics(args, samples, prefix="reward") + + # per-correctness (no count_frac: for binary rewards = mean reward = already in reward/raw_reward) + correct = [s for s in samples if s.get_reward_value(args) > 0] + incorrect = [s for s in samples if s.get_reward_value(args) <= 0] + for label, grp in [("correct", correct), ("incorrect", incorrect)]: + if grp: + log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{label}", n, include_count_frac=False) + log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{label}") + + # per-category and combined (only if category data present) + cat_key = _get_problem_category_key(args, samples) + if cat_key is not None: + for cat, cat_grp in group_by(samples, lambda s: s.metadata.get(cat_key)).items(): + if cat is None or not cat_grp: + continue + log_dict |= _compute_grouped_reward_metrics(args, cat_grp, f"reward/{cat}", n) + log_dict |= _compute_grouped_response_metrics(args, cat_grp, f"response_stats/{cat}") + log_dict |= _compute_group_outcome_metrics(args, cat_grp, prefix=f"reward/{cat}") + for label, grp in [ + ("correct", [s for s in cat_grp if s.get_reward_value(args) > 0]), + ("incorrect", [s for s in cat_grp if s.get_reward_value(args) <= 0]), + ]: + if grp: + log_dict |= _compute_grouped_reward_metrics(args, grp, f"reward/{cat}/{label}", n) + log_dict |= _compute_grouped_response_metrics(args, grp, f"response_stats/{cat}/{label}") tito_vals = [s.metadata.get("tito_session_mismatch") for s in samples] tito_vals = [v for v in tito_vals if v is not None] if tito_vals: @@ -1297,3 +1328,57 @@ def _compute_reward_cat_metrics(args, all_samples: list[Sample]): samples_of_reward_cat = group_by(all_samples, lambda s: s.reward[reward_cat_key]) return {f"error_cat/{reward_cat}": len(s) / len(all_samples) for reward_cat, s in samples_of_reward_cat.items()} + + +# Candidate metadata keys to auto-detect problem category (checked in order) +_CANDIDATE_CATEGORY_KEYS = ["category", "type", "subject", "domain", "problem_type"] + + +def _get_problem_category_key(args, all_samples: list[Sample]) -> str | None: + """Return the metadata key to use for problem category grouping, or None if not available.""" + explicit = getattr(args, "log_problem_category", None) + if explicit: + return explicit + for sample in all_samples: + if sample.metadata: + for key in _CANDIDATE_CATEGORY_KEYS: + if key in sample.metadata: + return key + return None + + +def _compute_grouped_reward_metrics( + args, group: list[Sample], prefix: str, n_total: int, include_count_frac: bool = True +) -> dict: + """Reward/outcome metrics for a split — emitted under reward/ sections.""" + result = {f"{prefix}/raw_reward": np.mean([s.get_reward_value(args) for s in group]).item()} + if include_count_frac: + result[f"{prefix}/count_frac"] = len(group) / n_total + return result + + +def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) -> dict: + """Response shape metrics for a split — emitted under response_stats/ sections.""" + return { + f"{prefix}/response_len": np.mean([s.effective_response_length for s in group]).item(), + f"{prefix}/truncated_frac": np.mean([int(s.status == Sample.Status.TRUNCATED) for s in group]).item(), + f"{prefix}/repetition_frac": np.mean([int(has_repetition(s.response)) for s in group]).item(), + } + + +def _compute_group_outcome_metrics( + args, all_samples: list[Sample], prefix: str = "reward" +) -> dict: + """Fraction of prompt groups that are unanimously correct or incorrect. GRPO only.""" + if args.advantage_estimator == "ppo": + return {} + groups = list(group_by(all_samples, lambda s: s.group_index).values()) + n_groups = len(groups) + if n_groups == 0: + return {} + all_correct = sum(1 for g in groups if all(s.get_reward_value(args) > 0 for s in g)) + all_incorrect = sum(1 for g in groups if all(s.get_reward_value(args) <= 0 for s in g)) + return { + f"{prefix}/all_correct_group_frac": all_correct / n_groups, + f"{prefix}/all_incorrect_group_frac": all_incorrect / n_groups, + }