Skip to content

Commit 38c4758

Browse files
committed
support dataset adapter
1 parent 8a5d7da commit 38c4758

File tree

4 files changed

+218
-35
lines changed

4 files changed

+218
-35
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 95 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import os
66
import sys
77
import time
8-
from typing import Any, Dict, Optional
8+
from typing import Any, Callable, Dict, Optional
99
import inspect
1010
import requests
11+
import tempfile
1112
from pydantic import ValidationError
1213

1314
from ..auth import get_fireworks_api_base, get_fireworks_api_key
14-
from ..common_utils import get_user_agent
15+
from ..common_utils import get_user_agent, load_jsonl
1516
from ..fireworks_rft import (
16-
build_default_output_model,
1717
create_dataset_from_jsonl,
1818
detect_dataset_builder,
1919
materialize_dataset_via_builder,
@@ -31,12 +31,88 @@
3131
_normalize_evaluator_id,
3232
_print_links,
3333
_resolve_selected_test,
34+
load_module_from_file_path,
3435
)
3536
from .local_test import run_evaluator_test
3637

3738
from fireworks import Fireworks
3839

3940

41+
def _extract_dataset_adapter(
42+
test_file_path: str, test_func_name: str
43+
) -> Optional[Callable[[list[dict[str, Any]]], Any]]:
44+
"""Extract dataset_adapter from an @evaluation_test wrapper via __ep_params__."""
45+
try:
46+
module = load_module_from_file_path(test_file_path)
47+
wrapper = getattr(module, test_func_name, None)
48+
if wrapper is None:
49+
return None
50+
ep_params = getattr(wrapper, "__ep_params__", None)
51+
if ep_params is None:
52+
return None
53+
adapter = getattr(ep_params, "dataset_adapter", None)
54+
if callable(adapter):
55+
return adapter
56+
return None
57+
except Exception:
58+
return None
59+
60+
61+
def _maybe_transform_dataset_jsonl_via_adapter(
62+
project_root: str,
63+
dataset_jsonl: str,
64+
test_file_path: Optional[str],
65+
test_func_name: Optional[str],
66+
) -> str:
67+
"""Transform dataset_jsonl via the test's dataset_adapter (when available).
68+
69+
For RFT dataset uploads, we want the uploaded dataset to match what evaluation-time
70+
would run on. If the selected evaluation test provides a dataset_adapter, that
71+
adapter is treated as the source of truth for constructing EvaluationRows.
72+
"""
73+
if not dataset_jsonl:
74+
return dataset_jsonl
75+
76+
if not test_file_path or not test_func_name:
77+
return dataset_jsonl
78+
79+
adapter = _extract_dataset_adapter(test_file_path, test_func_name)
80+
if not adapter:
81+
return dataset_jsonl
82+
83+
raw_rows: list[dict[str, Any]] = load_jsonl(dataset_jsonl) # type: ignore[assignment]
84+
adapted = adapter(raw_rows)
85+
if not isinstance(adapted, list):
86+
raise ValueError("dataset_adapter must return a list of EvaluationRow (or dicts parseable as EvaluationRow).")
87+
88+
eval_rows: list[EvaluationRow] = []
89+
for item in adapted:
90+
if isinstance(item, EvaluationRow):
91+
eval_rows.append(item)
92+
else:
93+
eval_rows.append(EvaluationRow.model_validate(item))
94+
95+
output_dir = os.path.join(project_root, ".ep_tmp")
96+
os.makedirs(output_dir, exist_ok=True)
97+
with tempfile.NamedTemporaryFile(
98+
mode="w",
99+
encoding="utf-8",
100+
delete=False,
101+
suffix=".jsonl",
102+
prefix="ep_rft_dataset_",
103+
dir=output_dir,
104+
) as f:
105+
for row in eval_rows:
106+
f.write(json.dumps(row.model_dump(mode="json"), ensure_ascii=False) + "\n")
107+
out_path = os.path.abspath(f.name)
108+
try:
109+
rel = os.path.relpath(out_path, project_root)
110+
except Exception:
111+
rel = out_path
112+
print(f"✓ Transformed dataset via dataset_adapter into EvaluationRow JSONL: {rel} ({len(eval_rows)} rows)")
113+
return out_path
114+
115+
40116
def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
41117
"""Import the test module and extract a JSONL path from data_loaders param if present.
42118
@@ -45,18 +121,10 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
45121
relative to the directory of the test file.
46122
"""
47123
try:
48-
import importlib.util
49-
from pathlib import Path
50-
51-
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
52-
if not spec or not spec.loader:
124+
module = load_module_from_file_path(test_file_path)
125+
wrapper = getattr(module, test_func_name, None)
126+
if wrapper is None:
53127
return None
54-
module = importlib.util.module_from_spec(spec)
55-
sys.modules[spec.name] = module
56-
spec.loader.exec_module(module) # type: ignore[attr-defined]
57-
if not hasattr(module, test_func_name):
58-
return None
59-
wrapper = getattr(module, test_func_name)
60128
marks = getattr(wrapper, "pytestmark", [])
61129
for m in marks:
62130
if getattr(m, "name", "") == "parametrize":
@@ -105,18 +173,10 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str)
105173
of the test file.
106174
"""
107175
try:
108-
import importlib.util
109-
from pathlib import Path
110-
111-
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
112-
if not spec or not spec.loader:
113-
return None
114-
module = importlib.util.module_from_spec(spec)
115-
sys.modules[spec.name] = module
116-
spec.loader.exec_module(module) # type: ignore[attr-defined]
117-
if not hasattr(module, test_func_name):
176+
module = load_module_from_file_path(test_file_path)
177+
wrapper = getattr(module, test_func_name, None)
178+
if wrapper is None:
118179
return None
119-
wrapper = getattr(module, test_func_name)
120180
marks = getattr(wrapper, "pytestmark", [])
121181
for m in marks:
122182
if getattr(m, "name", "") == "parametrize":
@@ -719,6 +779,16 @@ def create_rft_command(args) -> int:
719779
if dataset_jsonl is None and not dataset_id:
720780
return 1
721781

782+
# 2.5) If the selected evaluation test provides a dataset_adapter, always use it to
783+
# construct the EvaluationRow dataset that we upload for RFT.
784+
if dataset_jsonl is not None:
785+
dataset_jsonl = _maybe_transform_dataset_jsonl_via_adapter(
786+
project_root=project_root,
787+
dataset_jsonl=dataset_jsonl,
788+
test_file_path=selected_test_file_path,
789+
test_func_name=selected_test_func_name,
790+
)
791+
722792
# 3) Optional local validation
723793
if not skip_validation:
724794
# Dataset validation (JSONL must be EvaluationRow-compatible when present)

eval_protocol/cli_commands/upload.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
from eval_protocol.cli_commands.utils import DiscoveredTest
3-
import importlib.util
43
import os
54
import re
65
import sys
@@ -18,6 +17,7 @@
1817
_discover_tests,
1918
_ensure_account_id,
2019
_get_questionary_style,
20+
load_module_from_file_path,
2121
_normalize_evaluator_id,
2222
_prompt_select,
2323
)
@@ -120,13 +120,8 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]:
120120
source_file_path = os.path.join(cwd, dotted_as_path)
121121

122122
# Load the module from the file path
123-
spec = importlib.util.spec_from_file_location(Path(source_file_path).stem, source_file_path)
124-
if not spec or not spec.loader:
125-
raise ValueError(f"Unable to load module from path: {source_file_path}")
126-
module = importlib.util.module_from_spec(spec)
127-
sys.modules[spec.name] = module
128-
spec.loader.exec_module(module) # type: ignore[attr-defined]
129-
module_name = spec.name
123+
module = load_module_from_file_path(source_file_path)
124+
module_name = getattr(module, "__name__", Path(source_file_path).stem)
130125

131126
if not hasattr(module, func):
132127
raise ValueError(f"Function '{func}' not found in module '{module_name}'")

eval_protocol/cli_commands/utils.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from types import ModuleType
2+
3+
14
import os
25
import ast
36
import sys
@@ -6,23 +9,46 @@
69
import argparse
710
import typing
811
import types
12+
import importlib.util
913
from dataclasses import dataclass
1014
from pathlib import Path
11-
from typing import Any, List, Optional, is_typeddict
15+
from typing import Any, List, Optional
1216
import typing_extensions
1317
import inspect
1418
from collections.abc import Callable
1519
import pytest
1620

1721
from ..auth import (
18-
get_fireworks_account_id,
1922
get_fireworks_api_base,
2023
get_fireworks_api_key,
2124
verify_api_key_and_get_account_id,
2225
)
2326
from ..fireworks_rft import _map_api_host_to_app_host
2427

2528

29+
def load_module_from_file_path(source_file_path: str) -> ModuleType:
30+
"""Load a Python module from an absolute/relative filesystem path.
31+
32+
This mirrors the CLI behavior used by `upload.py` and `create_rft.py`:
33+
- module name is derived from the file stem (e.g. /a/b/foo.py -> foo)
34+
- the module is inserted into sys.modules under that name before exec
35+
"""
36+
abs_path = os.path.abspath(source_file_path)
37+
if not os.path.isfile(abs_path):
38+
raise ValueError(f"File not found: {abs_path}")
39+
if not abs_path.endswith(".py"):
40+
raise ValueError(f"Expected a .py file path, got: {abs_path}")
41+
42+
module_name = Path(abs_path).stem
43+
spec = importlib.util.spec_from_file_location(module_name, abs_path)
44+
if not spec or not spec.loader:
45+
raise ValueError(f"Unable to load module from path: {abs_path}")
46+
module = importlib.util.module_from_spec(spec)
47+
sys.modules[spec.name] = module
48+
spec.loader.exec_module(module) # type: ignore[attr-defined]
49+
return module
50+
51+
2652
def _get_questionary_style():
2753
"""Get the shared questionary style for CLI prompts - minimal and clean."""
2854
try:

tests/test_cli_create_rft.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,3 +1206,95 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
12061206
assert captured["jsonl_path"] != str(inferred_jsonl)
12071207
# And because --dataset-jsonl was provided, we should never call the input_dataset extractor
12081208
assert calls["input_dataset"] == 0
1209+
1210+
1211+
def test_create_rft_transforms_raw_input_dataset_via_dataset_adapter_before_upload(rft_test_harness, monkeypatch):
1212+
project = rft_test_harness
1213+
1214+
# Create a real @evaluation_test-decorated module so create_rft can extract __ep_params__.dataset_adapter
1215+
metric_dir = project / "metric"
1216+
metric_dir.mkdir(parents=True, exist_ok=True)
1217+
1218+
raw_jsonl = metric_dir / "raw.jsonl"
1219+
raw_jsonl.write_text('{"q":"hi","a":"ok"}\n{"q":"yo","a":"ok2"}\n', encoding="utf-8")
1220+
1221+
test_file = metric_dir / "test_adapt.py"
1222+
test_file.write_text(
1223+
"""
1224+
from typing import Any
1225+
from eval_protocol.models import EvaluationRow, Message
1226+
from eval_protocol.pytest import evaluation_test
1227+
1228+
def my_adapter(rows: list[dict[str, Any]]) -> list[EvaluationRow]:
1229+
return [
1230+
EvaluationRow(messages=[Message(role="user", content=r["q"])], ground_truth=r.get("a"))
1231+
for r in rows
1232+
]
1233+
1234+
@evaluation_test(
1235+
input_dataset=["raw.jsonl"],
1236+
dataset_adapter=my_adapter,
1237+
num_runs=1,
1238+
max_dataset_rows=2,
1239+
mode="pointwise",
1240+
)
1241+
def test_adapt(row: EvaluationRow) -> EvaluationRow:
1242+
return row
1243+
""".lstrip(),
1244+
encoding="utf-8",
1245+
)
1246+
1247+
# Discovery: exactly one test, and resolve_selected_test points to our module/function
1248+
single_disc = SimpleNamespace(qualname="metric.test_adapt.test_adapt", file_path=str(test_file))
1249+
monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc])
1250+
monkeypatch.setattr(
1251+
cr,
1252+
"_resolve_selected_test",
1253+
lambda project_root, evaluator_id, selected_tests=None: (str(test_file), "test_adapt"),
1254+
)
1255+
1256+
captured = {"jsonl_path": None}
1257+
1258+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
1259+
captured["jsonl_path"] = jsonl_path
1260+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
1261+
1262+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
1263+
1264+
# Ensure upload path doesn't touch the network; job creation via stub_fireworks fixture
1265+
args = argparse.Namespace(
1266+
evaluator=None,
1267+
yes=True,
1268+
dry_run=False,
1269+
force=False,
1270+
env_file=None,
1271+
dataset=None,
1272+
dataset_jsonl=None,
1273+
dataset_display_name=None,
1274+
dataset_builder=None,
1275+
base_model=None,
1276+
warm_start_from="accounts/acct123/models/ft-abc123",
1277+
output_model=None,
1278+
n=None,
1279+
max_tokens=None,
1280+
learning_rate=None,
1281+
batch_size=None,
1282+
epochs=None,
1283+
lora_rank=None,
1284+
max_context_length=None,
1285+
chunk_size=None,
1286+
eval_auto_carveout=None,
1287+
skip_validation=True,
1288+
ignore_docker=False,
1289+
docker_build_extra="",
1290+
docker_run_extra="",
1291+
)
1292+
1293+
rc = cr.create_rft_command(args)
1294+
assert rc == 0
1295+
assert captured["jsonl_path"] is not None
1296+
# Raw JSONL should NOT be uploaded; transformed EvaluationRow JSONL should be.
1297+
assert os.path.abspath(captured["jsonl_path"]) != os.path.abspath(str(raw_jsonl))
1298+
assert os.path.basename(captured["jsonl_path"]).endswith(".jsonl")
1299+
# The transformed file should validate as EvaluationRow JSONL
1300+
assert cr._validate_dataset_jsonl(captured["jsonl_path"])

0 commit comments

Comments
 (0)