Skip to content

Commit c1df8b5

Browse files
benjibcxzrderek
andauthored
auto select evaluators correctly (#323)
* auto select evaluators correctly * add new test to verify dataset id and fix code * try skipping if possible * fix --------- Co-authored-by: Derek Xu <xzrderek@gmail.com>
1 parent cf59d13 commit c1df8b5

File tree

2 files changed

+528
-70
lines changed

2 files changed

+528
-70
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 214 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,87 @@
2323
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2424

2525

26+
def _last_evaluator_paths(cwd: str) -> list[str]:
27+
return [
28+
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
29+
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
30+
]
31+
32+
33+
def _load_last_evaluator(cwd: str) -> Optional[str]:
34+
import json
35+
36+
for p in _last_evaluator_paths(cwd):
37+
try:
38+
if os.path.isfile(p):
39+
with open(p, "r", encoding="utf-8") as f:
40+
data = json.load(f)
41+
if isinstance(data, dict) and data.get("evaluator_id"):
42+
return str(data["evaluator_id"])
43+
except Exception:
44+
# ignore and continue
45+
pass
46+
return None
47+
48+
49+
def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
50+
import json
51+
52+
base = os.path.join(cwd, ".eval_protocol")
53+
try:
54+
os.makedirs(base, exist_ok=True)
55+
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
56+
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
57+
except Exception:
58+
# best-effort only
59+
pass
60+
61+
62+
def _gather_evaluator_traces(cwd: str) -> list[dict]:
63+
roots = [
64+
os.path.join(cwd, ".eval_protocol", "evaluators"),
65+
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
66+
]
67+
records: list[dict] = []
68+
for root in roots:
69+
if os.path.isdir(root):
70+
for name in os.listdir(root):
71+
if name.endswith(".json"):
72+
full = os.path.join(root, name)
73+
try:
74+
mtime = os.path.getmtime(full)
75+
except Exception:
76+
mtime = 0.0
77+
records.append({"id": name[:-5], "path": full, "mtime": mtime})
78+
# dedupe by id keeping most recent mtime
79+
dedup: dict[str, dict] = {}
80+
for rec in records:
81+
cur = dedup.get(rec["id"])
82+
if not cur or rec["mtime"] > cur["mtime"]:
83+
dedup[rec["id"]] = rec
84+
return list(dedup.values())
85+
86+
87+
def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
88+
print("\nMultiple evaluators detected. Select one:")
89+
ordered = sorted(candidates, key=lambda x: -x["mtime"])
90+
for i, c in enumerate(ordered, start=1):
91+
print(f" {i}) {c['id']} (from {c['path']})")
92+
try:
93+
choice = input("Enter a number (or press Enter to cancel): ").strip()
94+
except KeyboardInterrupt:
95+
print("\nCancelled.")
96+
return None
97+
if not choice or not choice.isdigit():
98+
return None
99+
n = int(choice)
100+
if 1 <= n <= len(ordered):
101+
sel = ordered[n - 1]["id"]
102+
print(f"✓ Using evaluator: {sel}")
103+
return sel
104+
return None
105+
106+
26107
def _ensure_account_id() -> Optional[str]:
27108
account_id = get_fireworks_account_id()
28109
api_key = get_fireworks_api_key()
@@ -240,6 +321,8 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
240321
if not base:
241322
base = "dataset"
242323
# Ensure first char is a letter
324+
if not base:
325+
base = "dataset"
243326
if not base[0].isalpha():
244327
base = f"eval-{base}"
245328
if len(base) > max_base_len:
@@ -248,14 +331,27 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
248331
return f"{base}{suffix}"
249332

250333

251-
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
252-
# Try local traces
253-
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
254-
if os.path.isdir(traces_dir):
255-
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
256-
if len(candidates) == 1:
257-
return candidates[0]
258-
# Fall back to discovering a single evaluation_test
334+
def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
335+
# 1) Use last used pointer if available
336+
last = _load_last_evaluator(cwd)
337+
if last:
338+
return last
339+
340+
# 2) Look for evaluator traces in project and home
341+
traces = _gather_evaluator_traces(cwd)
342+
if len(traces) == 1:
343+
return traces[0]["id"]
344+
if len(traces) > 1:
345+
if non_interactive:
346+
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
347+
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
348+
return sel
349+
chosen = _prompt_select_evaluator(traces)
350+
if chosen:
351+
return chosen
352+
return None
353+
354+
# 3) Fall back to discovering a single evaluation_test
259355
tests = _discover_tests(cwd)
260356
if len(tests) == 1:
261357
qualname, source_file_path = tests[0].qualname, tests[0].file_path
@@ -348,81 +444,129 @@ def create_rft_command(args) -> int:
348444
# Resolve evaluator id if omitted
349445
project_root = os.getcwd()
350446
if not evaluator_id:
351-
evaluator_id = _auto_select_evaluator_id(project_root)
447+
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
352448
if not evaluator_id:
353449
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
354450
return 1
355-
356451
# Resolve evaluator resource name to fully-qualified format required by API
357452
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
358453

454+
# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
455+
skip_upload = False
456+
if not force:
457+
try:
458+
headers = {
459+
"Authorization": f"Bearer {api_key}",
460+
"Content-Type": "application/json",
461+
"User-Agent": get_user_agent(),
462+
}
463+
resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10)
464+
if resp.ok:
465+
state = resp.json().get("state", "STATE_UNSPECIFIED")
466+
print(f"✓ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).")
467+
# Poll for ACTIVE before proceeding
468+
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
469+
if not _poll_evaluator_status(
470+
evaluator_resource_name=evaluator_resource_name,
471+
api_key=api_key,
472+
api_base=api_base,
473+
timeout_minutes=10,
474+
):
475+
app_base = _map_api_host_to_app_host(api_base)
476+
evaluator_slug = _extract_terminal_segment(evaluator_id)
477+
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"
478+
print("\n❌ Evaluator is not ready within the timeout period.")
479+
print(f"📊 Please check the evaluator status at: {dashboard_url}")
480+
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
481+
return 1
482+
_save_last_evaluator(project_root, evaluator_id)
483+
skip_upload = True
484+
except requests.exceptions.RequestException:
485+
pass
486+
359487
# Ensure evaluator exists by invoking the upload flow programmatically
360-
try:
361-
from .upload import upload_command
488+
if not skip_upload:
489+
try:
490+
from .upload import upload_command
362491

