Skip to content

Commit 97b5dc5

Browse files
committed
kick off RFT in one command
1 parent fda1ba0 commit 97b5dc5

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

eval_protocol/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,7 @@ def parse_args(args=None):
358358
)
359359
rft_parser.add_argument(
360360
"--evaluator-id",
361-
required=True,
362-
help="Evaluator ID used during upload; resolves evaluator resource via local trace",
361+
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
363362
)
364363
# Dataset options
365364
rft_parser.add_argument(

eval_protocol/cli_commands/create_rft.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
load_evaluator_trace,
2020
materialize_dataset_via_builder,
2121
)
22+
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2223

2324

2425
def _ensure_account_id() -> Optional[str]:
@@ -48,8 +49,26 @@ def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) ->
4849
pass
4950

5051

52+
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
53+
# Try local traces
54+
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
55+
if os.path.isdir(traces_dir):
56+
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
57+
if len(candidates) == 1:
58+
return candidates[0]
59+
# Fall back to discovering a single evaluation_test
60+
tests = _discover_tests(cwd)
61+
if len(tests) == 1:
62+
qualname, source_file_path = tests[0].qualname, tests[0].file_path
63+
test_func_name = qualname.split(".")[-1]
64+
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
65+
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
66+
return evaluator_id
67+
return None
68+
69+
5170
def create_rft_command(args) -> int:
52-
evaluator_id: str = getattr(args, "evaluator_id")
71+
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
5372
non_interactive: bool = bool(getattr(args, "yes", False))
5473
dry_run: bool = bool(getattr(args, "dry_run", False))
5574

@@ -65,8 +84,15 @@ def create_rft_command(args) -> int:
6584

6685
api_base = get_fireworks_api_base()
6786

68-
# Resolve evaluator resource name via local trace
87+
# Resolve evaluator id if omitted
6988
project_root = os.getcwd()
89+
if not evaluator_id:
90+
evaluator_id = _auto_select_evaluator_id(project_root)
91+
if not evaluator_id:
92+
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
93+
return 1
94+
95+
# Resolve evaluator resource name via local trace
7096
trace = load_evaluator_trace(project_root, evaluator_id)
7197
if not trace or not isinstance(trace, dict):
7298
print(

0 commit comments

Comments
 (0)