Skip to content

Commit 4d4db1f

Browse files
committed
kick off RFT in one command
1 parent 763116e commit 4d4db1f

File tree

2 files changed

+73
-33
lines changed

2 files changed

+73
-33
lines changed

eval_protocol/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,7 @@ def parse_args(args=None):
368368
)
369369
rft_parser.add_argument(
370370
"--evaluator-id",
371-
required=True,
372-
help="Evaluator ID used during upload; resolves evaluator resource via local trace",
371+
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
373372
)
374373
# Dataset options
375374
rft_parser.add_argument(

eval_protocol/cli_commands/create_rft.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
load_evaluator_trace,
2020
materialize_dataset_via_builder,
2121
)
22+
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2223

2324

2425
def _ensure_account_id() -> Optional[str]:
@@ -32,24 +33,51 @@ def _ensure_account_id() -> Optional[str]:
3233
return account_id
3334

3435

36+
def _extract_terminal_segment(resource_name: str) -> str:
37+
"""Return the last path segment if a fully-qualified resource name is provided."""
38+
try:
39+
return resource_name.strip("/").split("/")[-1]
40+
except Exception:
41+
return resource_name
42+
43+
3544
def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None:
3645
api_base = get_fireworks_api_base()
3746
app_base = _map_api_host_to_app_host(api_base)
3847
print("\n📊 Dashboard Links:")
39-
print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_id}")
48+
evaluator_slug = _extract_terminal_segment(evaluator_id)
49+
print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}")
4050
if dataset_id:
4151
print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}")
4252
if job_name:
4353
# job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id}
4454
try:
4555
job_id = job_name.strip().split("/")[-1]
46-
print(f" RFT Job: {app_base}/dashboard/rft/{job_id}")
56+
print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}")
4757
except Exception:
4858
pass
4959

5060

61+
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
62+
# Try local traces
63+
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
64+
if os.path.isdir(traces_dir):
65+
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
66+
if len(candidates) == 1:
67+
return candidates[0]
68+
# Fall back to discovering a single evaluation_test
69+
tests = _discover_tests(cwd)
70+
if len(tests) == 1:
71+
qualname, source_file_path = tests[0].qualname, tests[0].file_path
72+
test_func_name = qualname.split(".")[-1]
73+
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
74+
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
75+
return evaluator_id
76+
return None
77+
78+
5179
def create_rft_command(args) -> int:
52-
evaluator_id: str = getattr(args, "evaluator_id")
80+
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
5381
non_interactive: bool = bool(getattr(args, "yes", False))
5482
dry_run: bool = bool(getattr(args, "dry_run", False))
5583

@@ -65,15 +93,23 @@ def create_rft_command(args) -> int:
6593

6694
api_base = get_fireworks_api_base()
6795

68-
# Resolve evaluator resource name via local trace
96+
# Resolve evaluator id if omitted
6997
project_root = os.getcwd()
70-
trace = load_evaluator_trace(project_root, evaluator_id)
71-
if not trace or not isinstance(trace, dict):
72-
print(
73-
"Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
74-
)
75-
return 1
76-
evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
98+
if not evaluator_id:
99+
evaluator_id = _auto_select_evaluator_id(project_root)
100+
if not evaluator_id:
101+
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
102+
return 1
103+
104+
# Resolve evaluator resource name via local trace
105+
# trace = load_evaluator_trace(project_root, evaluator_id)
106+
# if not trace or not isinstance(trace, dict):
107+
# print(
108+
# "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
109+
# )
110+
# return 1
111+
# evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
112+
evaluator_resource_name = evaluator_id
77113

78114
# Determine dataset id and materialization path
79115
dataset_id = getattr(args, "dataset_id", None)
@@ -83,24 +119,27 @@ def create_rft_command(args) -> int:
83119

84120
if not dataset_id:
85121
if not dataset_jsonl:
122+
print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.")
123+
return 1
86124
# Try builder from args, else from trace detection
87-
builder_spec = dataset_builder or trace.get("dataset_builder")
88-
if not builder_spec:
89-
# Attempt detect from metric_dir
90-
metric_dir = trace.get("metric_dir")
91-
if metric_dir:
92-
builder_spec = detect_dataset_builder(metric_dir)
93-
if not builder_spec:
94-
print(
95-
"Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
96-
)
97-
return 1
98-
try:
99-
dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
100-
print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
101-
except Exception as e:
102-
print(f"Error: dataset builder failed: {e}")
103-
return 1
125+
# TODO: build dataset from traces directly
126+
# builder_spec = dataset_builder or trace.get("dataset_builder")
127+
# if not builder_spec:
128+
# # Attempt detect from metric_dir
129+
# metric_dir = trace.get("metric_dir")
130+
# if metric_dir:
131+
# builder_spec = detect_dataset_builder(metric_dir)
132+
# if not builder_spec:
133+
# print(
134+
# "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
135+
# )
136+
# return 1
137+
# try:
138+
# dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
139+
# print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
140+
# except Exception as e:
141+
# print(f"Error: dataset builder failed: {e}")
142+
# return 1
104143

105144
inferred_dataset_id = build_default_dataset_id(evaluator_id)
106145
if dry_run:
@@ -170,8 +209,8 @@ def create_rft_command(args) -> int:
170209
}
171210

172211
body: Dict[str, Any] = {
173-
"displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
174-
"dataset": dataset_id,
212+
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
213+
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
175214
"evaluator": evaluator_resource_name,
176215
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
177216
"trainingConfig": training_config,
@@ -181,10 +220,12 @@ def create_rft_command(args) -> int:
181220
"outputMetrics": None,
182221
"mcpServer": None,
183222
}
223+
print("Show body:")
224+
print(json.dumps(body, indent=2))
184225
if getattr(args, "evaluation_dataset", None):
185226
body["evaluationDataset"] = args.evaluation_dataset
186227
if getattr(args, "output_model", None):
187-
body.setdefault("trainingConfig", {})["outputModel"] = args.output_model
228+
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}"
188229
else:
189230
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)
190231

0 commit comments

Comments
 (0)