Skip to content

Commit fda1ba0

Browse files
committed
Create RFT on Fireworks directly
1 parent a3baa0a commit fda1ba0

File tree

4 files changed

+520
-0
lines changed

4 files changed

+520
-0
lines changed

eval_protocol/cli.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,71 @@ def parse_args(args=None):
346346
help="Non-interactive: upload all discovered evaluation tests",
347347
)
348348

349+
# Create command group
350+
create_parser = subparsers.add_parser(
351+
"create",
352+
help="Resource creation commands",
353+
)
354+
create_subparsers = create_parser.add_subparsers(dest="create_command")
355+
rft_parser = create_subparsers.add_parser(
356+
"rft",
357+
help="Create a Reinforcement Fine-tuning Job on Fireworks",
358+
)
359+
rft_parser.add_argument(
360+
"--evaluator-id",
361+
required=True,
362+
help="Evaluator ID used during upload; resolves evaluator resource via local trace",
363+
)
364+
# Dataset options
365+
rft_parser.add_argument(
366+
"--dataset-id",
367+
help="Use existing Fireworks dataset id (skip local materialization)",
368+
)
369+
rft_parser.add_argument(
370+
"--dataset-jsonl",
371+
help="Path to JSONL to upload as a new Fireworks dataset",
372+
)
373+
rft_parser.add_argument(
374+
"--dataset-builder",
375+
help="Explicit dataset builder spec (module::function or path::function)",
376+
)
377+
rft_parser.add_argument(
378+
"--dataset-display-name",
379+
help="Display name for dataset on Fireworks (defaults to dataset id)",
380+
)
381+
# Training config and evaluator/job settings
382+
rft_parser.add_argument("--base-model", help="Base model resource id")
383+
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
384+
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
385+
rft_parser.add_argument("--epochs", type=int)
386+
rft_parser.add_argument("--batch-size", type=int)
387+
rft_parser.add_argument("--learning-rate", type=float)
388+
rft_parser.add_argument("--max-context-length", type=int)
389+
rft_parser.add_argument("--lora-rank", type=int)
390+
rft_parser.add_argument("--accelerator-count", type=int)
391+
rft_parser.add_argument("--region", help="Fireworks region enum value")
392+
rft_parser.add_argument("--display-name", help="RFT job display name")
393+
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
394+
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
395+
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
396+
# Inference params
397+
rft_parser.add_argument("--temperature", type=float)
398+
rft_parser.add_argument("--top-p", type=float)
399+
rft_parser.add_argument("--top-k", type=int)
400+
rft_parser.add_argument("--max-tokens", type=int)
401+
rft_parser.add_argument("--n", type=int)
402+
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
403+
# Wandb
404+
rft_parser.add_argument("--wandb-enabled", action="store_true")
405+
rft_parser.add_argument("--wandb-project")
406+
rft_parser.add_argument("--wandb-entity")
407+
rft_parser.add_argument("--wandb-run-id")
408+
rft_parser.add_argument("--wandb-api-key")
409+
# Misc
410+
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
411+
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
412+
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
413+
349414
# Run command (for Hydra-based evaluations)
350415
# This subparser intentionally defines no arguments itself.
351416
# All arguments after 'run' will be passed to Hydra by parse_known_args.
@@ -471,6 +536,13 @@ def _extract_flag_value(argv_list, flag_name):
471536
from .cli_commands.upload import upload_command
472537

