|
23 | 23 | from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source |
24 | 24 |
|
25 | 25 |
|
| 26 | +def _last_evaluator_paths(cwd: str) -> list[str]: |
| 27 | + return [ |
| 28 | + os.path.join(cwd, ".eval_protocol", "last_evaluator.json"), |
| 29 | + os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")), |
| 30 | + ] |
| 31 | + |
| 32 | + |
| 33 | +def _load_last_evaluator(cwd: str) -> Optional[str]: |
| 34 | + import json |
| 35 | + |
| 36 | + for p in _last_evaluator_paths(cwd): |
| 37 | + try: |
| 38 | + if os.path.isfile(p): |
| 39 | + with open(p, "r", encoding="utf-8") as f: |
| 40 | + data = json.load(f) |
| 41 | + if isinstance(data, dict) and data.get("evaluator_id"): |
| 42 | + return str(data["evaluator_id"]) |
| 43 | + except Exception: |
| 44 | + # ignore and continue |
| 45 | + pass |
| 46 | + return None |
| 47 | + |
| 48 | + |
| 49 | +def _save_last_evaluator(cwd: str, evaluator_id: str) -> None: |
| 50 | + import json |
| 51 | + |
| 52 | + base = os.path.join(cwd, ".eval_protocol") |
| 53 | + try: |
| 54 | + os.makedirs(base, exist_ok=True) |
| 55 | + with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f: |
| 56 | + json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f) |
| 57 | + except Exception: |
| 58 | + # best-effort only |
| 59 | + pass |
| 60 | + |
| 61 | + |
| 62 | +def _gather_evaluator_traces(cwd: str) -> list[dict]: |
| 63 | + roots = [ |
| 64 | + os.path.join(cwd, ".eval_protocol", "evaluators"), |
| 65 | + os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")), |
| 66 | + ] |
| 67 | + records: list[dict] = [] |
| 68 | + for root in roots: |
| 69 | + if os.path.isdir(root): |
| 70 | + for name in os.listdir(root): |
| 71 | + if name.endswith(".json"): |
| 72 | + full = os.path.join(root, name) |
| 73 | + try: |
| 74 | + mtime = os.path.getmtime(full) |
| 75 | + except Exception: |
| 76 | + mtime = 0.0 |
| 77 | + records.append({"id": name[:-5], "path": full, "mtime": mtime}) |
| 78 | + # dedupe by id keeping most recent mtime |
| 79 | + dedup: dict[str, dict] = {} |
| 80 | + for rec in records: |
| 81 | + cur = dedup.get(rec["id"]) |
| 82 | + if not cur or rec["mtime"] > cur["mtime"]: |
| 83 | + dedup[rec["id"]] = rec |
| 84 | + return list(dedup.values()) |
| 85 | + |
| 86 | + |
| 87 | +def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]: |
| 88 | + print("\nMultiple evaluators detected. Select one:") |
| 89 | + ordered = sorted(candidates, key=lambda x: -x["mtime"]) |
| 90 | + for i, c in enumerate(ordered, start=1): |
| 91 | + print(f" {i}) {c['id']} (from {c['path']})") |
| 92 | + try: |
| 93 | + choice = input("Enter a number (or press Enter to cancel): ").strip() |
| 94 | + except KeyboardInterrupt: |
| 95 | + print("\nCancelled.") |
| 96 | + return None |
| 97 | + if not choice or not choice.isdigit(): |
| 98 | + return None |
| 99 | + n = int(choice) |
| 100 | + if 1 <= n <= len(ordered): |
| 101 | + sel = ordered[n - 1]["id"] |
| 102 | + print(f"✓ Using evaluator: {sel}") |
| 103 | + return sel |
| 104 | + return None |
| 105 | + |
| 106 | + |
26 | 107 | def _ensure_account_id() -> Optional[str]: |
27 | 108 | account_id = get_fireworks_account_id() |
28 | 109 | api_key = get_fireworks_api_key() |
@@ -248,14 +329,27 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str: |
248 | 329 | return f"{base}{suffix}" |
249 | 330 |
|
250 | 331 |
|
251 | | -def _auto_select_evaluator_id(cwd: str) -> Optional[str]: |
252 | | - # Try local traces |
253 | | - traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators") |
254 | | - if os.path.isdir(traces_dir): |
255 | | - candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")] |
256 | | - if len(candidates) == 1: |
257 | | - return candidates[0] |
258 | | - # Fall back to discovering a single evaluation_test |
| 332 | +def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]: |
| 333 | + # 1) Use last used pointer if available |
| 334 | + last = _load_last_evaluator(cwd) |
| 335 | + if last: |
| 336 | + return last |
| 337 | + |
| 338 | + # 2) Look for evaluator traces in project and home |
| 339 | + traces = _gather_evaluator_traces(cwd) |
| 340 | + if len(traces) == 1: |
| 341 | + return traces[0]["id"] |
| 342 | + if len(traces) > 1: |
| 343 | + if non_interactive: |
| 344 | + sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"] |
| 345 | + print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.") |
| 346 | + return sel |
| 347 | + chosen = _prompt_select_evaluator(traces) |
| 348 | + if chosen: |
| 349 | + return chosen |
| 350 | + return None |
| 351 | + |
| 352 | + # 3) Fall back to discovering a single evaluation_test |
259 | 353 | tests = _discover_tests(cwd) |
260 | 354 | if len(tests) == 1: |
261 | 355 | qualname, source_file_path = tests[0].qualname, tests[0].file_path |
@@ -348,10 +442,12 @@ def create_rft_command(args) -> int: |
348 | 442 | # Resolve evaluator id if omitted |
349 | 443 | project_root = os.getcwd() |
350 | 444 | if not evaluator_id: |
351 | | - evaluator_id = _auto_select_evaluator_id(project_root) |
| 445 | + evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive) |
352 | 446 | if not evaluator_id: |
353 | 447 | print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") |
354 | 448 | return 1 |
| 449 | + # Persist last selected/used evaluator for next runs |
| 450 | + _save_last_evaluator(project_root, evaluator_id) |
355 | 451 |
|
356 | 452 | # Resolve evaluator resource name to fully-qualified format required by API |
357 | 453 | evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" |
|
0 commit comments