-
Notifications
You must be signed in to change notification settings - Fork 16
gsm8k math example #294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gsm8k math example #294
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| eval-protocol |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import re | ||
| from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message | ||
| from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test | ||
| from typing import List, Dict, Any, Optional | ||
|
|
||
|
|
||
| def extract_answer_digits(ground_truth: str) -> Optional[str]: | ||
| """ | ||
| Extract the digits from the answer string. | ||
| """ | ||
| answer_string = ground_truth.split("<answer>")[1].split("</answer>")[0] | ||
| return re.search(r"(\d+)", answer_string).group(1) if answer_string else None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Function Fails Gracefully on Malformed InputsThe |
||
|
|
||
|
|
||
| @evaluation_test( | ||
| input_dataset=["development/gsm8k_sample.jsonl"], | ||
| completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], | ||
| max_dataset_rows=5, | ||
| passed_threshold=0.0, | ||
| rollout_processor=SingleTurnRolloutProcessor(), | ||
| mode="pointwise", | ||
| evaluation_test_kwargs=[ | ||
| {"math_reward_kwargs": {"tolerance": 0.001, "absolute_tolerance": 1e-8, "require_units": False}} | ||
| ], | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Test Decorator Missing Dataset AdapterThe test decorator is missing the |
||
| def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow: | ||
| """ | ||
| Evaluate math problem solving considering both accuracy and format. | ||
|
|
||
| This function demonstrates how to combine multiple evaluation criteria: | ||
| - Numerical accuracy using built-in math evaluation (80% weight) | ||
| - Format compliance checking for <think>...</think><answer>...</answer> structure (20% weight) | ||
|
|
||
| Args: | ||
| row: EvaluationRow containing the conversation messages and ground truth | ||
| **kwargs: Additional parameters (like math_reward_kwargs) | ||
|
|
||
| Returns: | ||
| EvaluationRow with the evaluation result | ||
| """ | ||
| #### Get predicted answer value | ||
| prediction = extract_answer_digits(str(row.messages[2].content)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Model Response Access Assumes Fixed Message CountHard-coded access to row.messages[2] assumes exactly 3+ messages are present and that the model response is at index 2. This is brittle because: (1) the SingleTurnRolloutProcessor may drop trailing assistant messages, reducing the message count, and (2) the input data structure from the data adapter may have a different number of messages. The code should access the last message instead: row.messages[-1] since that's where the model's response is added. |
||
| gt = extract_answer_digits(str(row.ground_truth)) | ||
|
|
||
| #### Get score | ||
| if prediction is None or gt is None: | ||
| score = 0 | ||
| reason = "Missing answer tags in prediction or ground truth." | ||
|
|
||
| elif gt == prediction: | ||
| score = 1 | ||
| reason = "Model answer is correct." | ||
|
|
||
| else: | ||
| score = 0 | ||
| reason = "Model answer is not correct." | ||
|
|
||
| reason += f" Prediction: {prediction}, Ground Truth: {gt}" | ||
|
|
||
| evaluation_result = EvaluateResult( | ||
| score=score, # Required: The final evaluation score | ||
| is_score_valid=True, # Optional: Whether the score is valid, true by default | ||
| reason=reason, # Optional: The reason for the score | ||
| ) | ||
| row.evaluation_result = evaluation_result | ||
| return row | ||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| import asyncio | ||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
|
|
||
| from eval_protocol.models import EvaluationRow, Message | ||
| from eval_protocol.pytest import SingleTurnRolloutProcessor | ||
|
|
||
|
|
||
| class _DummyConfig: | ||
| def __init__(self): | ||
| self.completion_params = {"model": "fake-model", "temperature": 0} | ||
| self.semaphore = asyncio.Semaphore(10) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_single_turn_drops_trailing_assistant_by_default(monkeypatch): | ||
| # Arrange dataset row with trailing assistant message | ||
| row = EvaluationRow( | ||
| messages=[ | ||
| Message(role="user", content="What is 2+2?"), | ||
| Message(role="assistant", content="Old response"), | ||
| ] | ||
| ) | ||
|
|
||
| # Capture the messages payload passed to the LLM call | ||
| captured = {} | ||
|
|
||
| # Patch module-level imports in the processor module | ||
| import eval_protocol.pytest.default_single_turn_rollout_process as mod | ||
|
|
||
| class StubChoices: | ||
| pass | ||
|
|
||
| class StubModelResponse: | ||
| def __init__(self, text: str): | ||
| self.choices = [StubChoices()] | ||
| # Emulate OpenAI-like response.message fields | ||
| self.choices[0].message = SimpleNamespace(content=text, tool_calls=None) | ||
| # Minimal usage payload | ||
| self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2) | ||
|
|
||
| async def fake_acompletion(**kwargs): | ||
| # Verify that trailing assistant was dropped before sending | ||
| msgs = kwargs.get("messages", []) | ||
| assert msgs, "Expected non-empty messages payload" | ||
| captured["messages"] = msgs | ||
| assert msgs[-1]["role"] != "assistant", "Trailing assistant should be dropped by default" | ||
| return StubModelResponse(text="4") | ||
|
|
||
| # Monkeypatch the processor module's symbols to avoid dependency on litellm types | ||
| monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True) | ||
| monkeypatch.setattr(mod, "Choices", StubChoices, raising=True) | ||
| monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True) | ||
|
|
||
| processor = SingleTurnRolloutProcessor() | ||
| config = _DummyConfig() | ||
|
|
||
| # Act | ||
| tasks = processor([row], config) | ||
| out = await tasks[0] | ||
|
|
||
| # Assert: request trimmed the trailing assistant | ||
| sent_msgs = captured["messages"] | ||
| assert len(sent_msgs) == 1 | ||
| assert sent_msgs[0]["role"] == "user" | ||
| assert out.messages[-1].role == "assistant" | ||
| assert out.messages[-1].content == "4" | ||
| # Ensure previous trailing assistant was not duplicated | ||
| assert [m.role for m in out.messages] == ["user", "assistant"] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_single_turn_keeps_trailing_assistant_when_disabled(monkeypatch): | ||
| # Arrange dataset row with trailing assistant message | ||
| row = EvaluationRow( | ||
| messages=[ | ||
| Message(role="user", content="Say hi"), | ||
| Message(role="assistant", content="Hi!"), | ||
| ] | ||
| ) | ||
|
|
||
| captured = {} | ||
|
|
||
| import eval_protocol.pytest.default_single_turn_rollout_process as mod | ||
|
|
||
| class StubChoices: | ||
| pass | ||
|
|
||
| class StubModelResponse: | ||
| def __init__(self, text: str): | ||
| self.choices = [StubChoices()] | ||
| self.choices[0].message = SimpleNamespace(content=text, tool_calls=None) | ||
| self.usage = SimpleNamespace(prompt_tokens=1, completion_tokens=1, total_tokens=2) | ||
|
|
||
| async def fake_acompletion(**kwargs): | ||
| msgs = kwargs.get("messages", []) | ||
| captured["messages"] = msgs | ||
| # With opt-out, trailing assistant is preserved | ||
| assert msgs[-1]["role"] == "assistant" | ||
| return StubModelResponse(text="Hello again") | ||
|
|
||
| monkeypatch.setattr(mod, "ModelResponse", StubModelResponse, raising=True) | ||
| monkeypatch.setattr(mod, "Choices", StubChoices, raising=True) | ||
| monkeypatch.setattr(mod, "acompletion", fake_acompletion, raising=True) | ||
|
|
||
| processor = SingleTurnRolloutProcessor(drop_trailing_assistant_messages=False) | ||
| config = _DummyConfig() | ||
|
|
||
| # Act | ||
| tasks = processor([row], config) | ||
| out = await tasks[0] | ||
|
|
||
| # Assert: both original messages plus new assistant | ||
| sent_msgs = captured["messages"] | ||
| assert [m["role"] for m in sent_msgs] == ["user", "assistant"] | ||
| assert [m.role for m in out.messages] == ["user", "assistant", "assistant"] | ||
| assert out.messages[-1].content == "Hello again" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Empty Payload After Message Filtering
Missing validation after filtering trailing assistant messages. If all messages in row.messages are assistant messages and drop_trailing_assistant_messages=True, the messages_for_request list becomes empty, resulting in an empty messages_payload being sent to the LLM API. This will fail with an API error rather than being caught by the existing validation on line 42-43. A check should be added after the filtering loop (lines 47-49) to ensure messages_for_request is not empty before proceeding.