473538
return upload_command(args)
539+
elif args.command == "create":
540+
if args.create_command == "rft":
541+
from .cli_commands.create_rft import create_rft_command
542+
543+
return create_rft_command(args)
544+
print("Error: missing subcommand for 'create'. Try: eval-protocol create rft")
545+
return 1
474546
elif args.command == "run":
475547
# For the 'run' command, Hydra takes over argument parsing.
476548

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import json
2+
import os
3+
import sys
4+
from typing import Any, Dict, Optional
5+
6+
from ..auth import (
7+
get_fireworks_account_id,
8+
get_fireworks_api_base,
9+
get_fireworks_api_key,
10+
verify_api_key_and_get_account_id,
11+
)
12+
from ..fireworks_rft import (
13+
_map_api_host_to_app_host,
14+
build_default_dataset_id,
15+
build_default_output_model,
16+
create_dataset_from_jsonl,
17+
create_reinforcement_fine_tuning_job,
18+
detect_dataset_builder,
19+
load_evaluator_trace,
20+
materialize_dataset_via_builder,
21+
)
22+
23+
24+
def _ensure_account_id() -> Optional[str]:
25+
account_id = get_fireworks_account_id()
26+
api_key = get_fireworks_api_key()
27+
if not account_id and api_key:
28+
resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base())
29+
if resolved:
30+
os.environ["FIREWORKS_ACCOUNT_ID"] = resolved
31+
return resolved
32+
return account_id
33+
34+
35+
def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None:
36+
api_base = get_fireworks_api_base()
37+
app_base = _map_api_host_to_app_host(api_base)
38+
print("\n📊 Dashboard Links:")
39+
print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_id}")
40+
if dataset_id:
41+
print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}")
42+
if job_name:
43+
# job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id}
44+
try:
45+
job_id = job_name.strip().split("/")[-1]
46+
print(f" RFT Job: {app_base}/dashboard/rft/{job_id}")
47+
except Exception:
48+
pass
49+
50+
51+
def create_rft_command(args) -> int:
52+
evaluator_id: str = getattr(args, "evaluator_id")
53+
non_interactive: bool = bool(getattr(args, "yes", False))
54+
dry_run: bool = bool(getattr(args, "dry_run", False))
55+
56+
api_key = get_fireworks_api_key()
57+
if not api_key:
58+
print("Error: FIREWORKS_API_KEY not set.")
59+
return 1
60+
61+
account_id = _ensure_account_id()
62+
if not account_id:
63+
print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.")
64+
return 1
65+
66+
api_base = get_fireworks_api_base()
67+
68+
# Resolve evaluator resource name via local trace
69+
project_root = os.getcwd()
70+
trace = load_evaluator_trace(project_root, evaluator_id)
71+
if not trace or not isinstance(trace, dict):
72+
print(
73+
"Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
74+
)
75+
return 1
76+
evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
77+
78+
# Determine dataset id and materialization path
79+
dataset_id = getattr(args, "dataset_id", None)
80+
dataset_jsonl = getattr(args, "dataset_jsonl", None)
81+
dataset_display_name = getattr(args, "dataset_display_name", None)
82+
dataset_builder = getattr(args, "dataset_builder", None)
83+
84+
if not dataset_id:
85+
if not dataset_jsonl:
86+
# Try builder from args, else from trace detection
87+
builder_spec = dataset_builder or trace.get("dataset_builder")
88+
if not builder_spec:
89+
# Attempt detect from metric_dir
90+
metric_dir = trace.get("metric_dir")
91+
if metric_dir:
92+
builder_spec = detect_dataset_builder(metric_dir)
93+
if not builder_spec:
94+
print(
95+
"Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
96+
)
97+
return 1
98+
try:
99+
dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
100+
print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
101+
except Exception as e:
102+
print(f"Error: dataset builder failed: {e}")
103+
return 1
104+
105+
inferred_dataset_id = build_default_dataset_id(evaluator_id)
106+
if dry_run:
107+
print("--dry-run: would create dataset and upload JSONL")
108+
dataset_id = inferred_dataset_id
109+
else:
110+
try:
111+
dataset_id, _ = create_dataset_from_jsonl(
112+
account_id=account_id,
113+
api_key=api_key,
114+
api_base=api_base,
115+
dataset_id=inferred_dataset_id,
116+
display_name=dataset_display_name or inferred_dataset_id,
117+
jsonl_path=dataset_jsonl,
118+
)
119+
print(f"✓ Created and uploaded dataset: {dataset_id}")
120+
except Exception as e:
121+
print(f"Error creating/uploading dataset: {e}")
122+
return 1
123+
124+
# Build training config/body
125+
training_config: Dict[str, Any] = {}
126+
if getattr(args, "base_model", None):
127+
training_config["baseModel"] = args.base_model
128+
if getattr(args, "warm_start_from", None):
129+
training_config["warmStartFrom"] = args.warm_start_from
130+
if "baseModel" not in training_config and "warmStartFrom" not in training_config:
131+
# Provide a conservative default if neither is set
132+
training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct"
133+
134+
# Optional hyperparameters
135+
for key, arg_name in [
136+
("epochs", "epochs"),
137+
("batchSize", "batch_size"),
138+
("learningRate", "learning_rate"),
139+
("maxContextLength", "max_context_length"),
140+
("loraRank", "lora_rank"),
141+
("acceleratorCount", "accelerator_count"),
142+
("region", "region"),
143+
]:
144+
val = getattr(args, arg_name, None)
145+
if val is not None:
146+
training_config[key] = val
147+
148+
inference_params: Dict[str, Any] = {}
149+
for key, arg_name in [
150+
("temperature", "temperature"),
151+
("topP", "top_p"),
152+
("topK", "top_k"),
153+
("maxTokens", "max_tokens"),
154+
("n", "n"),
155+
]:
156+
val = getattr(args, arg_name, None)
157+
if val is not None:
158+
inference_params[key] = val
159+
if getattr(args, "inference_extra_body", None):
160+
inference_params["extraBody"] = args.inference_extra_body
161+
162+
wandb_config: Optional[Dict[str, Any]] = None
163+
if getattr(args, "wandb_enabled", False):
164+
wandb_config = {
165+
"enabled": True,
166+
"apiKey": getattr(args, "wandb_api_key", None),
167+
"project": getattr(args, "wandb_project", None),
168+
"entity": getattr(args, "wandb_entity", None),
169+
"runId": getattr(args, "wandb_run_id", None),
170+
}
171+
172+
body: Dict[str, Any] = {
173+
"displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
174+
"dataset": dataset_id,
175+
"evaluator": evaluator_resource_name,
176+
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
177+
"trainingConfig": training_config,
178+
"inferenceParameters": inference_params or None,
179+
"wandbConfig": wandb_config,
180+
"outputStats": None,
181+
"outputMetrics": None,
182+
"mcpServer": None,
183+
}
184+
if getattr(args, "evaluation_dataset", None):
185+
body["evaluationDataset"] = args.evaluation_dataset
186+
if getattr(args, "output_model", None):
187+
body.setdefault("trainingConfig", {})["outputModel"] = args.output_model
188+
else:
189+
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)
190+
191+
# Clean None fields to avoid noisy payloads
192+
body = {k: v for k, v in body.items() if v is not None}
193+
194+
if dry_run:
195+
print("--dry-run: would create RFT job with body:")
196+
print(json.dumps(body, indent=2))
197+
_print_links(evaluator_id, dataset_id, None)
198+
return 0
199+
200+
try:
201+
result = create_reinforcement_fine_tuning_job(
202+
account_id=account_id, api_key=api_key, api_base=api_base, body=body
203+
)
204+
job_name = result.get("name") if isinstance(result, dict) else None
205+
print("\n✅ Created Reinforcement Fine-tuning Job")
206+
if job_name:
207+
print(f" name: {job_name}")
208+
_print_links(evaluator_id, dataset_id, job_name)
209+
return 0
210+
except Exception as e:
211+
print(f"Error creating RFT job: {e}")
212+
return 1

