Skip to content

Commit 87a99e0

Browse files
authored
Create RFT on Fireworks directly (#277)
* Create RFT on Fireworks directly * kick off RFT in one command
1 parent 8ad7d26 commit 87a99e0

File tree

4 files changed

+561
-0
lines changed

4 files changed

+561
-0
lines changed

eval_protocol/cli.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,70 @@ def parse_args(args=None):
356356
help="Non-interactive: upload all discovered evaluation tests",
357357
)
358358

359+
# Create command group
360+
create_parser = subparsers.add_parser(
361+
"create",
362+
help="Resource creation commands",
363+
)
364+
create_subparsers = create_parser.add_subparsers(dest="create_command")
365+
rft_parser = create_subparsers.add_parser(
366+
"rft",
367+
help="Create a Reinforcement Fine-tuning Job on Fireworks",
368+
)
369+
rft_parser.add_argument(
370+
"--evaluator-id",
371+
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
372+
)
373+
# Dataset options
374+
rft_parser.add_argument(
375+
"--dataset-id",
376+
help="Use existing Fireworks dataset id (skip local materialization)",
377+
)
378+
rft_parser.add_argument(
379+
"--dataset-jsonl",
380+
help="Path to JSONL to upload as a new Fireworks dataset",
381+
)
382+
rft_parser.add_argument(
383+
"--dataset-builder",
384+
help="Explicit dataset builder spec (module::function or path::function)",
385+
)
386+
rft_parser.add_argument(
387+
"--dataset-display-name",
388+
help="Display name for dataset on Fireworks (defaults to dataset id)",
389+
)
390+
# Training config and evaluator/job settings
391+
rft_parser.add_argument("--base-model", help="Base model resource id")
392+
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
393+
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
394+
rft_parser.add_argument("--epochs", type=int)
395+
rft_parser.add_argument("--batch-size", type=int)
396+
rft_parser.add_argument("--learning-rate", type=float)
397+
rft_parser.add_argument("--max-context-length", type=int)
398+
rft_parser.add_argument("--lora-rank", type=int)
399+
rft_parser.add_argument("--accelerator-count", type=int)
400+
rft_parser.add_argument("--region", help="Fireworks region enum value")
401+
rft_parser.add_argument("--display-name", help="RFT job display name")
402+
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
403+
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
404+
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
405+
# Inference params
406+
rft_parser.add_argument("--temperature", type=float)
407+
rft_parser.add_argument("--top-p", type=float)
408+
rft_parser.add_argument("--top-k", type=int)
409+
rft_parser.add_argument("--max-tokens", type=int)
410+
rft_parser.add_argument("--n", type=int)
411+
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
412+
# Wandb
413+
rft_parser.add_argument("--wandb-enabled", action="store_true")
414+
rft_parser.add_argument("--wandb-project")
415+
rft_parser.add_argument("--wandb-entity")
416+
rft_parser.add_argument("--wandb-run-id")
417+
rft_parser.add_argument("--wandb-api-key")
418+
# Misc
419+
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
420+
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
421+
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
422+
359423
# Run command (for Hydra-based evaluations)
360424
# This subparser intentionally defines no arguments itself.
361425
# All arguments after 'run' will be passed to Hydra by parse_known_args.
@@ -481,6 +545,13 @@ def _extract_flag_value(argv_list, flag_name):
481545
from .cli_commands.upload import upload_command
482546

