Skip to content

Commit d8477be

Browse files
authored
upload from input dataset as well (#306)
1 parent 8ec17d1 commit d8477be

File tree

2 files changed

+80
-4
lines changed

2 files changed

+80
-4
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,72 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
151151
return None
152152

153153

154+
def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) -> Optional[str]:
155+
"""Import the test module and extract a JSONL path from input_dataset (dataset_path) param if present.
156+
157+
Looks for a pytest.mark.parametrize with argnames containing 'dataset_path' and extracts the
158+
first dataset path value. If a relative path is found, it is resolved relative to the directory
159+
of the test file.
160+
"""
161+
try:
162+
import importlib.util
163+
from pathlib import Path
164+
165+
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
166+
if not spec or not spec.loader:
167+
return None
168+
module = importlib.util.module_from_spec(spec)
169+
sys.modules[spec.name] = module
170+
spec.loader.exec_module(module) # type: ignore[attr-defined]
171+
if not hasattr(module, test_func_name):
172+
return None
173+
wrapper = getattr(module, test_func_name)
174+
marks = getattr(wrapper, "pytestmark", [])
175+
for m in marks:
176+
if getattr(m, "name", "") == "parametrize":
177+
kwargs = getattr(m, "kwargs", {})
178+
argnames = kwargs.get("argnames", (m.args[0] if m.args else []))
179+
argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else []))
180+
# Normalize argnames to list
181+
if isinstance(argnames, str):
182+
names_list = [n.strip() for n in argnames.split(",") if n.strip()]
183+
else:
184+
names_list = list(argnames)
185+
if "dataset_path" not in names_list:
186+
continue
187+
idx = names_list.index("dataset_path")
188+
# argvalues is a list of tuples/values aligned with argnames
189+
# Get the first value (first test case)
190+
if argvalues:
191+
val = argvalues[0]
192+
# Normalize to tuple
193+
if not isinstance(val, (tuple, list)):
194+
params = (val,)
195+
else:
196+
params = tuple(val)
197+
if idx < len(params):
198+
dataset_path = params[idx]
199+
# dataset_path is typically a string, but could be a list if combine_datasets=True
200+
if isinstance(dataset_path, (list, tuple)) and len(dataset_path) > 0:
201+
dataset_path = dataset_path[0]
202+
if isinstance(dataset_path, str) and dataset_path:
203+
if os.path.isabs(dataset_path):
204+
return dataset_path
205+
base_dir = os.path.dirname(os.path.abspath(test_file_path))
206+
resolved = os.path.abspath(os.path.join(base_dir, dataset_path))
207+
if os.path.isfile(resolved):
208+
return resolved
209+
# Try resolving from project root if relative to test file doesn't work
210+
if not os.path.isabs(dataset_path):
211+
# Try resolving from current working directory
212+
cwd_path = os.path.abspath(os.path.join(os.getcwd(), dataset_path))
213+
if os.path.isfile(cwd_path):
214+
return cwd_path
215+
return None
216+
except Exception:
217+
return None
218+
219+
154220
def _build_trimmed_dataset_id(evaluator_id: str) -> str:
155221
"""Build a dataset id derived from evaluator_id, trimmed to 63 chars.
156222
@@ -277,11 +343,12 @@ def create_rft_command(args) -> int:
277343
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
278344

279345
if not dataset_id:
280-
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test
346+
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test
281347
if not dataset_jsonl:
282348
tests = _discover_tests(project_root)
283349
if len(tests) == 1:
284350
func_name = tests[0].qualname.split(".")[-1]
351+
# Try data_loaders first (existing behavior)
285352
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
286353
if dataset_jsonl:
287354
# Display relative path for readability
@@ -290,9 +357,19 @@ def create_rft_command(args) -> int:
290357
except Exception:
291358
rel = dataset_jsonl
292359
print(f"✓ Using JSONL from data loader: {rel}")
360+
else:
361+
# Fall back to input_dataset (dataset_path)
362+
dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name)
363+
if dataset_jsonl:
364+
# Display relative path for readability
365+
try:
366+
rel = os.path.relpath(dataset_jsonl, project_root)
367+
except Exception:
368+
rel = dataset_jsonl
369+
print(f"✓ Using JSONL from input_dataset: {rel}")
293370
if not dataset_jsonl:
294371
print(
295-
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test."
372+
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
296373
)
297374
return 1
298375

tests/pytest/gsm8k/test_pytest_math_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message
33
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
44
import os
5-
from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader
65
from typing import List, Dict, Any, Optional
76
import logging
87

@@ -31,7 +30,7 @@ def extract_answer_digits(ground_truth: str) -> Optional[str]:
3130

3231

3332
@evaluation_test(
34-
data_loaders=EvaluationRowJsonlDataLoader(jsonl_path=JSONL_PATH),
33+
input_dataset=[JSONL_PATH],
3534
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
3635
max_dataset_rows=5,
3736
passed_threshold=0.0,

0 commit comments

Comments
 (0)