Skip to content

Commit 1489b63

Browse files
committed
gsm8k math example
1 parent 680e719 commit 1489b63

File tree

4 files changed

+255
-10
lines changed

4 files changed

+255
-10
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .dynamic_data_loader import DynamicDataLoader
22
from .inline_data_loader import InlineDataLoader
3+
from .fireworks_dataset_loader import FireworksDatasetLoader
34

4-
__all__ = ["DynamicDataLoader", "InlineDataLoader"]
5+
__all__ = ["DynamicDataLoader", "InlineDataLoader", "FireworksDatasetLoader"]
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import os
5+
import shutil
6+
import subprocess
7+
import tempfile
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
from typing import Callable, Optional, cast
11+
from urllib.parse import quote
12+
13+
import requests
14+
15+
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
16+
from eval_protocol.common_utils import load_jsonl
17+
from eval_protocol.data_loader.models import (
18+
DataLoaderResult,
19+
DataLoaderVariant,
20+
EvaluationDataLoader,
21+
)
22+
from eval_protocol.models import EvaluationRow, JSONType
23+
24+
25+
def _default_dataset_adapter(rows: list[dict[str, object]]) -> list[EvaluationRow]:
26+
"""
27+
Convert Fireworks dataset rows into EvaluationRow.
28+
29+
Preferred shape:
30+
- { messages: [...], ground_truth?: any }
31+
32+
Fallback (legacy demo):
33+
- { user_query: str, ground_truth_for_eval?: any }
34+
"""
35+
converted: list[EvaluationRow] = []
36+
for row in rows:
37+
# Defer import to avoid cycles
38+
from eval_protocol.models import Message
39+
40+
messages = row.get("messages")
41+
ground_truth = cast(JSONType, row.get("ground_truth"))
42+
if ground_truth is None:
43+
ground_truth = cast(JSONType, row.get("ground_truth_for_eval"))
44+
45+
if isinstance(messages, list) and messages:
46+
normalized_messages: list[Message] = []
47+
for m in messages:
48+
if isinstance(m, Message):
49+
normalized_messages.append(m)
50+
elif isinstance(m, dict):
51+
# Let Message handle content types (str or list)
52+
normalized_messages.append(Message.model_validate(m))
53+
converted.append(EvaluationRow(messages=normalized_messages, ground_truth=ground_truth))
54+
continue
55+
56+
# Fallback: single-turn user_query
57+
user_query = str(row.get("user_query", ""))
58+
converted.append(EvaluationRow(messages=[Message(role="user", content=user_query)], ground_truth=ground_truth))
59+
return converted
60+
61+
62+
def _download_fireworks_dataset_jsonl(
63+
dataset_ref: str,
64+
*,
65+
api_key: Optional[str] = None,
66+
api_base: Optional[str] = None,
67+
) -> Path:
68+
"""
69+
Download a Fireworks dataset to a temporary file and return its path.
70+
71+
This mirrors `firectl download dataset <ref>` behavior using HTTP APIs.
72+
We expect a single JSONL file under dataset/<name>/dataset_with_ground_truth_column_*.jsonl
73+
"""
74+
# Prefer firectl if available, as in user's example
75+
firectl_bin = shutil.which("firectl")
76+
if firectl_bin:
77+
tmp_root = Path(tempfile.mkdtemp(prefix="ep_fw_ds_"))
78+
# firectl requires an explicit --output-dir
79+
cmd = [firectl_bin, "download", "dataset", dataset_ref, "--output-dir", str(tmp_root)]
80+
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
81+
if proc.returncode != 0:
82+
raise RuntimeError(f"firectl failed: {proc.stderr or proc.stdout}")
83+
# Expected structure: <tmp_root>/dataset/<name>/*.jsonl
84+
name_part = dataset_ref.split("/datasets/")[-1] if "/datasets/" in dataset_ref else None
85+
candidate_dir = tmp_root / "dataset"
86+
if name_part:
87+
candidate_dir = candidate_dir / name_part
88+
jsonl_files = (
89+
list(candidate_dir.rglob("*.jsonl")) if candidate_dir.exists() else list((tmp_root).rglob("*.jsonl"))
90+
)
91+
if not jsonl_files:
92+
raise RuntimeError("No JSONL files found after firectl download")
93+
# Prefer ground_truth jsonl
94+
jsonl_files.sort(key=lambda p: (0 if "ground_truth" in p.name else 1, 0 if p.suffix == ".jsonl" else 1))
95+
return jsonl_files[0]
96+
97+
# Fallback to HTTP API
98+
resolved_key = api_key or get_fireworks_api_key()
99+
if not resolved_key:
100+
raise RuntimeError("FIREWORKS_API_KEY is required to download Fireworks datasets")
101+
102+
base = (api_base or get_fireworks_api_base()).rstrip("/")
103+
headers = {"Authorization": f"Bearer {resolved_key}"}
104+
105+
encoded_ref = quote(dataset_ref, safe="")
106+
list_url = f"{base}/v1/datasets/{encoded_ref}/files"
107+
resp = requests.get(list_url, headers=headers, timeout=60)
108+
resp.raise_for_status()
109+
payload = resp.json()
110+
files = payload.get("files", []) if isinstance(payload, dict) else []
111+
if not files:
112+
raise RuntimeError(f"No files found for dataset {dataset_ref}")
113+
114+
def _score(name: str) -> tuple[int, int]:
115+
name_lower = name.lower()
116+
return (
117+
0 if "ground_truth" in name_lower else 1,
118+
0 if name_lower.endswith(".jsonl") else 1,
119+
)
120+
121+
files_sorted = sorted(files, key=lambda f: _score(str(f.get("name", ""))))
122+
chosen = None
123+
for f in files_sorted:
124+
name = str(f.get("name", ""))
125+
if name.endswith(".jsonl"):
126+
chosen = f
127+
break
128+
if not chosen:
129+
raise RuntimeError(f"No JSONL file found for dataset {dataset_ref}")
130+
131+
file_id = chosen.get("id") or chosen.get("file_id") or chosen.get("name")
132+
encoded_file = quote(str(file_id), safe="")
133+
dl_url = f"{base}/v1/datasets/{encoded_ref}/files/{encoded_file}:download"
134+
dl_resp = requests.get(dl_url, headers=headers, timeout=60)
135+
dl_resp.raise_for_status()
136+
dl_payload = dl_resp.json()
137+
signed_url = dl_payload.get("url") or dl_payload.get("signed_url")
138+
if not signed_url:
139+
raise RuntimeError("Failed to obtain signed URL for dataset file download")
140+
141+
tmp_dir = Path(tempfile.mkdtemp(prefix="ep_fw_ds_"))
142+
out_path = tmp_dir / Path(str(chosen.get("name", "dataset.jsonl"))).name
143+
with requests.get(str(signed_url), stream=True, timeout=300) as r:
144+
r.raise_for_status()
145+
with open(out_path, "wb") as f:
146+
for chunk in r.iter_content(chunk_size=1 << 16):
147+
if chunk:
148+
f.write(chunk)
149+
150+
return out_path
151+
152+
153+
@dataclass(kw_only=True)
154+
class FireworksDatasetLoader(EvaluationDataLoader):
155+
"""
156+
Data loader that downloads a dataset from Fireworks and emits `EvaluationRow`s.
157+
158+
- dataset_ref: e.g. "accounts/fireworks/datasets/demo-gsm8k-math-dataset-1000"
159+
- dataset_adapter: function to convert list[dict] -> list[EvaluationRow]. If not provided,
160+
defaults to an adapter that expects OpenAI-style `messages` rows and falls back to legacy demo shape.
161+
- max_rows: optional limit on number of rows to emit.
162+
- api_key/api_base: override resolution from environment if needed.
163+
"""
164+
165+
dataset_ref: str
166+
dataset_adapter: Callable[[list[dict[str, object]]], list[EvaluationRow]] | None = None
167+
max_rows: Optional[int] = None
168+
api_key: Optional[str] = None
169+
api_base: Optional[str] = None
170+
id: str = "fireworks"
171+
description: Optional[str] = None
172+
173+
def variants(self) -> list[DataLoaderVariant]:
174+
def _load() -> DataLoaderResult:
175+
jsonl_path = _download_fireworks_dataset_jsonl(
176+
self.dataset_ref, api_key=self.api_key, api_base=self.api_base
177+
)
178+
try:
179+
raw_rows = load_jsonl(str(jsonl_path))
180+
if self.max_rows is not None:
181+
raw_rows = raw_rows[: self.max_rows]
182+
adapter = self.dataset_adapter or _default_dataset_adapter
183+
rows = adapter(raw_rows)
184+
return DataLoaderResult(
185+
rows=rows,
186+
type=self.__class__.__name__,
187+
variant_id=self.id,
188+
variant_description=self.description or f"Fireworks dataset {self.dataset_ref}",
189+
)
190+
finally:
191+
# Clean up temp file directory
192+
try:
193+
# Remove file and its parent temp dir
194+
p = Path(jsonl_path)
195+
parent = p.parent
196+
if p.exists():
197+
p.unlink(missing_ok=True) # type: ignore[arg-type]
198+
# Attempt to remove directory if empty
199+
try:
200+
parent.rmdir()
201+
except OSError:
202+
pass
203+
except Exception:
204+
pass
205+
206+
return [_load]

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
class SingleTurnRolloutProcessor(RolloutProcessor):
2222
"""Single turn rollout processor for direct LLM calls."""
2323