eval_protocol/cli_commands/upload.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from eval_protocol.platform_api import create_or_update_fireworks_secret
2222

2323
from eval_protocol.evaluation import create_evaluation
24+
from eval_protocol.fireworks_rft import save_evaluator_trace, detect_dataset_builder
2425

2526

2627
@dataclass
@@ -604,6 +605,23 @@ def upload_command(args: argparse.Namespace) -> int:
604605
)
605606
name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id
606607

608+
# Persist local evaluator trace for later `create rft`
609+
try:
610+
metric_dir = os.path.dirname(source_file_path) if source_file_path else root
611+
builder_spec = detect_dataset_builder(metric_dir) or None
612+
trace_payload = {
613+
"evaluator_id": evaluator_id,
614+
"evaluator_resource_name": name,
615+
"entry_point": entry_point,
616+
"metric_dir": metric_dir,
617+
"project_root": root,
618+
"dataset_builder": builder_spec,
619+
}
620+
save_evaluator_trace(project_root=root, evaluator_id=evaluator_id, trace=trace_payload)
621+
except Exception:
622+
# Non-fatal; continue
623+
pass
624+
607625
# Print success message with Fireworks dashboard link
608626
print(f"\n✅ Successfully uploaded evaluator: {evaluator_id}")
609627
print("📊 View in Fireworks Dashboard:")

0 commit comments

Comments
 (0)