Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,72 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
return None


def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) -> Optional[str]:
"""Import the test module and extract a JSONL path from input_dataset (dataset_path) param if present.

Looks for a pytest.mark.parametrize with argnames containing 'dataset_path' and extracts the
first dataset path value. If a relative path is found, it is resolved relative to the directory
of the test file.
"""
try:
import importlib.util
from pathlib import Path

spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
if not hasattr(module, test_func_name):
return None
wrapper = getattr(module, test_func_name)
marks = getattr(wrapper, "pytestmark", [])
for m in marks:
if getattr(m, "name", "") == "parametrize":
kwargs = getattr(m, "kwargs", {})
argnames = kwargs.get("argnames", (m.args[0] if m.args else []))
argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else []))
# Normalize argnames to list
if isinstance(argnames, str):
names_list = [n.strip() for n in argnames.split(",") if n.strip()]
else:
names_list = list(argnames)
if "dataset_path" not in names_list:
continue
idx = names_list.index("dataset_path")
# argvalues is a list of tuples/values aligned with argnames
# Get the first value (first test case)
if argvalues:
val = argvalues[0]
# Normalize to tuple
if not isinstance(val, (tuple, list)):
params = (val,)
else:
params = tuple(val)
if idx < len(params):
dataset_path = params[idx]
# dataset_path is typically a string, but could be a list if combine_datasets=True
if isinstance(dataset_path, (list, tuple)) and len(dataset_path) > 0:
dataset_path = dataset_path[0]
if isinstance(dataset_path, str) and dataset_path:
if os.path.isabs(dataset_path):
return dataset_path
base_dir = os.path.dirname(os.path.abspath(test_file_path))
resolved = os.path.abspath(os.path.join(base_dir, dataset_path))
if os.path.isfile(resolved):
return resolved
# Try resolving from project root if relative to test file doesn't work
if not os.path.isabs(dataset_path):
# Try resolving from current working directory
cwd_path = os.path.abspath(os.path.join(os.getcwd(), dataset_path))
if os.path.isfile(cwd_path):
return cwd_path
return None
except Exception:
return None


def _build_trimmed_dataset_id(evaluator_id: str) -> str:
"""Build a dataset id derived from evaluator_id, trimmed to 63 chars.

Expand Down Expand Up @@ -277,11 +343,12 @@ def create_rft_command(args) -> int:
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow

if not dataset_id:
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test
if not dataset_jsonl:
tests = _discover_tests(project_root)
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
# Try data_loaders first (existing behavior)
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
if dataset_jsonl:
# Display relative path for readability
Expand All @@ -290,9 +357,19 @@ def create_rft_command(args) -> int:
except Exception:
rel = dataset_jsonl
print(f"✓ Using JSONL from data loader: {rel}")
else:
# Fall back to input_dataset (dataset_path)
dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name)
if dataset_jsonl:
# Display relative path for readability
try:
rel = os.path.relpath(dataset_jsonl, project_root)
except Exception:
rel = dataset_jsonl
print(f"✓ Using JSONL from input_dataset: {rel}")
if not dataset_jsonl:
print(
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test."
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
)
return 1

Expand Down
3 changes: 1 addition & 2 deletions tests/pytest/gsm8k/test_pytest_math_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message
from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test
import os
from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader
from typing import List, Dict, Any, Optional
import logging

Expand Down Expand Up @@ -31,7 +30,7 @@ def extract_answer_digits(ground_truth: str) -> Optional[str]:


@evaluation_test(
data_loaders=EvaluationRowJsonlDataLoader(jsonl_path=JSONL_PATH),
input_dataset=[JSONL_PATH],
completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
max_dataset_rows=5,
passed_threshold=0.0,
Expand Down
Loading