Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 214 additions & 70 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,87 @@
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source


def _last_evaluator_paths(cwd: str) -> list[str]:
return [
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
]


def _load_last_evaluator(cwd: str) -> Optional[str]:
import json

for p in _last_evaluator_paths(cwd):
try:
if os.path.isfile(p):
with open(p, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("evaluator_id"):
return str(data["evaluator_id"])
except Exception:
# ignore and continue
pass
return None


def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
import json

base = os.path.join(cwd, ".eval_protocol")
try:
os.makedirs(base, exist_ok=True)
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
except Exception:
# best-effort only
pass


def _gather_evaluator_traces(cwd: str) -> list[dict]:
roots = [
os.path.join(cwd, ".eval_protocol", "evaluators"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
]
records: list[dict] = []
for root in roots:
if os.path.isdir(root):
for name in os.listdir(root):
if name.endswith(".json"):
full = os.path.join(root, name)
try:
mtime = os.path.getmtime(full)
except Exception:
mtime = 0.0
records.append({"id": name[:-5], "path": full, "mtime": mtime})
# dedupe by id keeping most recent mtime
dedup: dict[str, dict] = {}
for rec in records:
cur = dedup.get(rec["id"])
if not cur or rec["mtime"] > cur["mtime"]:
dedup[rec["id"]] = rec
return list(dedup.values())


def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
print("\nMultiple evaluators detected. Select one:")
ordered = sorted(candidates, key=lambda x: -x["mtime"])
for i, c in enumerate(ordered, start=1):
print(f" {i}) {c['id']} (from {c['path']})")
try:
choice = input("Enter a number (or press Enter to cancel): ").strip()
except KeyboardInterrupt:
print("\nCancelled.")
return None
if not choice or not choice.isdigit():
return None
n = int(choice)
if 1 <= n <= len(ordered):
sel = ordered[n - 1]["id"]
print(f"✓ Using evaluator: {sel}")
return sel
return None


def _ensure_account_id() -> Optional[str]:
account_id = get_fireworks_account_id()
api_key = get_fireworks_api_key()
Expand Down Expand Up @@ -240,6 +321,8 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
if not base:
base = "dataset"
# Ensure first char is a letter
if not base:
base = "dataset"
if not base[0].isalpha():
base = f"eval-{base}"
if len(base) > max_base_len:
Expand All @@ -248,14 +331,27 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
return f"{base}{suffix}"


def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
# Try local traces
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
if os.path.isdir(traces_dir):
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
if len(candidates) == 1:
return candidates[0]
# Fall back to discovering a single evaluation_test
def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
# 1) Use last used pointer if available
last = _load_last_evaluator(cwd)
if last:
return last

# 2) Look for evaluator traces in project and home
traces = _gather_evaluator_traces(cwd)
if len(traces) == 1:
return traces[0]["id"]
if len(traces) > 1:
if non_interactive:
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
return sel
chosen = _prompt_select_evaluator(traces)
if chosen:
return chosen
return None

# 3) Fall back to discovering a single evaluation_test
tests = _discover_tests(cwd)
if len(tests) == 1:
qualname, source_file_path = tests[0].qualname, tests[0].file_path
Expand Down Expand Up @@ -348,81 +444,129 @@ def create_rft_command(args) -> int:
# Resolve evaluator id if omitted
project_root = os.getcwd()
if not evaluator_id:
evaluator_id = _auto_select_evaluator_id(project_root)
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
if not evaluator_id:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
return 1

# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
skip_upload = False
if not force:
try:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10)
if resp.ok:
state = resp.json().get("state", "STATE_UNSPECIFIED")
print(f"✓ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).")
# Poll for ACTIVE before proceeding
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
if not _poll_evaluator_status(
evaluator_resource_name=evaluator_resource_name,
api_key=api_key,
api_base=api_base,
timeout_minutes=10,
):
app_base = _map_api_host_to_app_host(api_base)
evaluator_slug = _extract_terminal_segment(evaluator_id)
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"
print("\n❌ Evaluator is not ready within the timeout period.")
print(f"📊 Please check the evaluator status at: {dashboard_url}")
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
_save_last_evaluator(project_root, evaluator_id)
skip_upload = True
except requests.exceptions.RequestException:
pass

# Ensure evaluator exists by invoking the upload flow programmatically
try:
from .upload import upload_command
if not skip_upload:
try:
from .upload import upload_command

tests = _discover_tests(project_root)
selected_entry: Optional[str] = None
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
abs_path = os.path.abspath(tests[0].file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
else:
# Try to match evaluator_id to a discovered test's normalized ID
for t in tests:
func_name = t.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
if candidate == evaluator_id:
abs_path = os.path.abspath(t.file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
break

upload_args = argparse.Namespace(
path=project_root,
entry=selected_entry,
id=evaluator_id,
display_name=None,
description=None,
force=force, # Pass through the --force flag
yes=True,
env_file=None, # Add the new env_file parameter
)
tests = _discover_tests(project_root)
selected_entry: Optional[str] = None
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
abs_path = os.path.abspath(tests[0].file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
else:
# Try to match evaluator_id to a discovered test's normalized ID
for t in tests:
func_name = t.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
if candidate == evaluator_id:
abs_path = os.path.abspath(t.file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
break
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
if selected_entry is None and len(tests) > 1:
print(
f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n"
" Please re-run specifying the evaluator id.\n"
" Hints:\n"
" - eval-protocol create rft --evaluator-id <existing-evaluator-id>\n"
)
return 1

if force:
print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'")
upload_args = argparse.Namespace(
path=project_root,
entry=selected_entry,
id=evaluator_id,
display_name=None,
description=None,
force=force, # Pass through the --force flag
yes=True,
env_file=None, # Add the new env_file parameter
)

rc = upload_command(upload_args)
if rc == 0:
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")
if force:
print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'")

# Poll for evaluator status
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
is_active = _poll_evaluator_status(
evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, timeout_minutes=10
)
rc = upload_command(upload_args)
if rc == 0:
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")

if not is_active:
# Print helpful message with dashboard link
app_base = _map_api_host_to_app_host(api_base)
evaluator_slug = _extract_terminal_segment(evaluator_id)
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"
# Poll for evaluator status
print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...")
is_active = _poll_evaluator_status(
evaluator_resource_name=evaluator_resource_name,
api_key=api_key,
api_base=api_base,
timeout_minutes=10,
)

print("\n❌ Evaluator is not ready within the timeout period.")
print(f"📊 Please check the evaluator status at: {dashboard_url}")
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
else:
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
except Exception as e:
print(f"Warning: Failed to upload evaluator automatically: {e}")
if not is_active:
# Print helpful message with dashboard link
app_base = _map_api_host_to_app_host(api_base)
evaluator_slug = _extract_terminal_segment(evaluator_id)
dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}"

print("\n❌ Evaluator is not ready within the timeout period.")
print(f"📊 Please check the evaluator status at: {dashboard_url}")
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
else:
# Only persist last-used evaluator after successful ensure + ACTIVE
_save_last_evaluator(project_root, evaluator_id)
else:
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
except Exception as e:
print(f"Warning: Failed to upload evaluator automatically: {e}")

# Determine dataset id and materialization path
dataset_id = getattr(args, "dataset_id", None)
Expand Down
Loading
Loading