From 715beeac9162d72b4757b75c3a15bec6b9a1592d Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Fri, 31 Oct 2025 23:38:05 -0700 Subject: [PATCH] upload from input dataset as well --- eval_protocol/cli_commands/create_rft.py | 81 ++++++++++++++++++- .../pytest/gsm8k/test_pytest_math_example.py | 3 +- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 7a2fe8c4..d8e921d2 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -151,6 +151,72 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> return None +def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) -> Optional[str]: + """Import the test module and extract a JSONL path from input_dataset (dataset_path) param if present. + + Looks for a pytest.mark.parametrize with argnames containing 'dataset_path' and extracts the + first dataset path value. If a relative path is found, it is resolved relative to the directory + of the test file. + """ + try: + import importlib.util + from pathlib import Path + + spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path) + if not spec or not spec.loader: + return None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + if not hasattr(module, test_func_name): + return None + wrapper = getattr(module, test_func_name) + marks = getattr(wrapper, "pytestmark", []) + for m in marks: + if getattr(m, "name", "") == "parametrize": + kwargs = getattr(m, "kwargs", {}) + argnames = kwargs.get("argnames", (m.args[0] if m.args else [])) + argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else [])) + # Normalize argnames to list + if isinstance(argnames, str): + names_list = [n.strip() for n in argnames.split(",") if n.strip()] + else: + names_list = list(argnames) + if "dataset_path" not in names_list: + continue + idx = names_list.index("dataset_path") + # argvalues is a list of tuples/values aligned with argnames + # Get the first value (first test case) + if argvalues: + val = argvalues[0] + # Normalize to tuple + if not isinstance(val, (tuple, list)): + params = (val,) + else: + params = tuple(val) + if idx < len(params): + dataset_path = params[idx] + # dataset_path is typically a string, but could be a list if combine_datasets=True + if isinstance(dataset_path, (list, tuple)) and len(dataset_path) > 0: + dataset_path = dataset_path[0] + if isinstance(dataset_path, str) and dataset_path: + if os.path.isabs(dataset_path): + return dataset_path + base_dir = os.path.dirname(os.path.abspath(test_file_path)) + resolved = os.path.abspath(os.path.join(base_dir, dataset_path)) + if os.path.isfile(resolved): + return resolved + # Try resolving from project root if relative to test file doesn't work + if not os.path.isabs(dataset_path): + # Try resolving from current working directory + cwd_path = os.path.abspath(os.path.join(os.getcwd(), dataset_path)) + if os.path.isfile(cwd_path): + return cwd_path + return None + except Exception: + return None + + def _build_trimmed_dataset_id(evaluator_id: str) -> str: """Build a dataset id derived from evaluator_id, trimmed to 63 chars. @@ -277,11 +343,12 @@ def create_rft_command(args) -> int: dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow if not dataset_id: - # Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test + # Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test if not dataset_jsonl: tests = _discover_tests(project_root) if len(tests) == 1: func_name = tests[0].qualname.split(".")[-1] + # Try data_loaders first (existing behavior) dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name) if dataset_jsonl: # Display relative path for readability @@ -290,9 +357,19 @@ def create_rft_command(args) -> int: except Exception: rel = dataset_jsonl print(f"✓ Using JSONL from data loader: {rel}") + else: + # Fall back to input_dataset (dataset_path) + dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name) + if dataset_jsonl: + # Display relative path for readability + try: + rel = os.path.relpath(dataset_jsonl, project_root) + except Exception: + rel = dataset_jsonl + print(f"✓ Using JSONL from input_dataset: {rel}") if not dataset_jsonl: print( - "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test." + "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test." ) return 1 diff --git a/tests/pytest/gsm8k/test_pytest_math_example.py b/tests/pytest/gsm8k/test_pytest_math_example.py index ec940f5c..63bc7e99 100644 --- a/tests/pytest/gsm8k/test_pytest_math_example.py +++ b/tests/pytest/gsm8k/test_pytest_math_example.py @@ -2,7 +2,6 @@ from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test import os -from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader from typing import List, Dict, Any, Optional import logging @@ -31,7 +30,7 @@ def extract_answer_digits(ground_truth: str) -> Optional[str]: @evaluation_test( - data_loaders=EvaluationRowJsonlDataLoader(jsonl_path=JSONL_PATH), + input_dataset=[JSONL_PATH], completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0,