Skip to content

Commit fa46b85

Browse files
committed
change create rft command to use the selector
1 parent 69e53a7 commit fa46b85

File tree

2 files changed

+51
-265
lines changed

2 files changed

+51
-265
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 33 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -23,87 +23,6 @@
2323
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2424

2525

26-
def _last_evaluator_paths(cwd: str) -> list[str]:
27-
return [
28-
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
29-
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
30-
]
31-
32-
33-
def _load_last_evaluator(cwd: str) -> Optional[str]:
34-
import json
35-
36-
for p in _last_evaluator_paths(cwd):
37-
try:
38-
if os.path.isfile(p):
39-
with open(p, "r", encoding="utf-8") as f:
40-
data = json.load(f)
41-
if isinstance(data, dict) and data.get("evaluator_id"):
42-
return str(data["evaluator_id"])
43-
except Exception:
44-
# ignore and continue
45-
pass
46-
return None
47-
48-
49-
def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
50-
import json
51-
52-
base = os.path.join(cwd, ".eval_protocol")
53-
try:
54-
os.makedirs(base, exist_ok=True)
55-
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
56-
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
57-
except Exception:
58-
# best-effort only
59-
pass
60-
61-
62-
def _gather_evaluator_traces(cwd: str) -> list[dict]:
63-
roots = [
64-
os.path.join(cwd, ".eval_protocol", "evaluators"),
65-
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
66-
]
67-
records: list[dict] = []
68-
for root in roots:
69-
if os.path.isdir(root):
70-
for name in os.listdir(root):
71-
if name.endswith(".json"):
72-
full = os.path.join(root, name)
73-
try:
74-
mtime = os.path.getmtime(full)
75-
except Exception:
76-
mtime = 0.0
77-
records.append({"id": name[:-5], "path": full, "mtime": mtime})
78-
# dedupe by id keeping most recent mtime
79-
dedup: dict[str, dict] = {}
80-
for rec in records:
81-
cur = dedup.get(rec["id"])
82-
if not cur or rec["mtime"] > cur["mtime"]:
83-
dedup[rec["id"]] = rec
84-
return list(dedup.values())
85-
86-
87-
def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
88-
print("\nMultiple evaluators detected. Select one:")
89-
ordered = sorted(candidates, key=lambda x: -x["mtime"])
90-
for i, c in enumerate(ordered, start=1):
91-
print(f" {i}) {c['id']} (from {c['path']})")
92-
try:
93-
choice = input("Enter a number (or press Enter to cancel): ").strip()
94-
except KeyboardInterrupt:
95-
print("\nCancelled.")
96-
return None
97-
if not choice or not choice.isdigit():
98-
return None
99-
n = int(choice)
100-
if 1 <= n <= len(ordered):
101-
sel = ordered[n - 1]["id"]
102-
print(f"✓ Using evaluator: {sel}")
103-
return sel
104-
return None
105-
106-
10726
def _ensure_account_id() -> Optional[str]:
10827
account_id = get_fireworks_account_id()
10928
api_key = get_fireworks_api_key()
@@ -331,37 +250,6 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
331250
return f"{base}{suffix}"
332251

333252

334-
def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
335-
# 1) Use last used pointer if available
336-
last = _load_last_evaluator(cwd)
337-
if last:
338-
return last
339-
340-
# 2) Look for evaluator traces in project and home
341-
traces = _gather_evaluator_traces(cwd)
342-
if len(traces) == 1:
343-
return traces[0]["id"]
344-
if len(traces) > 1:
345-
if non_interactive:
346-
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
347-
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
348-
return sel
349-
chosen = _prompt_select_evaluator(traces)
350-
if chosen:
351-
return chosen
352-
return None
353-
354-
# 3) Fall back to discovering a single evaluation_test
355-
tests = _discover_tests(cwd)
356-
if len(tests) == 1:
357-
qualname, source_file_path = tests[0].qualname, tests[0].file_path
358-
test_func_name = qualname.split(".")[-1]
359-
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
360-
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
361-
return evaluator_id
362-
return None
363-
364-
365253
def _poll_evaluator_status(
366254
evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10
367255
) -> bool:
@@ -441,13 +329,40 @@ def create_rft_command(args) -> int:
441329

442330
api_base = get_fireworks_api_base()
443331

444-
# Resolve evaluator id if omitted
332+
# Resolve evaluator id/entry if omitted (reuse upload's selector flow)
445333
project_root = os.getcwd()
334+
preselected_entry: Optional[str] = None
446335
if not evaluator_id:
447-
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
448-
if not evaluator_id:
449-
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
336+
print("Scanning for evaluation tests...")
337+
tests = _discover_tests(project_root)
338+
if not tests:
339+
print("No evaluation tests found.")
340+
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
341+
return 1
342+
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
343+
try:
344+
from .upload import _prompt_select # reuse the same selector UX as 'upload'
345+
346+
selected_tests = _prompt_select(tests, non_interactive=False)
347+
except Exception:
348+
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
450349
return 1
350+
if not selected_tests:
351+
print("No tests selected.")
352+
return 1
353+
if len(selected_tests) != 1:
354+
print("Error: Please select exactly one evaluation test for 'create rft'.")
355+
return 1
356+
chosen = selected_tests[0]
357+
func_name = chosen.qualname.split(".")[-1]
358+
abs_path = os.path.abspath(chosen.file_path)
359+
try:
360+
rel = os.path.relpath(abs_path, project_root)
361+
except Exception:
362+
rel = abs_path
363+
preselected_entry = f"{rel}::{func_name}"
364+
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
365+
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
451366
# Resolve evaluator resource name to fully-qualified format required by API
452367
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
453368

@@ -479,7 +394,6 @@ def create_rft_command(args) -> int:
479394
print(f"📊 Please check the evaluator status at: {dashboard_url}")
480395
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
481396
return 1
482-
_save_last_evaluator(project_root, evaluator_id)
483397
skip_upload = True
484398
except requests.exceptions.RequestException:
485399
pass
@@ -561,8 +475,8 @@ def create_rft_command(args) -> int:
561475
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
562476
return 1
563477
else:
564-
# Only persist last-used evaluator after successful ensure + ACTIVE
565-
_save_last_evaluator(project_root, evaluator_id)
478+
# Evaluator ACTIVE; proceed
479+
pass
566480
else:
567481
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
568482
except Exception as e:

0 commit comments

Comments
 (0)