|
8 | 8 | pytest eval_protocol/benchmarks/ifeval/test_ifeval.py -v |
9 | 9 | """ |
10 | 10 |
|
| 11 | +import asyncio |
11 | 12 | import json |
12 | 13 | from pathlib import Path |
13 | | -from typing import Any |
14 | 14 |
|
15 | | -from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message, MetricResult |
| 15 | +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult |
16 | 16 | from eval_protocol.pytest import evaluation_test |
17 | 17 | from eval_protocol.pytest.default_single_turn_rollout_process import SingleTurnRolloutProcessor |
| 18 | +from eval_protocol.pytest.rollout_processor import RolloutProcessor |
| 19 | +from eval_protocol.pytest.types import RolloutProcessorConfig |
18 | 20 |
|
19 | 21 | from .reward import ifeval_partial_credit_reward |
20 | 22 |
|
@@ -46,29 +48,44 @@ def _coerce_content_to_str(content: str | list | None) -> str: |
46 | 48 | _IFBENCH_MESSAGES = _load_ifbench_messages() |
47 | 49 |
|
48 | 50 |
|
49 | | -class IFEvalRolloutProcessor(SingleTurnRolloutProcessor): |
50 | | - """Preprocess rows to extract ground_truth from __GT__ messages.""" |
51 | | - |
52 | | - def preprocess_row(self, row: EvaluationRow) -> EvaluationRow: |
53 | | - """Extract ground truth and remove __GT__ messages.""" |
54 | | - filtered_messages: list[Message] = [] |
55 | | - for m in row.messages: |
56 | | - content_str = _coerce_content_to_str(m.content) |
57 | | - if m.role == "system" and content_str.startswith("__GT__:"): |
58 | | - # Extract ground truth |
59 | | - row.ground_truth = content_str.split(":", 1)[1].strip() |
60 | | - else: |
61 | | - filtered_messages.append(m) |
62 | | - row.messages = filtered_messages |
63 | | - return row |
| 51 | +class IFEvalGroundTruthRolloutProcessor(RolloutProcessor): |
| 52 | + """Extract ground truth from __GT__ system messages, then run single-turn rollouts.""" |
| 53 | + |
| 54 | + def __init__(self) -> None: |
| 55 | + super().__init__() |
| 56 | + self.single_turn_processor = SingleTurnRolloutProcessor() |
| 57 | + |
| 58 | + def __call__( |
| 59 | + self, rows: list[EvaluationRow], config: RolloutProcessorConfig |
| 60 | + ) -> list[asyncio.Task[EvaluationRow]]: |
| 61 | + processed: list[EvaluationRow] = [] |
| 62 | + for r in rows: |
| 63 | + gt_tokens: list[str] = [] |
| 64 | + for m in r.messages: |
| 65 | + if m.role == "system": |
| 66 | + content_str = _coerce_content_to_str(m.content) |
| 67 | + if content_str.startswith("__GT__:"): |
| 68 | + gt_tokens.append(content_str) |
| 69 | + if gt_tokens: |
| 70 | + r.ground_truth = gt_tokens[-1].split(":", 1)[1].strip() |
| 71 | + filtered: list[Message] = [] |
| 72 | + for m in r.messages: |
| 73 | + if m.role == "system": |
| 74 | + content_str = _coerce_content_to_str(m.content) |
| 75 | + if content_str.startswith("__GT__:"): |
| 76 | + continue |
| 77 | + filtered.append(m) |
| 78 | + r.messages = filtered |
| 79 | + processed.append(r) |
| 80 | + return self.single_turn_processor(processed, config) |
64 | 81 |
|
65 | 82 |
|
66 | 83 | @evaluation_test( |
67 | 84 | input_messages=_IFBENCH_MESSAGES, |
68 | 85 | completion_params=[ |
69 | 86 | {"model": "fireworks_ai/accounts/fireworks/models/qwen3-8b"} |
70 | 87 | ], |
71 | | - rollout_processor=IFEvalRolloutProcessor(), |
| 88 | + rollout_processor=IFEvalGroundTruthRolloutProcessor(), |
72 | 89 | aggregation_method="mean", |
73 | 90 | passed_threshold=0.5, |
74 | 91 | num_runs=1, |
|
0 commit comments