24+
def __init__(self, *, drop_trailing_assistant_messages: bool = True) -> None:
25+
"""
26+
Args:
27+
drop_trailing_assistant_messages: When True (default), strip any trailing
28+
assistant messages from the input conversation before calling the model.
29+
This helps when datasets include previous assistant turns and you want
30+
the model to answer the latest user query.
31+
"""
32+
self.drop_trailing_assistant_messages = drop_trailing_assistant_messages
33+
2434
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
2535
"""Generate single turn rollout tasks and return them for external handling."""
2636
# Do not modify global LiteLLM cache. Disable caching per-request instead.
@@ -32,7 +42,13 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3242
if len(row.messages) == 0:
3343
raise ValueError("Messages is empty. Please provide a non-empty dataset")
3444

35-
messages_payload = [message.model_dump() for message in row.messages]
45+
# Optionally drop trailing assistant messages for single-turn prompts
46+
messages_for_request: List[Message] = list(row.messages)
47+
if self.drop_trailing_assistant_messages:
48+
while messages_for_request and messages_for_request[-1].role == "assistant":
49+
messages_for_request.pop()
50+
51+
messages_payload = [message.model_dump() for message in messages_for_request]
3652

3753
request_params = {"messages": messages_payload, **config.completion_params}
3854
# Ensure caching is disabled only for this request (review feedback)
@@ -114,7 +130,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
114130
except Exception:
115131
pass
116132

