Skip to content

Commit 64efd1c

Browse files
committed
auto select evaluators correctly
1 parent cf59d13 commit 64efd1c

File tree

2 files changed

+231
-9
lines changed

2 files changed

+231
-9
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,87 @@
2323
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2424

2525

26+
def _last_evaluator_paths(cwd: str) -> list[str]:
27+
return [
28+
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
29+
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
30+
]
31+
32+
33+
def _load_last_evaluator(cwd: str) -> Optional[str]:
34+
import json
35+
36+
for p in _last_evaluator_paths(cwd):
37+
try:
38+
if os.path.isfile(p):
39+
with open(p, "r", encoding="utf-8") as f:
40+
data = json.load(f)
41+
if isinstance(data, dict) and data.get("evaluator_id"):
42+
return str(data["evaluator_id"])
43+
except Exception:
44+
# ignore and continue
45+
pass
46+
return None
47+
48+
49+
def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
50+
import json
51+
52+
base = os.path.join(cwd, ".eval_protocol")
53+
try:
54+
os.makedirs(base, exist_ok=True)
55+
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
56+
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
57+
except Exception:
58+
# best-effort only
59+
pass
60+
61+
62+
def _gather_evaluator_traces(cwd: str) -> list[dict]:
63+
roots = [
64+
os.path.join(cwd, ".eval_protocol", "evaluators"),
65+
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
66+
]
67+
records: list[dict] = []
68+
for root in roots:
69+
if os.path.isdir(root):
70+
for name in os.listdir(root):
71+
if name.endswith(".json"):
72+
full = os.path.join(root, name)
73+
try:
74+
mtime = os.path.getmtime(full)
75+
except Exception:
76+
mtime = 0.0
77+
records.append({"id": name[:-5], "path": full, "mtime": mtime})
78+
# dedupe by id keeping most recent mtime
79+
dedup: dict[str, dict] = {}
80+
for rec in records:
81+
cur = dedup.get(rec["id"])
82+
if not cur or rec["mtime"] > cur["mtime"]:
83+
dedup[rec["id"]] = rec
84+
return list(dedup.values())
85+
86+
87+
def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
88+
print("\nMultiple evaluators detected. Select one:")
89+
ordered = sorted(candidates, key=lambda x: -x["mtime"])
90+
for i, c in enumerate(ordered, start=1):
91+
print(f" {i}) {c['id']} (from {c['path']})")
92+
try:
93+
choice = input("Enter a number (or press Enter to cancel): ").strip()
94+
except KeyboardInterrupt:
95+
print("\nCancelled.")
96+
return None
97+
if not choice or not choice.isdigit():
98+
return None
99+
n = int(choice)
100+
if 1 <= n <= len(ordered):
101+
sel = ordered[n - 1]["id"]
102+
print(f"✓ Using evaluator: {sel}")
103+
return sel
104+
return None
105+
106+
26107
def _ensure_account_id() -> Optional[str]:
27108
account_id = get_fireworks_account_id()
28109
api_key = get_fireworks_api_key()
@@ -248,14 +329,27 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
248329
return f"{base}{suffix}"
249330

250331

251-
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
252-
# Try local traces
253-
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
254-
if os.path.isdir(traces_dir):
255-
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
256-
if len(candidates) == 1:
257-
return candidates[0]
258-
# Fall back to discovering a single evaluation_test
332+
def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
333+
# 1) Use last used pointer if available
334+
last = _load_last_evaluator(cwd)
335+
if last:
336+
return last
337+
338+
# 2) Look for evaluator traces in project and home
339+
traces = _gather_evaluator_traces(cwd)
340+
if len(traces) == 1:
341+
return traces[0]["id"]
342+
if len(traces) > 1:
343+
if non_interactive:
344+
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
345+
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
346+
return sel
347+
chosen = _prompt_select_evaluator(traces)
348+
if chosen:
349+
return chosen
350+
return None
351+
352+
# 3) Fall back to discovering a single evaluation_test
259353
tests = _discover_tests(cwd)
260354
if len(tests) == 1:
261355
qualname, source_file_path = tests[0].qualname, tests[0].file_path
@@ -348,10 +442,12 @@ def create_rft_command(args) -> int:
348442
# Resolve evaluator id if omitted
349443
project_root = os.getcwd()
350444
if not evaluator_id:
351-
evaluator_id = _auto_select_evaluator_id(project_root)
445+
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
352446
if not evaluator_id:
353447
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
354448
return 1
449+
# Persist last selected/used evaluator for next runs
450+
_save_last_evaluator(project_root, evaluator_id)
355451

356452
# Resolve evaluator resource name to fully-qualified format required by API
357453
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

