Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 118 additions & 146 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,88 +20,8 @@
create_dataset_from_jsonl,
create_reinforcement_fine_tuning_job,
)
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source


def _last_evaluator_paths(cwd: str) -> list[str]:
return [
os.path.join(cwd, ".eval_protocol", "last_evaluator.json"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")),
]


def _load_last_evaluator(cwd: str) -> Optional[str]:
import json

for p in _last_evaluator_paths(cwd):
try:
if os.path.isfile(p):
with open(p, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("evaluator_id"):
return str(data["evaluator_id"])
except Exception:
# ignore and continue
pass
return None


def _save_last_evaluator(cwd: str, evaluator_id: str) -> None:
import json

base = os.path.join(cwd, ".eval_protocol")
try:
os.makedirs(base, exist_ok=True)
with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f:
json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f)
except Exception:
# best-effort only
pass


def _gather_evaluator_traces(cwd: str) -> list[dict]:
roots = [
os.path.join(cwd, ".eval_protocol", "evaluators"),
os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")),
]
records: list[dict] = []
for root in roots:
if os.path.isdir(root):
for name in os.listdir(root):
if name.endswith(".json"):
full = os.path.join(root, name)
try:
mtime = os.path.getmtime(full)
except Exception:
mtime = 0.0
records.append({"id": name[:-5], "path": full, "mtime": mtime})
# dedupe by id keeping most recent mtime
dedup: dict[str, dict] = {}
for rec in records:
cur = dedup.get(rec["id"])
if not cur or rec["mtime"] > cur["mtime"]:
dedup[rec["id"]] = rec
return list(dedup.values())


def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]:
print("\nMultiple evaluators detected. Select one:")
ordered = sorted(candidates, key=lambda x: -x["mtime"])
for i, c in enumerate(ordered, start=1):
print(f" {i}) {c['id']} (from {c['path']})")
try:
choice = input("Enter a number (or press Enter to cancel): ").strip()
except KeyboardInterrupt:
print("\nCancelled.")
return None
if not choice or not choice.isdigit():
return None
n = int(choice)
if 1 <= n <= len(ordered):
sel = ordered[n - 1]["id"]
print(f"✓ Using evaluator: {sel}")
return sel
return None
from ..fireworks_rft import detect_dataset_builder, materialize_dataset_via_builder
from .upload import _discover_tests, _normalize_evaluator_id, _prompt_select


def _ensure_account_id() -> Optional[str]:
Expand Down Expand Up @@ -331,35 +251,35 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str:
return f"{base}{suffix}"


def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]:
# 1) Use last used pointer if available
last = _load_last_evaluator(cwd)
if last:
return last

# 2) Look for evaluator traces in project and home
traces = _gather_evaluator_traces(cwd)
if len(traces) == 1:
return traces[0]["id"]
if len(traces) > 1:
if non_interactive:
sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"]
print(f"⚠️ Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.")
return sel
chosen = _prompt_select_evaluator(traces)
if chosen:
return chosen
return None

# 3) Fall back to discovering a single evaluation_test
tests = _discover_tests(cwd)
if len(tests) == 1:
qualname, source_file_path = tests[0].qualname, tests[0].file_path
test_func_name = qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
return evaluator_id
return None
def _resolve_selected_test(
project_root: str,
evaluator_id: Optional[str],
selected_tests: Optional[list] = None,
) -> tuple[Optional[str], Optional[str]]:
"""
Resolve a single test's source file path and function name to use downstream.
Priority:
1) If selected_tests provided and length == 1, use it.
2) Else discover tests; if exactly one test, use it.
3) Else, if evaluator_id provided, match by normalized '<file-stem>-<func-name>'.
Returns: (file_path, func_name) or (None, None) if unresolved.
"""
try:
tests = selected_tests if selected_tests is not None else _discover_tests(project_root)
if not tests:
return None, None
if len(tests) == 1:
return tests[0].file_path, tests[0].qualname.split(".")[-1]
if evaluator_id:
for t in tests:
func_name = t.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
if candidate == evaluator_id:
return t.file_path, func_name
return None, None
except Exception:
return None, None


