Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 105 additions & 9 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,87 @@
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source


def _last_evaluator_paths(cwd: str) -> list[str]:
return [
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
]


def _load_last_evaluator(cwd: str) -> Optional[str]:
import json

for p in _last_evaluator_paths(cwd):
try:
if os.path.isfile(p):
with open(p, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("evaluator_id"):
return str(data["evaluator_id"])
except Exception:
# ignore and continue
pass
return None


def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
import json

base = os.path.join(cwd, ".eval_protocol")
try:
os.makedirs(base, exist_ok=True)
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
except Exception:
# best-effort only
pass


def _gather_evaluator_traces(cwd: str) -> list[dict]:
roots = [
os.path.join(cwd, ".eval_protocol", "evaluators"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
]
records: list[dict] = []
for root in roots:
if os.path.isdir(root):
for name in os.listdir(root):
if name.endswith(".json"):
full = os.path.join(root, name)
try:
mtime = os.path.getmtime(full)
except Exception:
mtime = 0.0
records.append({"id": name[:-5], "path": full, "mtime": mtime})
# dedupe by id keeping most recent mtime
dedup: dict[str, dict] = {}
for rec in records:
cur = dedup.get(rec["id"])
if not cur or rec["mtime"] > cur["mtime"]:
dedup[rec["id"]] = rec
return list(dedup.values())


def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
print("\nMultiple evaluators detected. Select one:")
ordered = sorted(candidates, key=lambda x: -x["mtime"])
for i, c in enumerate(ordered, start=1):
print(f" {i}) {c['id']} (from {c['path']})")
try:
choice = input("Enter a number (or press Enter to cancel): ").strip()
except KeyboardInterrupt:
print("\nCancelled.")
return None
if not choice or not choice.isdigit():
return None
n = int(choice)
if 1 <= n <= len(ordered):
sel = ordered[n - 1]["id"]
print(f"✓ Using evaluator: {sel}")
return sel
return None


def _ensure_account_id() -> Optional[str]:
account_id = get_fireworks_account_id()
api_key = get_fireworks_api_key()
Expand Down Expand Up @@ -248,14 +329,27 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
return f"{base}{suffix}"


def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
# Try local traces
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
if os.path.isdir(traces_dir):
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
if len(candidates) == 1:
return candidates[0]
# Fall back to discovering a single evaluation_test
def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
Comment thread
xzrderek marked this conversation as resolved.
# 1) Use last used pointer if available
last = _load_last_evaluator(cwd)
if last:
return last

# 2) Look for evaluator traces in project and home
traces = _gather_evaluator_traces(cwd)
if len(traces) == 1:
return traces[0]["id"]
if len(traces) > 1:
if non_interactive:
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
return sel
chosen = _prompt_select_evaluator(traces)
if chosen:
return chosen
return None

# 3) Fall back to discovering a single evaluation_test
tests = _discover_tests(cwd)
if len(tests) == 1:
qualname, source_file_path = tests[0].qualname, tests[0].file_path
Expand Down Expand Up @@ -348,10 +442,12 @@ def create_rft_command(args) -> int:
# Resolve evaluator id if omitted
project_root = os.getcwd()
if not evaluator_id:
evaluator_id = _auto_select_evaluator_id(project_root)
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
Comment thread
xzrderek marked this conversation as resolved.
if not evaluator_id:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
return 1
# Persist last selected/used evaluator for next runs
_save_last_evaluator(project_root, evaluator_id)
Comment thread
xzrderek marked this conversation as resolved.
Outdated

# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
Expand Down
126 changes: 126 additions & 0 deletions tests/test_cli_create_rft_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import json
import os
import time
from types import SimpleNamespace
from unittest.mock import patch

import pytest

from eval_protocol.cli_commands import create_rft as cr


def _write_json(path: str, data: dict) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f)


def test_load_and_save_last_evaluator(tmp_path, monkeypatch):
# Force HOME to temp so expanduser paths remain inside tmp
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# Initially none
assert cr._load_last_evaluator(str(project)) is None

# Save and load
cr._save_last_evaluator(str(project), "evaluator-abc")
assert cr._load_last_evaluator(str(project)) == "evaluator-abc"


def test_auto_select_uses_last_pointer(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# Write last pointer under project
last_path = project / ".eval_protocol" / "last_evaluator.json"
_write_json(str(last_path), {"evaluator_id": "chosen-id"})

eid = cr._auto_select_evaluator_id(str(project))
assert eid == "chosen-id"


def test_auto_select_single_trace(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# Single evaluator trace under project
trace = project / ".eval_protocol" / "evaluators" / "only-one.json"
_write_json(str(trace), {"dummy": True})

eid = cr._auto_select_evaluator_id(str(project))
assert eid == "only-one"


def test_auto_select_multiple_traces_non_interactive_most_recent(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# Two traces with different mtimes
older = project / ".eval_protocol" / "evaluators" / "older.json"
newer = project / ".eval_protocol" / "evaluators" / "newer.json"
_write_json(str(older), {})
_write_json(str(newer), {})
# Set older then newer mtime
t0 = time.time() - 100
os.utime(str(older), (t0, t0))
t1 = time.time()
os.utime(str(newer), (t1, t1))

eid = cr._auto_select_evaluator_id(str(project), non_interactive=True)
assert eid == "newer"


def test_auto_select_multiple_traces_interactive_prompt(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# Two traces with different mtimes to force ordering: newer first, older second
older = project / ".eval_protocol" / "evaluators" / "older.json"
newer = project / ".eval_protocol" / "evaluators" / "newer.json"
_write_json(str(older), {})
_write_json(str(newer), {})
t0 = time.time() - 100
os.utime(str(older), (t0, t0))
t1 = time.time()
os.utime(str(newer), (t1, t1))

with patch("builtins.input", return_value="2"):
eid = cr._auto_select_evaluator_id(str(project), non_interactive=False)
# Choosing "2" should pick the second item by recency => "older"
assert eid == "older"


def test_auto_select_falls_back_to_single_discovered_test(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# No traces; provide exactly one discovered test
test_file = project / "metric" / "test_calendar.py"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.write_text("# dummy", encoding="utf-8")

dummy = SimpleNamespace(qualname="calendar_agent.test_calendar_agent_evaluation", file_path=str(test_file))
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [dummy])

eid = cr._auto_select_evaluator_id(str(project))
assert eid is not None
# Should incorporate function name suffix
assert "test_calendar_agent_evaluation".split("_")[-1] in eid or "test-calendar-agent-evaluation" in eid


def test_auto_select_returns_none_when_no_candidates(tmp_path, monkeypatch):
monkeypatch.setenv("HOME", str(tmp_path / "home"))
project = tmp_path / "proj"
project.mkdir()

# No traces, no tests
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [])
eid = cr._auto_select_evaluator_id(str(project))
assert eid is None
Loading