tests/test_cli_create_rft_infer.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import json
2+
import os
3+
import time
4+
from types import SimpleNamespace
5+
from unittest.mock import patch
6+
7+
import pytest
8+
9+
from eval_protocol.cli_commands import create_rft as cr
10+
11+
12+
def _write_json(path: str, data: dict) -> None:
13+
os.makedirs(os.path.dirname(path), exist_ok=True)
14+
with open(path, "w", encoding="utf-8") as f:
15+
json.dump(data, f)
16+
17+
18+
def test_load_and_save_last_evaluator(tmp_path, monkeypatch):
19+
# Force HOME to temp so expanduser paths remain inside tmp
20+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
21+
project = tmp_path / "proj"
22+
project.mkdir()
23+
24+
# Initially none
25+
assert cr._load_last_evaluator(str(project)) is None
26+
27+
# Save and load
28+
cr._save_last_evaluator(str(project), "evaluator-abc")
29+
assert cr._load_last_evaluator(str(project)) == "evaluator-abc"
30+
31+
32+
def test_auto_select_uses_last_pointer(tmp_path, monkeypatch):
33+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
34+
project = tmp_path / "proj"
35+
project.mkdir()
36+
37+
# Write last pointer under project
38+
last_path = project / ".eval_protocol" / "last_evaluator.json"
39+
_write_json(str(last_path), {"evaluator_id": "chosen-id"})
40+
41+
eid = cr._auto_select_evaluator_id(str(project))
42+
assert eid == "chosen-id"
43+
44+
45+
def test_auto_select_single_trace(tmp_path, monkeypatch):
46+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
47+
project = tmp_path / "proj"
48+
project.mkdir()
49+
50+
# Single evaluator trace under project
51+
trace = project / ".eval_protocol" / "evaluators" / "only-one.json"
52+
_write_json(str(trace), {"dummy": True})
53+
54+
eid = cr._auto_select_evaluator_id(str(project))
55+
assert eid == "only-one"
56+
57+
58+
def test_auto_select_multiple_traces_non_interactive_most_recent(tmp_path, monkeypatch):
59+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
60+
project = tmp_path / "proj"
61+
project.mkdir()
62+
63+
# Two traces with different mtimes
64+
older = project / ".eval_protocol" / "evaluators" / "older.json"
65+
newer = project / ".eval_protocol" / "evaluators" / "newer.json"
66+
_write_json(str(older), {})
67+
_write_json(str(newer), {})
68+
# Set older then newer mtime
69+
t0 = time.time() - 100
70+
os.utime(str(older), (t0, t0))
71+
t1 = time.time()
72+
os.utime(str(newer), (t1, t1))
73+
74+
eid = cr._auto_select_evaluator_id(str(project), non_interactive=True)
75+
assert eid == "newer"
76+
77+
78+
def test_auto_select_multiple_traces_interactive_prompt(tmp_path, monkeypatch):
79+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
80+
project = tmp_path / "proj"
81+
project.mkdir()
82+
83+
# Two traces with different mtimes to force ordering: newer first, older second
84+
older = project / ".eval_protocol" / "evaluators" / "older.json"
85+
newer = project / ".eval_protocol" / "evaluators" / "newer.json"
86+
_write_json(str(older), {})
87+
_write_json(str(newer), {})
88+
t0 = time.time() - 100
89+
os.utime(str(older), (t0, t0))
90+
t1 = time.time()
91+
os.utime(str(newer), (t1, t1))
92+
93+
with patch("builtins.input", return_value="2"):
94+
eid = cr._auto_select_evaluator_id(str(project), non_interactive=False)
95+
# Choosing "2" should pick the second item by recency => "older"
96+
assert eid == "older"
97+
98+
99+
def test_auto_select_falls_back_to_single_discovered_test(tmp_path, monkeypatch):
100+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
101+
project = tmp_path / "proj"
102+
project.mkdir()
103+
104+
# No traces; provide exactly one discovered test
105+
test_file = project / "metric" / "test_calendar.py"
106+
test_file.parent.mkdir(parents=True, exist_ok=True)
107+
test_file.write_text("# dummy", encoding="utf-8")
108+
109+
dummy = SimpleNamespace(qualname="calendar_agent.test_calendar_agent_evaluation", file_path=str(test_file))
110+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [dummy])
111+
112+
eid = cr._auto_select_evaluator_id(str(project))
113+
assert eid is not None
114+
# Should incorporate function name suffix
115+
assert "test_calendar_agent_evaluation".split("_")[-1] in eid or "test-calendar-agent-evaluation" in eid
116+
117+
118+
def test_auto_select_returns_none_when_no_candidates(tmp_path, monkeypatch):
119+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
120+
project = tmp_path / "proj"
121+
project.mkdir()
122+
123+
# No traces, no tests
124+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [])
125+
eid = cr._auto_select_evaluator_id(str(project))
126+
assert eid is None

0 commit comments

Comments
 (0)