diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 2a75d492b9..af11accfc3 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1260,10 +1260,18 @@ def _is_zero_std(samples: list[Sample]): rewards = [sample.get_reward_value(args) for sample in samples] return len(rewards) == 0 or all(rewards[0] == r for r in rewards) + def _reward_label(sample: Sample) -> str: + # Aborted / None-reward samples have no numeric reward to round; bucket + # them under a dedicated label so downstream round() never sees None. + reward = sample.get_reward_value(args) + if reward is None: + return "none" + return str(round(reward, 1)) + all_sample_groups = group_by(all_samples, lambda s: s.group_index) interesting_sample_groups = [g for g in all_sample_groups.values() if _is_zero_std(g)] - interesting_rewards = [str(round(g[0].get_reward_value(args), 1)) for g in interesting_sample_groups] + interesting_rewards = [_reward_label(g[0]) for g in interesting_sample_groups] return {f"zero_std/count_{reward}": len(items) for reward, items in group_by(interesting_rewards).items()}