483547
return upload_command(args)
548+
elif args.command == "create":
549+
if args.create_command == "rft":
550+
from .cli_commands.create_rft import create_rft_command
551+
552+
return create_rft_command(args)
553+
print("Error: missing subcommand for 'create'. Try: eval-protocol create rft")
554+
return 1
484555
elif args.command == "run":
485556
# For the 'run' command, Hydra takes over argument parsing.
486557

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import json
2+
import os
3+
import sys
4+
from typing import Any, Dict, Optional
5+
6+
from ..auth import (
7+
get_fireworks_account_id,
8+
get_fireworks_api_base,
9+
get_fireworks_api_key,
10+
verify_api_key_and_get_account_id,
11+
)
12+
from ..fireworks_rft import (
13+
_map_api_host_to_app_host,
14+
build_default_dataset_id,
15+
build_default_output_model,
16+
create_dataset_from_jsonl,
17+
create_reinforcement_fine_tuning_job,
18+
detect_dataset_builder,
19+
load_evaluator_trace,
20+
materialize_dataset_via_builder,
21+
)
22+
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
23+
24+
25+
def _ensure_account_id() -> Optional[str]:
26+
account_id = get_fireworks_account_id()
27+
api_key = get_fireworks_api_key()
28+
if not account_id and api_key:
29+
resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base())
30+
if resolved:
31+
os.environ["FIREWORKS_ACCOUNT_ID"] = resolved
32+
return resolved
33+
return account_id
34+
35+
36+
def _extract_terminal_segment(resource_name: str) -> str:
37+
"""Return the last path segment if a fully-qualified resource name is provided."""
38+
try:
39+
return resource_name.strip("/").split("/")[-1]
40+
except Exception:
41+
return resource_name
42+
43+
44+
def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None:
45+
api_base = get_fireworks_api_base()
46+
app_base = _map_api_host_to_app_host(api_base)
47+
print("\n📊 Dashboard Links:")
48+
evaluator_slug = _extract_terminal_segment(evaluator_id)
49+
print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}")
50+
if dataset_id:
51+
print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}")
52+
if job_name:
53+
# job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id}
54+
try:
55+
job_id = job_name.strip().split("/")[-1]
56+
print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}")
57+
except Exception:
58+
pass
59+
60+
61+
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
62+
# Try local traces
63+
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
64+
if os.path.isdir(traces_dir):
65+
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
66+
if len(candidates) == 1:
67+
return candidates[0]
68+
# Fall back to discovering a single evaluation_test
69+
tests = _discover_tests(cwd)
70+
if len(tests) == 1:
71+
qualname, source_file_path = tests[0].qualname, tests[0].file_path
72+
test_func_name = qualname.split(".")[-1]
73+
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
74+
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
75+
return evaluator_id
76+
return None
77+
78+
79+
def create_rft_command(args) -> int:
80+
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
81+
non_interactive: bool = bool(getattr(args, "yes", False))
82+
dry_run: bool = bool(getattr(args, "dry_run", False))
83+
84+
api_key = get_fireworks_api_key()
85+
if not api_key:
86+
print("Error: FIREWORKS_API_KEY not set.")
87+
return 1
88+
89+
account_id = _ensure_account_id()
90+
if not account_id:
91+
print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.")
92+
return 1
93+
94+
api_base = get_fireworks_api_base()
95+
96+
# Resolve evaluator id if omitted
97+
project_root = os.getcwd()
98+
if not evaluator_id:
99+
evaluator_id = _auto_select_evaluator_id(project_root)
100+
if not evaluator_id:
101+
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
102+
return 1
103+
104+
# Resolve evaluator resource name via local trace
105+
# trace = load_evaluator_trace(project_root, evaluator_id)
106+
# if not trace or not isinstance(trace, dict):
107+
# print(
108+
# "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
109+
# )
110+
# return 1
111+
# evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
112+
evaluator_resource_name = evaluator_id
113+
114+
# Determine dataset id and materialization path
115+
dataset_id = getattr(args, "dataset_id", None)
116+
dataset_jsonl = getattr(args, "dataset_jsonl", None)
117+
dataset_display_name = getattr(args, "dataset_display_name", None)
118+
dataset_builder = getattr(args, "dataset_builder", None)
119+
120+
if not dataset_id:
121+
# Try builder from args, else from trace detection
122+
# TODO: build dataset from traces directly
123+
# builder_spec = dataset_builder or trace.get("dataset_builder")
124+
# if not builder_spec:
125+
# # Attempt detect from metric_dir
126+
# metric_dir = trace.get("metric_dir")
127+
# if metric_dir:
128+
# builder_spec = detect_dataset_builder(metric_dir)
129+
# if not builder_spec:
130+
# print(
131+
# "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
132+
# )
133+
# return 1
134+
# try:
135+
# dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
136+
# print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
137+
# except Exception as e:
138+
# print(f"Error: dataset builder failed: {e}")
139+
# return 1
140+
141+
if not dataset_jsonl:
142+
print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.")
143+
return 1
144+
145+
inferred_dataset_id = build_default_dataset_id(evaluator_id)
146+
if dry_run:
147+
print("--dry-run: would create dataset and upload JSONL")
148+
dataset_id = inferred_dataset_id
149+
else:
150+
try:
151+
dataset_id, _ = create_dataset_from_jsonl(
152+
account_id=account_id,
153+
api_key=api_key,
154+
api_base=api_base,
155+
dataset_id=inferred_dataset_id,
156+
display_name=dataset_display_name or inferred_dataset_id,
157+
jsonl_path=dataset_jsonl,
158+
)
159+
print(f"✓ Created and uploaded dataset: {dataset_id}")
160+
except Exception as e:
161+
print(f"Error creating/uploading dataset: {e}")
162+
return 1
163+
164+
# Build training config/body
165+
training_config: Dict[str, Any] = {}
166+
if getattr(args, "base_model", None):
167+
training_config["baseModel"] = args.base_model
168+
if getattr(args, "warm_start_from", None):
169+
training_config["warmStartFrom"] = args.warm_start_from
170+
if "baseModel" not in training_config and "warmStartFrom" not in training_config:
171+
# Provide a conservative default if neither is set
172+
training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct"
173+
174+
# Optional hyperparameters
175+
for key, arg_name in [
176+
("epochs", "epochs"),
177+
("batchSize", "batch_size"),
178+
("learningRate", "learning_rate"),
179+
("maxContextLength", "max_context_length"),
180+
("loraRank", "lora_rank"),
181+
("acceleratorCount", "accelerator_count"),
182+
("region", "region"),
183+
]:
184+
val = getattr(args, arg_name, None)
185+
if val is not None:
186+
training_config[key] = val
187+
188+
inference_params: Dict[str, Any] = {}
189+
for key, arg_name in [
190+
("temperature", "temperature"),
191+
("topP", "top_p"),
192+
("topK", "top_k"),
193+
("maxTokens", "max_tokens"),
194+
("n", "n"),
195+
]:
196+
val = getattr(args, arg_name, None)
197+
if val is not None:
198+
inference_params[key] = val
199+
if getattr(args, "inference_extra_body", None):
200+
inference_params["extraBody"] = args.inference_extra_body
201+
202+
wandb_config: Optional[Dict[str, Any]] = None
203+
if getattr(args, "wandb_enabled", False):
204+
wandb_config = {
205+
"enabled": True,
206+
"apiKey": getattr(args, "wandb_api_key", None),
207+
"project": getattr(args, "wandb_project", None),
208+
"entity": getattr(args, "wandb_entity", None),
209+
"runId": getattr(args, "wandb_run_id", None),
210+
}
211+
212+
body: Dict[str, Any] = {
213+
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
214+
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
215+
"evaluator": evaluator_resource_name,
216+
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
217+
"trainingConfig": training_config,
218+
"inferenceParameters": inference_params or None,
219+
"wandbConfig": wandb_config,
220+
"outputStats": None,
221+
"outputMetrics": None,
222+
"mcpServer": None,
223+
}
224+
print("Show body:")
225+
print(json.dumps(body, indent=2))
226+
if getattr(args, "evaluation_dataset", None):
227+
body["evaluationDataset"] = args.evaluation_dataset
228+
if getattr(args, "output_model", None):
229+
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}"
230+
else:
231+
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)
232+
233+
# Clean None fields to avoid noisy payloads
234+
body = {k: v for k, v in body.items() if v is not None}
235+
236+
if dry_run:
237+
print("--dry-run: would create RFT job with body:")
238+
print(json.dumps(body, indent=2))
239+
_print_links(evaluator_id, dataset_id, None)
240+
return 0
241+
242+
try:
243+
result = create_reinforcement_fine_tuning_job(
244+
account_id=account_id, api_key=api_key, api_base=api_base, body=body
245+
)
246+
job_name = result.get("name") if isinstance(result, dict) else None
247+
print("\n✅ Created Reinforcement Fine-tuning Job")
248+
if job_name:
249+
print(f" name: {job_name}")
250+
_print_links(evaluator_id, dataset_id, job_name)
251+
return 0
252+
except Exception as e:
253+
print(f"Error creating RFT job: {e}")
254+
return 1

