55import os
66import sys
77import time
8- from typing import Any , Dict , Optional
8+ from typing import Any , Callable , Dict , Optional
99import inspect
1010import requests
11+ import tempfile
1112from pydantic import ValidationError
1213
1314from ..auth import get_fireworks_api_base , get_fireworks_api_key
14- from ..common_utils import get_user_agent
15+ from ..common_utils import get_user_agent , load_jsonl
1516from ..fireworks_rft import (
16- build_default_output_model ,
1717 create_dataset_from_jsonl ,
1818 detect_dataset_builder ,
1919 materialize_dataset_via_builder ,
3131 _normalize_evaluator_id ,
3232 _print_links ,
3333 _resolve_selected_test ,
34+ load_module_from_file_path ,
3435)
3536from .local_test import run_evaluator_test
3637
3738from fireworks import Fireworks
3839
3940
41+ def _extract_dataset_adapter (
42+ test_file_path : str , test_func_name : str
43+ ) -> Optional [Callable [[list [dict [str , Any ]]], Any ]]:
44+ """Extract dataset_adapter from an @evaluation_test wrapper via __ep_params__."""
45+ try :
46+ module = load_module_from_file_path (test_file_path )
47+ wrapper = getattr (module , test_func_name , None )
48+ if wrapper is None :
49+ return None
50+ ep_params = getattr (wrapper , "__ep_params__" , None )
51+ if ep_params is None :
52+ return None
53+ adapter = getattr (ep_params , "dataset_adapter" , None )
54+ if callable (adapter ):
55+ return adapter
56+ return None
57+ except Exception :
58+ return None
59+
60+
61+ def _maybe_transform_dataset_jsonl_via_adapter (
62+ project_root : str ,
63+ dataset_jsonl : str ,
64+ test_file_path : Optional [str ],
65+ test_func_name : Optional [str ],
66+ ) -> str :
67+ """Transform dataset_jsonl via the test's dataset_adapter (when available).
68+
69+ For RFT dataset uploads, we want the uploaded dataset to match what evaluation-time
70+ would run on. If the selected evaluation test provides a dataset_adapter, that
71+ adapter is treated as the source of truth for constructing EvaluationRows.
72+ """
73+ if not dataset_jsonl :
74+ return dataset_jsonl
75+
76+ if not test_file_path or not test_func_name :
77+ return dataset_jsonl
78+
79+ adapter = _extract_dataset_adapter (test_file_path , test_func_name )
80+ if not adapter :
81+ return dataset_jsonl
82+
83+ raw_rows : list [dict [str , Any ]] = load_jsonl (dataset_jsonl ) # type: ignore[assignment]
84+ adapted = adapter (raw_rows )
85+ if not isinstance (adapted , list ):
86+ raise ValueError ("dataset_adapter must return a list of EvaluationRow (or dicts parseable as EvaluationRow)." )
87+
88+ eval_rows : list [EvaluationRow ] = []
89+ for item in adapted :
90+ if isinstance (item , EvaluationRow ):
91+ eval_rows .append (item )
92+ else :
93+ eval_rows .append (EvaluationRow .model_validate (item ))
94+
95+ output_dir = os .path .join (project_root , ".ep_tmp" )
96+ os .makedirs (output_dir , exist_ok = True )
97+ with tempfile .NamedTemporaryFile (
98+ mode = "w" ,
99+ encoding = "utf-8" ,
100+ delete = False ,
101+ suffix = ".jsonl" ,
102+ prefix = "ep_rft_dataset_" ,
103+ dir = output_dir ,
104+ ) as f :
105+ for row in eval_rows :
106+ f .write (json .dumps (row .model_dump (mode = "json" ), ensure_ascii = False ) + "\n " )
107+ out_path = os .path .abspath (f .name )
108+ try :
109+ rel = os .path .relpath (out_path , project_root )
110+ except Exception :
111+ rel = out_path
112+ print (f"✓ Transformed dataset via dataset_adapter into EvaluationRow JSONL: { rel } ({ len (eval_rows )} rows)" )
113+ return out_path
114+
115+
40116def _extract_jsonl_from_dataloader (test_file_path : str , test_func_name : str ) -> Optional [str ]:
41117 """Import the test module and extract a JSONL path from data_loaders param if present.
42118
@@ -45,18 +121,10 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
45121 relative to the directory of the test file.
46122 """
47123 try :
48- import importlib .util
49- from pathlib import Path
50-
51- spec = importlib .util .spec_from_file_location (Path (test_file_path ).stem , test_file_path )
52- if not spec or not spec .loader :
124+ module = load_module_from_file_path (test_file_path )
125+ wrapper = getattr (module , test_func_name , None )
126+ if wrapper is None :
53127 return None
54- module = importlib .util .module_from_spec (spec )
55- sys .modules [spec .name ] = module
56- spec .loader .exec_module (module ) # type: ignore[attr-defined]
57- if not hasattr (module , test_func_name ):
58- return None
59- wrapper = getattr (module , test_func_name )
60128 marks = getattr (wrapper , "pytestmark" , [])
61129 for m in marks :
62130 if getattr (m , "name" , "" ) == "parametrize" :
@@ -105,18 +173,10 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str)
105173 of the test file.
106174 """
107175 try :
108- import importlib .util
109- from pathlib import Path
110-
111- spec = importlib .util .spec_from_file_location (Path (test_file_path ).stem , test_file_path )
112- if not spec or not spec .loader :
113- return None
114- module = importlib .util .module_from_spec (spec )
115- sys .modules [spec .name ] = module
116- spec .loader .exec_module (module ) # type: ignore[attr-defined]
117- if not hasattr (module , test_func_name ):
176+ module = load_module_from_file_path (test_file_path )
177+ wrapper = getattr (module , test_func_name , None )
178+ if wrapper is None :
118179 return None
119- wrapper = getattr (module , test_func_name )
120180 marks = getattr (wrapper , "pytestmark" , [])
121181 for m in marks :
122182 if getattr (m , "name" , "" ) == "parametrize" :
@@ -719,6 +779,16 @@ def create_rft_command(args) -> int:
719779 if dataset_jsonl is None and not dataset_id :
720780 return 1
721781
782+ # 2.5) If the selected evaluation test provides a dataset_adapter, always use it to
783+ # construct the EvaluationRow dataset that we upload for RFT.
784+ if dataset_jsonl is not None :
785+ dataset_jsonl = _maybe_transform_dataset_jsonl_via_adapter (
786+ project_root = project_root ,
787+ dataset_jsonl = dataset_jsonl ,
788+ test_file_path = selected_test_file_path ,
789+ test_func_name = selected_test_func_name ,
790+ )
791+
722792 # 3) Optional local validation
723793 if not skip_validation :
724794 # Dataset validation (JSONL must be EvaluationRow-compatible when present)
0 commit comments