Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 33 additions & 119 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,87 +23,6 @@
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source


def _last_evaluator_paths(cwd: str) -> list[str]:
return [
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
]


def _load_last_evaluator(cwd: str) -> Optional[str]:
import json

for p in _last_evaluator_paths(cwd):
try:
if os.path.isfile(p):
with open(p, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("evaluator_id"):
return str(data["evaluator_id"])
except Exception:
# ignore and continue
pass
return None


def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
import json

base = os.path.join(cwd, ".eval_protocol")
try:
os.makedirs(base, exist_ok=True)
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
except Exception:
# best-effort only
pass


def _gather_evaluator_traces(cwd: str) -> list[dict]:
roots = [
os.path.join(cwd, ".eval_protocol", "evaluators"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
]
records: list[dict] = []
for root in roots:
if os.path.isdir(root):
for name in os.listdir(root):
if name.endswith(".json"):
full = os.path.join(root, name)
try:
mtime = os.path.getmtime(full)
except Exception:
mtime = 0.0
records.append({"id": name[:-5], "path": full, "mtime": mtime})
# dedupe by id keeping most recent mtime
dedup: dict[str, dict] = {}
for rec in records:
cur = dedup.get(rec["id"])
if not cur or rec["mtime"] > cur["mtime"]:
dedup[rec["id"]] = rec
return list(dedup.values())


def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
print("\nMultiple evaluators detected. Select one:")
ordered = sorted(candidates, key=lambda x: -x["mtime"])
for i, c in enumerate(ordered, start=1):
print(f" {i}) {c['id']} (from {c['path']})")
try:
choice = input("Enter a number (or press Enter to cancel): ").strip()
except KeyboardInterrupt:
print("\nCancelled.")
return None
if not choice or not choice.isdigit():
return None
n = int(choice)
if 1 <= n <= len(ordered):
sel = ordered[n - 1]["id"]
print(f"✓ Using evaluator: {sel}")
return sel
return None


def _ensure_account_id() -> Optional[str]:
account_id = get_fireworks_account_id()
api_key = get_fireworks_api_key()
Expand Down Expand Up @@ -331,37 +250,6 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
return f"{base}{suffix}"


def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
# 1) Use last used pointer if available
last = _load_last_evaluator(cwd)
if last:
return last

# 2) Look for evaluator traces in project and home
traces = _gather_evaluator_traces(cwd)
if len(traces) == 1:
return traces[0]["id"]
if len(traces) > 1:
if non_interactive:
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
return sel
chosen = _prompt_select_evaluator(traces)
if chosen:
return chosen
return None

# 3) Fall back to discovering a single evaluation_test
tests = _discover_tests(cwd)
if len(tests) == 1:
qualname, source_file_path = tests[0].qualname, tests[0].file_path
test_func_name = qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
return evaluator_id
return None


def _poll_evaluator_status(
evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10
) -> bool:
Expand Down Expand Up @@ -441,13 +329,40 @@ def create_rft_command(args) -> int:

api_base = get_fireworks_api_base()

# Resolve evaluator id if omitted
# Resolve evaluator id/entry if omitted (reuse upload's selector flow)
project_root = os.getcwd()
preselected_entry: Optional[str] = None
if not evaluator_id:
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
if not evaluator_id:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
print("Scanning for evaluation tests...")
tests = _discover_tests(project_root)
if not tests:
print("No evaluation tests found.")
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
return 1
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
try:
from .upload import _prompt_select # reuse the same selector UX as 'upload'

selected_tests = _prompt_select(tests, non_interactive=False)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
except Exception:
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
return 1
if not selected_tests:
print("No tests selected.")
return 1
if len(selected_tests) != 1:
print("Error: Please select exactly one evaluation test for 'create rft'.")
Comment thread
xzrderek marked this conversation as resolved.
Outdated
return 1
chosen = selected_tests[0]
func_name = chosen.qualname.split(".")[-1]
abs_path = os.path.abspath(chosen.file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
preselected_entry = f"{rel}::{func_name}"
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

Expand Down Expand Up @@ -479,7 +394,6 @@ def create_rft_command(args) -> int:
print(f"📊 Please check the evaluator status at: {dashboard_url}")
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
_save_last_evaluator(project_root, evaluator_id)
skip_upload = True
except requests.exceptions.RequestException:
pass
Expand Down Expand Up @@ -561,8 +475,8 @@ def create_rft_command(args) -> int:
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
else:
# Only persist last-used evaluator after successful ensure + ACTIVE
_save_last_evaluator(project_root, evaluator_id)
# Evaluator ACTIVE; proceed
pass
else:
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
except Exception as e:
Expand Down
Loading
Loading