def _poll_evaluator_status(
Expand Down Expand Up @@ -428,6 +348,9 @@ def create_rft_command(args) -> int:
non_interactive: bool = bool(getattr(args, "yes", False))
dry_run: bool = bool(getattr(args, "dry_run", False))
force: bool = bool(getattr(args, "force", False))
# Track the specifically chosen test (if any) to aid dataset inference later
selected_test_file_path: Optional[str] = None
selected_test_func_name: Optional[str] = None

api_key = get_fireworks_api_key()
if not api_key:
Expand All @@ -441,13 +364,52 @@ def create_rft_command(args) -> int:

api_base = get_fireworks_api_base()

# Resolve evaluator id if omitted
# Resolve evaluator id/entry if omitted (reuse upload's selector flow)
project_root = os.getcwd()
if not evaluator_id:
evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive)
if not evaluator_id:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
print("Scanning for evaluation tests...")
tests = _discover_tests(project_root)
if not tests:
print("No evaluation tests found.")
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
return 1
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
try:
selected_tests = _prompt_select(tests, non_interactive=non_interactive)
except Exception:
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
return 1
if not selected_tests:
print("No tests selected.")
return 1
if len(selected_tests) != 1:
if non_interactive and len(selected_tests) > 1:
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
print(" Please pass --evaluator-id or --entry to disambiguate.")
try:
# Offer candidate evaluator ids for convenience
tests = _discover_tests(project_root)
if tests:
print(" Candidate evaluator ids:")
for t in tests:
func = t.qualname.split(".")[-1]
stem = os.path.splitext(os.path.basename(t.file_path))[0]
cand = _normalize_evaluator_id(f"{stem}-{func}")
print(f" - {cand}")
except Exception:
pass
else:
print("Error: Please select exactly one evaluation test for 'create rft'.")
return 1
# Derive evaluator_id from user's single selection
chosen = selected_tests[0]
func_name = chosen.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
# Resolve selected test once for downstream
selected_test_file_path, selected_test_func_name = _resolve_selected_test(
project_root, evaluator_id, selected_tests=selected_tests
)
# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

Expand Down Expand Up @@ -479,8 +441,12 @@ def create_rft_command(args) -> int:
print(f"📊 Please check the evaluator status at: {dashboard_url}")
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
_save_last_evaluator(project_root, evaluator_id)
skip_upload = True
# Populate selected test info for dataset inference later
st_path, st_func = _resolve_selected_test(project_root, evaluator_id)
if st_path and st_func:
selected_test_file_path = st_path
selected_test_func_name = st_func
except requests.exceptions.RequestException:
pass

Expand All @@ -491,28 +457,16 @@ def create_rft_command(args) -> int:

tests = _discover_tests(project_root)
selected_entry: Optional[str] = None
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
abs_path = os.path.abspath(tests[0].file_path)
st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests)
if st_path and st_func:
abs_path = os.path.abspath(st_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
else:
# Try to match evaluator_id to a discovered test's normalized ID
for t in tests:
func_name = t.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
if candidate == evaluator_id:
abs_path = os.path.abspath(t.file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
break
selected_entry = f"{rel}::{st_func}"
selected_test_file_path = st_path
selected_test_func_name = st_func
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
if selected_entry is None and len(tests) > 1:
print(
Expand Down Expand Up @@ -561,8 +515,8 @@ def create_rft_command(args) -> int:
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
return 1
else:
# Only persist last-used evaluator after successful ensure + ACTIVE
_save_last_evaluator(project_root, evaluator_id)
# Evaluator ACTIVE; proceed
pass
else:
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
except Exception as e:
Expand All @@ -575,30 +529,48 @@ def create_rft_command(args) -> int:
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow

if not dataset_id:
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test
# Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset.
if not dataset_jsonl:
tests = _discover_tests(project_root)
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
# Try data_loaders first (existing behavior)
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
# Use specifically selected test if available; else only infer when exactly one test exists
test_file_for_infer = None
func_for_infer = None
if selected_test_file_path and selected_test_func_name:
test_file_for_infer = selected_test_file_path
func_for_infer = selected_test_func_name
else:
tests = _discover_tests(project_root)
if len(tests) == 1:
test_file_for_infer = tests[0].file_path
func_for_infer = tests[0].qualname.split(".")[-1]
if test_file_for_infer and func_for_infer:
# Try data_loaders first
dataset_jsonl = _extract_jsonl_from_dataloader(test_file_for_infer, func_for_infer)
if dataset_jsonl:
# Display relative path for readability
try:
rel = os.path.relpath(dataset_jsonl, project_root)
except Exception:
rel = dataset_jsonl
print(f"✓ Using JSONL from data loader: {rel}")
else:
if not dataset_jsonl:
# Fall back to input_dataset (dataset_path)
dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name)
dataset_jsonl = _extract_jsonl_from_input_dataset(test_file_for_infer, func_for_infer)
if dataset_jsonl:
# Display relative path for readability
try:
rel = os.path.relpath(dataset_jsonl, project_root)
except Exception:
rel = dataset_jsonl
print(f"✓ Using JSONL from input_dataset: {rel}")
if not dataset_jsonl:
# Last resort: attempt to detect and run a dataset builder in the test's directory
metric_dir = os.path.dirname(test_file_for_infer)
builder_spec = detect_dataset_builder(metric_dir)
if builder_spec:
try:
tmp_jsonl, count = materialize_dataset_via_builder(builder_spec)
dataset_jsonl = tmp_jsonl
print(f"✓ Materialized {count} rows via dataset builder: {builder_spec}")
except Exception as e:
print(f"Warning: dataset builder failed: {e}")
if not dataset_jsonl:
print(
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
Expand Down
Loading
Loading