|
23 | 23 | from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source |
24 | 24 |
|
25 | 25 |
|
26 | | -def _last_evaluator_paths(cwd: str) -> list[str]: |
27 | | - return [ |
28 | | - os.path.join(cwd, ".eval_protocol", "last_evaluator.json"), |
29 | | - os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")), |
30 | | - ] |
31 | | - |
32 | | - |
33 | | -def _load_last_evaluator(cwd: str) -> Optional[str]: |
34 | | - import json |
35 | | - |
36 | | - for p in _last_evaluator_paths(cwd): |
37 | | - try: |
38 | | - if os.path.isfile(p): |
39 | | - with open(p, "r", encoding="utf-8") as f: |
40 | | - data = json.load(f) |
41 | | - if isinstance(data, dict) and data.get("evaluator_id"): |
42 | | - return str(data["evaluator_id"]) |
43 | | - except Exception: |
44 | | - # ignore and continue |
45 | | - pass |
46 | | - return None |
47 | | - |
48 | | - |
49 | | -def _save_last_evaluator(cwd: str, evaluator_id: str) -> None: |
50 | | - import json |
51 | | - |
52 | | - base = os.path.join(cwd, ".eval_protocol") |
53 | | - try: |
54 | | - os.makedirs(base, exist_ok=True) |
55 | | - with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f: |
56 | | - json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f) |
57 | | - except Exception: |
58 | | - # best-effort only |
59 | | - pass |
60 | | - |
61 | | - |
62 | | -def _gather_evaluator_traces(cwd: str) -> list[dict]: |
63 | | - roots = [ |
64 | | - os.path.join(cwd, ".eval_protocol", "evaluators"), |
65 | | - os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")), |
66 | | - ] |
67 | | - records: list[dict] = [] |
68 | | - for root in roots: |
69 | | - if os.path.isdir(root): |
70 | | - for name in os.listdir(root): |
71 | | - if name.endswith(".json"): |
72 | | - full = os.path.join(root, name) |
73 | | - try: |
74 | | - mtime = os.path.getmtime(full) |
75 | | - except Exception: |
76 | | - mtime = 0.0 |
77 | | - records.append({"id": name[:-5], "path": full, "mtime": mtime}) |
78 | | - # dedupe by id keeping most recent mtime |
79 | | - dedup: dict[str, dict] = {} |
80 | | - for rec in records: |
81 | | - cur = dedup.get(rec["id"]) |
82 | | - if not cur or rec["mtime"] > cur["mtime"]: |
83 | | - dedup[rec["id"]] = rec |
84 | | - return list(dedup.values()) |
85 | | - |
86 | | - |
87 | | -def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]: |
88 | | - print("\nMultiple evaluators detected. Select one:") |
89 | | - ordered = sorted(candidates, key=lambda x: -x["mtime"]) |
90 | | - for i, c in enumerate(ordered, start=1): |
91 | | - print(f" {i}) {c['id']} (from {c['path']})") |
92 | | - try: |
93 | | - choice = input("Enter a number (or press Enter to cancel): ").strip() |
94 | | - except KeyboardInterrupt: |
95 | | - print("\nCancelled.") |
96 | | - return None |
97 | | - if not choice or not choice.isdigit(): |
98 | | - return None |
99 | | - n = int(choice) |
100 | | - if 1 <= n <= len(ordered): |
101 | | - sel = ordered[n - 1]["id"] |
102 | | - print(f"✓ Using evaluator: {sel}") |
103 | | - return sel |
104 | | - return None |
105 | | - |
106 | | - |
107 | 26 | def _ensure_account_id() -> Optional[str]: |
108 | 27 | account_id = get_fireworks_account_id() |
109 | 28 | api_key = get_fireworks_api_key() |
@@ -331,37 +250,6 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str: |
331 | 250 | return f"{base}{suffix}" |
332 | 251 |
|
333 | 252 |
|
334 | | -def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]: |
335 | | - # 1) Use last used pointer if available |
336 | | - last = _load_last_evaluator(cwd) |
337 | | - if last: |
338 | | - return last |
339 | | - |
340 | | - # 2) Look for evaluator traces in project and home |
341 | | - traces = _gather_evaluator_traces(cwd) |
342 | | - if len(traces) == 1: |
343 | | - return traces[0]["id"] |
344 | | - if len(traces) > 1: |
345 | | - if non_interactive: |
346 | | - sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"] |
347 | | - print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.") |
348 | | - return sel |
349 | | - chosen = _prompt_select_evaluator(traces) |
350 | | - if chosen: |
351 | | - return chosen |
352 | | - return None |
353 | | - |
354 | | - # 3) Fall back to discovering a single evaluation_test |
355 | | - tests = _discover_tests(cwd) |
356 | | - if len(tests) == 1: |
357 | | - qualname, source_file_path = tests[0].qualname, tests[0].file_path |
358 | | - test_func_name = qualname.split(".")[-1] |
359 | | - source_file_name = os.path.splitext(os.path.basename(source_file_path))[0] |
360 | | - evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}") |
361 | | - return evaluator_id |
362 | | - return None |
363 | | - |
364 | | - |
365 | 253 | def _poll_evaluator_status( |
366 | 254 | evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 |
367 | 255 | ) -> bool: |
@@ -441,13 +329,40 @@ def create_rft_command(args) -> int: |
441 | 329 |
|
442 | 330 | api_base = get_fireworks_api_base() |
443 | 331 |
|
444 | | - # Resolve evaluator id if omitted |
| 332 | + # Resolve evaluator id/entry if omitted (reuse upload's selector flow) |
445 | 333 | project_root = os.getcwd() |
| 334 | + preselected_entry: Optional[str] = None |
446 | 335 | if not evaluator_id: |
447 | | - evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive) |
448 | | - if not evaluator_id: |
449 | | - print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") |
| 336 | + print("Scanning for evaluation tests...") |
| 337 | + tests = _discover_tests(project_root) |
| 338 | + if not tests: |
| 339 | + print("No evaluation tests found.") |
| 340 | + print("\nHint: Make sure your tests use the @evaluation_test decorator.") |
| 341 | + return 1 |
| 342 | + # Always interactive selection here (no implicit quiet unless --evaluator-id was provided) |
| 343 | + try: |
| 344 | + from .upload import _prompt_select # reuse the same selector UX as 'upload' |
| 345 | + |
| 346 | + selected_tests = _prompt_select(tests, non_interactive=False) |
| 347 | + except Exception: |
| 348 | + print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.") |
450 | 349 | return 1 |
| 350 | + if not selected_tests: |
| 351 | + print("No tests selected.") |
| 352 | + return 1 |
| 353 | + if len(selected_tests) != 1: |
| 354 | + print("Error: Please select exactly one evaluation test for 'create rft'.") |
| 355 | + return 1 |
| 356 | + chosen = selected_tests[0] |
| 357 | + func_name = chosen.qualname.split(".")[-1] |
| 358 | + abs_path = os.path.abspath(chosen.file_path) |
| 359 | + try: |
| 360 | + rel = os.path.relpath(abs_path, project_root) |
| 361 | + except Exception: |
| 362 | + rel = abs_path |
| 363 | + preselected_entry = f"{rel}::{func_name}" |
| 364 | + source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] |
| 365 | + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") |
451 | 366 | # Resolve evaluator resource name to fully-qualified format required by API |
452 | 367 | evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" |
453 | 368 |
|
@@ -479,7 +394,6 @@ def create_rft_command(args) -> int: |
479 | 394 | print(f"📊 Please check the evaluator status at: {dashboard_url}") |
480 | 395 | print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") |
481 | 396 | return 1 |
482 | | - _save_last_evaluator(project_root, evaluator_id) |
483 | 397 | skip_upload = True |
484 | 398 | except requests.exceptions.RequestException: |
485 | 399 | pass |
@@ -561,8 +475,8 @@ def create_rft_command(args) -> int: |
561 | 475 | print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") |
562 | 476 | return 1 |
563 | 477 | else: |
564 | | - # Only persist last-used evaluator after successful ensure + ACTIVE |
565 | | - _save_last_evaluator(project_root, evaluator_id) |
| 478 | + # Evaluator ACTIVE; proceed |
| 479 | + pass |
566 | 480 | else: |
567 | 481 | print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") |
568 | 482 | except Exception as e: |
|
0 commit comments