Skip to content

Commit a533dcb

Browse files
committed
fix bug
1 parent fa46b85 commit a533dcb

File tree

2 files changed

+205
-12
lines changed

2 files changed

+205
-12
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
create_dataset_from_jsonl,
2121
create_reinforcement_fine_tuning_job,
2222
)
23-
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
23+
from .upload import _discover_tests, _normalize_evaluator_id, _prompt_select
2424

2525

2626
def _ensure_account_id() -> Optional[str]:
@@ -331,7 +331,6 @@ def create_rft_command(args) -> int:
331331

332332
# Resolve evaluator id/entry if omitted (reuse upload's selector flow)
333333
project_root = os.getcwd()
334-
preselected_entry: Optional[str] = None
335334
if not evaluator_id:
336335
print("Scanning for evaluation tests...")
337336
tests = _discover_tests(project_root)
@@ -341,9 +340,7 @@ def create_rft_command(args) -> int:
341340
return 1
342341
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
343342
try:
344-
from .upload import _prompt_select # reuse the same selector UX as 'upload'
345-
346-
selected_tests = _prompt_select(tests, non_interactive=False)
343+
selected_tests = _prompt_select(tests, non_interactive=non_interactive)
347344
except Exception:
348345
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
349346
return 1
@@ -355,12 +352,6 @@ def create_rft_command(args) -> int:
355352
return 1
356353
chosen = selected_tests[0]
357354
func_name = chosen.qualname.split(".")[-1]
358-
abs_path = os.path.abspath(chosen.file_path)
359-
try:
360-
rel = os.path.relpath(abs_path, project_root)
361-
except Exception:
362-
rel = abs_path
363-
preselected_entry = f"{rel}::{func_name}"
364355
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
365356
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
366357
# Resolve evaluator resource name to fully-qualified format required by API

tests/test_cli_create_rft_infer.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
8686

8787
# Assert dataset id derived from selected test: metric-test_single
8888
assert captured["dataset_id"] is not None
89-
assert captured["dataset_id"].startswith("metric-test-single-dataset-")
89+
assert captured["dataset_id"].startswith("test-single-test-single-dataset-")
9090

9191