117-
messages = list(row.messages) + [
133+
messages = list(messages_for_request) + [
118134
Message(
119135
role="assistant",
120136
content=assistant_content,

tests/pytest/test_pytest_math_example.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult
22
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
3+
from eval_protocol.data_loader import FireworksDatasetLoader
34
from eval_protocol.rewards.math import math_reward
45
from examples.math_example.main import check_think_answer_format
5-
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row
66

77

88
@evaluation_test(
9-
input_dataset=["development/gsm8k_sample.jsonl"],
10-
dataset_adapter=gsm8k_to_evaluation_row,
9+
data_loaders=FireworksDatasetLoader(
10+
dataset_ref="accounts/fireworks/datasets/demo-gsm8k-math-dataset-1000",
11+
id="fw-gsm8k-demo",
12+
description="Fireworks demo GSM8K 1k dataset",
13+
),
1114
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
1215
max_dataset_rows=5,
1316
passed_threshold=0.0,
@@ -32,15 +35,34 @@ def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow:
3235
Returns:
3336
EvaluationRow with the evaluation result
3437
"""
35-
# Get the assistant's response
38+
# Get the assistant's response (coerce to text)
3639
assistant_message = row.messages[-1]
3740
if isinstance(assistant_message, dict):
38-
assistant_response = assistant_message.get("content", "")
41+
content = assistant_message.get("content", "")
3942
else:
40-
assistant_response = assistant_message.content or ""
43+
content = assistant_message.content or ""
44+
45+
def _to_text(val):
46+
if isinstance(val, str):
47+
return val
48+
if isinstance(val, list):
49+
parts = []
50+
for part in val:
51+
if isinstance(part, dict):
52+
t = part.get("text") or part.get("content")
53+
if isinstance(t, str):
54+
parts.append(t)
55+
return "".join(parts)
56+
return str(val) if val is not None else ""
57+
58+
assistant_response = _to_text(content)
4159

4260
# Evaluate numerical accuracy using built-in function
43-
accuracy_result = math_reward(messages=row.messages, ground_truth=row.ground_truth, **kwargs["math_reward_kwargs"])
61+
accuracy_result = math_reward(
62+
messages=row.messages,
63+
ground_truth=str(row.ground_truth) if row.ground_truth is not None else "",
64+
**kwargs["math_reward_kwargs"],
65+
)
4466

4567
# Evaluate format compliance (looking for <think>...</think><answer>...</answer> format)
4668
format_correct = check_think_answer_format(assistant_response)

0 commit comments

Comments
 (0)