eval_protocol/cli_commands/upload.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from eval_protocol.platform_api import create_or_update_fireworks_secret
2222

2323
from eval_protocol.evaluation import create_evaluation
24+
from eval_protocol.fireworks_rft import save_evaluator_trace, detect_dataset_builder
2425

2526

2627
@dataclass
@@ -666,6 +667,23 @@ def upload_command(args: argparse.Namespace) -> int:
666667
)
667668
name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id
668669

670+
# Persist local evaluator trace for later `create rft`
671+
try:
672+
metric_dir = os.path.dirname(source_file_path) if source_file_path else root
673+
builder_spec = detect_dataset_builder(metric_dir) or None
674+
trace_payload = {
675+
"evaluator_id": evaluator_id,
676+
"evaluator_resource_name": name,
677+
"entry_point": entry_point,
678+
"metric_dir": metric_dir,
679+
"project_root": root,
680+
"dataset_builder": builder_spec,
681+
}
682+
save_evaluator_trace(project_root=root, evaluator_id=evaluator_id, trace=trace_payload)
683+
except Exception:
684+
# Non-fatal; continue
685+
pass
686+
669687
# Print success message with Fireworks dashboard link
670688
print(f"\n✅ Successfully uploaded evaluator: {evaluator_id}")
671689
print("📊 View in Fireworks Dashboard:")

0 commit comments

Comments
 (0)