9292
def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(tmp_path, monkeypatch):
@@ -184,3 +184,205 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
184184
+ "-dataset-"
185185
)
186186
assert captured["dataset_id"].startswith(expected_prefix)
187+
188+
189+
def test_create_rft_interactive_selector_single_test(tmp_path, monkeypatch):
190+
# Setup project
191+
project = tmp_path / "proj"
192+
project.mkdir()
193+
monkeypatch.chdir(project)
194+
195+
# Single discovered test
196+
test_file = project / "metric" / "test_one.py"
197+
test_file.parent.mkdir(parents=True, exist_ok=True)
198+
test_file.write_text("# one", encoding="utf-8")
199+
single_disc = SimpleNamespace(qualname="metric.test_one", file_path=str(test_file))
200+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc])
201+
202+
# Environment
203+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
204+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
205+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
206+
207+
# Stub selector to return the single test; stub upload and polling
208+
import eval_protocol.cli_commands.upload as upload_mod
209+
210+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
211+
captured = {"id": None, "entry": None, "dataset_id": None}
212+
213+
def _fake_upload(ns):
214+
captured["id"] = getattr(ns, "id", None)
215+
captured["entry"] = getattr(ns, "entry", None)
216+
return 0
217+
218+
monkeypatch.setattr(upload_mod, "upload_command", _fake_upload)
219+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
220+
221+
# Provide dataset jsonl
222+
ds_path = project / "metric" / "dataset.jsonl"
223+
ds_path.write_text('{"input":"x"}\n', encoding="utf-8")
224+
monkeypatch.setattr(
225+
cr,
226+
"create_dataset_from_jsonl",
227+
lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: (
228+
dataset_id,
229+
{"name": f"accounts/{account_id}/datasets/{dataset_id}"},
230+
),
231+
)
232+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
233+
234+
# Run without evaluator_id; use --yes so selector returns tests directly (no UI)
235+
import argparse
236+
237+
args = argparse.Namespace(
238+
evaluator_id=None,
239+
yes=True,
240+
dry_run=False,
241+
force=False,
242+
env_file=None,
243+
dataset_id=None,
244+
dataset_jsonl=str(ds_path),
245+
dataset_display_name=None,
246+
dataset_builder=None,
247+
base_model=None,
248+
warm_start_from="accounts/acct123/models/ft-abc123",
249+
output_model=None,
250+
n=None,
251+
max_tokens=None,
252+
learning_rate=None,
253+
batch_size=None,
254+
epochs=None,
255+
lora_rank=None,
256+
max_context_length=None,
257+
chunk_size=None,
258+
eval_auto_carveout=None,
259+
)
260+
261+
rc = cr.create_rft_command(args)
262+
assert rc == 0
263+
assert captured["id"] is not None
264+
assert captured["entry"] is not None and captured["entry"].endswith("test_one.py::test_one")
265+
266+
267+
def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch):
268+
project = tmp_path / "proj"
269+
project.mkdir()
270+
monkeypatch.chdir(project)
271+
272+
# Env
273+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
274+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
275+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
276+
277+
# Mock evaluator exists and is ACTIVE
278+
class _Resp:
279+
ok = True
280+
281+
def json(self):
282+
return {"state": "ACTIVE"}
283+
284+
def raise_for_status(self):
285+
return None
286+
287+
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
288+
289+
# Provide dataset via --dataset-jsonl so no test discovery needed
290+
ds_path = project / "dataset.jsonl"
291+
ds_path.write_text('{"input":"x"}\n', encoding="utf-8")
292+
monkeypatch.setattr(
293+
cr,
294+
"create_dataset_from_jsonl",
295+
lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: (
296+
dataset_id,
297+
{"name": f"accounts/{account_id}/datasets/{dataset_id}"},
298+
),
299+
)
300+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
301+
302+
import argparse
303+
304+
args = argparse.Namespace(
305+
evaluator_id="some-eval",
306+
yes=True,
307+
dry_run=False,
308+
force=False,
309+
env_file=None,
310+
dataset_id=None,
311+
dataset_jsonl=str(ds_path),
312+
dataset_display_name=None,
313+
dataset_builder=None,
314+
base_model=None,
315+
warm_start_from="accounts/acct123/models/ft-abc123",
316+
output_model=None,
317+
n=None,
318+
max_tokens=None,
319+
learning_rate=None,
320+
batch_size=None,
321+
epochs=None,
322+
lora_rank=None,
323+
max_context_length=None,
324+
chunk_size=None,
325+
eval_auto_carveout=None,
326+
)
327+
328+
rc = cr.create_rft_command(args)
329+
assert rc == 0
330+
331+
332+
def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch):
333+
project = tmp_path / "proj"
334+
project.mkdir()
335+
monkeypatch.chdir(project)
336+
337+
# Env
338+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
339+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
340+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
341+
342+
# Evaluator does not exist (force path into upload section)
343+
def _raise(*a, **k):
344+
raise requests.exceptions.RequestException("nope")
345+
346+
import requests
347+
348+
monkeypatch.setattr(cr.requests, "get", _raise)
349+
350+
# Two discovered tests (ambiguous)
351+
f1 = project / "a.py"
352+
f2 = project / "b.py"
353+
f1.write_text("# a", encoding="utf-8")
354+
f2.write_text("# b", encoding="utf-8")
355+
d1 = SimpleNamespace(qualname="a.test_one", file_path=str(f1))
356+
d2 = SimpleNamespace(qualname="b.test_two", file_path=str(f2))
357+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2])
358+
359+
import argparse
360+
361+
args = argparse.Namespace(
362+
evaluator_id="some-eval",
363+
yes=True,
364+
dry_run=False,
365+
force=False,
366+
env_file=None,
367+
dataset_id=None,
368+
dataset_jsonl=str(project / "dataset.jsonl"),
369+
dataset_display_name=None,
370+
dataset_builder=None,
371+
base_model=None,
372+
warm_start_from="accounts/acct123/models/ft-abc123",
373+
output_model=None,
374+
n=None,
375+
max_tokens=None,
376+
learning_rate=None,
377+
batch_size=None,
378+
epochs=None,
379+
lora_rank=None,
380+
max_context_length=None,
381+
chunk_size=None,
382+
eval_auto_carveout=None,
383+
)
384+
# create the dataset file so we don't fail earlier
385+
(project / "dataset.jsonl").write_text('{"input":"x"}\n', encoding="utf-8")
386+
387+
rc = cr.create_rft_command(args)
388+
assert rc == 1

0 commit comments

Comments
 (0)