363-
tests = _discover_tests(project_root)
364-
selected_entry: Optional[str] = None
365-
if len(tests) == 1:
366-
func_name = tests[0].qualname.split(".")[-1]
367-
abs_path = os.path.abspath(tests[0].file_path)
368-
try:
369-
rel = os.path.relpath(abs_path, project_root)
370-
except Exception:
371-
rel = abs_path
372-
selected_entry = f"{rel}::{func_name}"
373-
else:
374-
# Try to match evaluator_id to a discovered test's normalized ID
375-
for t in tests:
376-
func_name = t.qualname.split(".")[-1]
377-
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
378-
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
379-
if candidate == evaluator_id:
380-
abs_path = os.path.abspath(t.file_path)
381-
try:
382-
rel = os.path.relpath(abs_path, project_root)
383-
except Exception:
384-
rel = abs_path
385-
selected_entry = f"{rel}::{func_name}"
386-
break
387-
388-
upload_args = argparse.Namespace(
389-
path=project_root,
390-
entry=selected_entry,
391-
id=evaluator_id,
392-
display_name=None,
393-
description=None,
394-
force=force, # Pass through the --force flag
395-
yes=True,
396-
env_file=None, # Add the new env_file parameter
397-
)
492+
tests = _discover_tests(project_root)
493+
selected_entry: Optional[str] = None
494+
if len(tests) == 1:
495+
func_name = tests[0].qualname.split(".")[-1]
496+
abs_path = os.path.abspath(tests[0].file_path)
497+
try:
498+
rel = os.path.relpath(abs_path, project_root)
499+
except Exception:
500+
rel = abs_path
501+
selected_entry = f"{rel}::{func_name}"
502+
else:
503+
# Try to match evaluator_id to a discovered test's normalized ID
504+
for t in tests:
505+
func_name = t.qualname.split(".")[-1]
506+
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
507+
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
508+
if candidate == evaluator_id:
509+
abs_path = os.path.abspath(t.file_path)
510+
try:
511+
rel = os.path.relpath(abs_path, project_root)
512+
except Exception:
513+
rel = abs_path
514+
selected_entry = f"{rel}::{func_name}"
515+
break
516+
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
517+
if selected_entry is None and len(tests) > 1:
518+
print(
519+
f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n"
520+
" Please re-run specifying the evaluator id.\n"
521+
" Hints:\n"
522+
" - eval-protocol create rft --evaluator-id <existing-evaluator-id>\n"
523+
)
524+
return 1
398525

399-
if force:
400-
print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'")
526+
upload_args = argparse.Namespace(
527+
path=project_root,
528+
entry=selected_entry,
529+
id=evaluator_id,
530+
display_name=None,
531+
description=None,
532+
force=force, # Pass through the --force flag
533+
yes=True,
534+
env_file=None, # Add the new env_file parameter
535+
)
401536

402-
rc = upload_command(upload_args)
403-
if rc == 0:
404-
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")
537+
if force:
538+
print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'")
405539

406-
# Poll for evaluator status
407-
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
408-
is_active = _poll_evaluator_status(
409-
evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, timeout_minutes=10
410-
)
540+
rc = upload_command(upload_args)
541+
if rc == 0:
542+
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")
411543

412-
if not is_active:
413-
# Print helpful message with dashboard link
414-
app_base = _map_api_host_to_app_host(api_base)
415-
evaluator_slug = _extract_terminal_segment(evaluator_id)
416-
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"
544+
# Poll for evaluator status
545+
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
546+
is_active = _poll_evaluator_status(
547+
evaluator_resource_name=evaluator_resource_name,
548+
api_key=api_key,
549+
api_base=api_base,
550+
timeout_minutes=10,
551+
)
417552

418-
print("\n❌ Evaluator is not ready within the timeout period.")
419-
print(f"📊 Please check the evaluator status at: {dashboard_url}")
420-
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
421-
return 1
422-
else:
423-
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
424-
except Exception as e:
425-
print(f"Warning: Failed to upload evaluator automatically: {e}")
553+
if not is_active:
554+
# Print helpful message with dashboard link
555+
app_base = _map_api_host_to_app_host(api_base)
556+
evaluator_slug = _extract_terminal_segment(evaluator_id)
557+
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"
558+
559+
print("\n❌ Evaluator is not ready within the timeout period.")
560+
print(f"📊 Please check the evaluator status at: {dashboard_url}")
561+
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
562+
return 1
563+
else:
564+
# Only persist last-used evaluator after successful ensure + ACTIVE
565+
_save_last_evaluator(project_root, evaluator_id)
566+
else:
567+
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
568+
except Exception as e:
569+
print(f"Warning: Failed to upload evaluator automatically: {e}")
426570

427571
# Determine dataset id and materialization path
428572
dataset_id = getattr(args, "dataset_id", None)

0 commit comments

Comments
 (0)