Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,70 @@ def parse_args(args=None):
help="Non-interactive: upload all discovered evaluation tests",
)

# Create command group
create_parser = subparsers.add_parser(
"create",
help="Resource creation commands",
)
create_subparsers = create_parser.add_subparsers(dest="create_command")
rft_parser = create_subparsers.add_parser(
"rft",
help="Create a Reinforcement Fine-tuning Job on Fireworks",
)
rft_parser.add_argument(
"--evaluator-id",
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
)
# Dataset options
rft_parser.add_argument(
"--dataset-id",
help="Use existing Fireworks dataset id (skip local materialization)",
)
rft_parser.add_argument(
"--dataset-jsonl",
help="Path to JSONL to upload as a new Fireworks dataset",
)
rft_parser.add_argument(
"--dataset-builder",
help="Explicit dataset builder spec (module::function or path::function)",
)
rft_parser.add_argument(
"--dataset-display-name",
help="Display name for dataset on Fireworks (defaults to dataset id)",
)
# Training config and evaluator/job settings
rft_parser.add_argument("--base-model", help="Base model resource id")
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
rft_parser.add_argument("--epochs", type=int)
rft_parser.add_argument("--batch-size", type=int)
rft_parser.add_argument("--learning-rate", type=float)
rft_parser.add_argument("--max-context-length", type=int)
rft_parser.add_argument("--lora-rank", type=int)
rft_parser.add_argument("--accelerator-count", type=int)
rft_parser.add_argument("--region", help="Fireworks region enum value")
rft_parser.add_argument("--display-name", help="RFT job display name")
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
# Inference params
rft_parser.add_argument("--temperature", type=float)
rft_parser.add_argument("--top-p", type=float)
rft_parser.add_argument("--top-k", type=int)
rft_parser.add_argument("--max-tokens", type=int)
rft_parser.add_argument("--n", type=int)
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true")
rft_parser.add_argument("--wandb-project")
rft_parser.add_argument("--wandb-entity")
rft_parser.add_argument("--wandb-run-id")
rft_parser.add_argument("--wandb-api-key")
# Misc
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")

# Run command (for Hydra-based evaluations)
# This subparser intentionally defines no arguments itself.
# All arguments after 'run' will be passed to Hydra by parse_known_args.
Expand Down Expand Up @@ -481,6 +545,13 @@ def _extract_flag_value(argv_list, flag_name):
from .cli_commands.upload import upload_command

return upload_command(args)
elif args.command == "create":
if args.create_command == "rft":
from .cli_commands.create_rft import create_rft_command

return create_rft_command(args)
print("Error: missing subcommand for 'create'. Try: eval-protocol create rft")
return 1
elif args.command == "run":
# For the 'run' command, Hydra takes over argument parsing.

Expand Down
254 changes: 254 additions & 0 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import json
import os
import sys
from typing import Any, Dict, Optional

from ..auth import (
get_fireworks_account_id,
get_fireworks_api_base,
get_fireworks_api_key,
verify_api_key_and_get_account_id,
)
from ..fireworks_rft import (
_map_api_host_to_app_host,
build_default_dataset_id,
build_default_output_model,
create_dataset_from_jsonl,
create_reinforcement_fine_tuning_job,
detect_dataset_builder,
load_evaluator_trace,
materialize_dataset_via_builder,
)
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source


def _ensure_account_id() -> Optional[str]:
account_id = get_fireworks_account_id()
api_key = get_fireworks_api_key()
if not account_id and api_key:
resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base())
if resolved:
os.environ["FIREWORKS_ACCOUNT_ID"] = resolved
return resolved
return account_id


def _extract_terminal_segment(resource_name: str) -> str:
"""Return the last path segment if a fully-qualified resource name is provided."""
try:
return resource_name.strip("/").split("/")[-1]
except Exception:
return resource_name


def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None:
api_base = get_fireworks_api_base()
app_base = _map_api_host_to_app_host(api_base)
print("\n📊 Dashboard Links:")
evaluator_slug = _extract_terminal_segment(evaluator_id)
print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}")
if dataset_id:
print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}")
if job_name:
# job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id}
try:
job_id = job_name.strip().split("/")[-1]
print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}")
except Exception:
pass


def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
# Try local traces
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
if os.path.isdir(traces_dir):
candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")]
if len(candidates) == 1:
return candidates[0]
# Fall back to discovering a single evaluation_test
tests = _discover_tests(cwd)
if len(tests) == 1:
qualname, source_file_path = tests[0].qualname, tests[0].file_path
test_func_name = qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(source_file_path))[0]
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}")
return evaluator_id
return None


def create_rft_command(args) -> int:
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
non_interactive: bool = bool(getattr(args, "yes", False))
dry_run: bool = bool(getattr(args, "dry_run", False))

api_key = get_fireworks_api_key()
if not api_key:
print("Error: FIREWORKS_API_KEY not set.")
return 1

account_id = _ensure_account_id()
if not account_id:
print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.")
return 1

api_base = get_fireworks_api_base()

# Resolve evaluator id if omitted
project_root = os.getcwd()
if not evaluator_id:
evaluator_id = _auto_select_evaluator_id(project_root)
if not evaluator_id:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
return 1

# Resolve evaluator resource name via local trace
# trace = load_evaluator_trace(project_root, evaluator_id)
# if not trace or not isinstance(trace, dict):
# print(
# "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
# )
# return 1
# evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
evaluator_resource_name = evaluator_id

# Determine dataset id and materialization path
dataset_id = getattr(args, "dataset_id", None)
dataset_jsonl = getattr(args, "dataset_jsonl", None)
dataset_display_name = getattr(args, "dataset_display_name", None)
dataset_builder = getattr(args, "dataset_builder", None)

if not dataset_id:
# Try builder from args, else from trace detection
# TODO: build dataset from traces directly
# builder_spec = dataset_builder or trace.get("dataset_builder")
# if not builder_spec:
# # Attempt detect from metric_dir
# metric_dir = trace.get("metric_dir")
# if metric_dir:
# builder_spec = detect_dataset_builder(metric_dir)
# if not builder_spec:
# print(
# "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
# )
# return 1
# try:
# dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
# print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
# except Exception as e:
# print(f"Error: dataset builder failed: {e}")
# return 1

if not dataset_jsonl:
print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.")
return 1

inferred_dataset_id = build_default_dataset_id(evaluator_id)
if dry_run:
print("--dry-run: would create dataset and upload JSONL")
dataset_id = inferred_dataset_id
else:
try:
dataset_id, _ = create_dataset_from_jsonl(
account_id=account_id,
api_key=api_key,
api_base=api_base,
dataset_id=inferred_dataset_id,
display_name=dataset_display_name or inferred_dataset_id,
jsonl_path=dataset_jsonl,
)
print(f"✓ Created and uploaded dataset: {dataset_id}")
except Exception as e:
print(f"Error creating/uploading dataset: {e}")
return 1

# Build training config/body
training_config: Dict[str, Any] = {}
if getattr(args, "base_model", None):
training_config["baseModel"] = args.base_model
if getattr(args, "warm_start_from", None):
training_config["warmStartFrom"] = args.warm_start_from
if "baseModel" not in training_config and "warmStartFrom" not in training_config:
# Provide a conservative default if neither is set
training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct"

# Optional hyperparameters
for key, arg_name in [
("epochs", "epochs"),
("batchSize", "batch_size"),
("learningRate", "learning_rate"),
("maxContextLength", "max_context_length"),
("loraRank", "lora_rank"),
("acceleratorCount", "accelerator_count"),
("region", "region"),
]:
val = getattr(args, arg_name, None)
if val is not None:
training_config[key] = val

inference_params: Dict[str, Any] = {}
for key, arg_name in [
("temperature", "temperature"),
("topP", "top_p"),
("topK", "top_k"),
("maxTokens", "max_tokens"),
("n", "n"),
]:
val = getattr(args, arg_name, None)
if val is not None:
inference_params[key] = val
if getattr(args, "inference_extra_body", None):
inference_params["extraBody"] = args.inference_extra_body

wandb_config: Optional[Dict[str, Any]] = None
if getattr(args, "wandb_enabled", False):
wandb_config = {
"enabled": True,
"apiKey": getattr(args, "wandb_api_key", None),
"project": getattr(args, "wandb_project", None),
"entity": getattr(args, "wandb_entity", None),
"runId": getattr(args, "wandb_run_id", None),
}

body: Dict[str, Any] = {
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
"evaluator": evaluator_resource_name,
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
"trainingConfig": training_config,
"inferenceParameters": inference_params or None,
"wandbConfig": wandb_config,
"outputStats": None,
"outputMetrics": None,
"mcpServer": None,
}
print("Show body:")
print(json.dumps(body, indent=2))
if getattr(args, "evaluation_dataset", None):
body["evaluationDataset"] = args.evaluation_dataset
if getattr(args, "output_model", None):
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}"
else:
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)

# Clean None fields to avoid noisy payloads
body = {k: v for k, v in body.items() if v is not None}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Debug Print Shows Incomplete Request Body

The debug print of the request body (lines 224-225) occurs before the body is fully constructed. After printing, the code adds evaluationDataset and outputModel fields (lines 226-231) and filters out None values (line 234). This means the printed body shows an incomplete/incorrect preview that doesn't match what will actually be sent to the API. The print statement should be moved after line 234 to show the final request body.

Fix in Cursor Fix in Web


if dry_run:
print("--dry-run: would create RFT job with body:")
print(json.dumps(body, indent=2))
_print_links(evaluator_id, dataset_id, None)
return 0

try:
result = create_reinforcement_fine_tuning_job(
account_id=account_id, api_key=api_key, api_base=api_base, body=body
)
job_name = result.get("name") if isinstance(result, dict) else None
print("\n✅ Created Reinforcement Fine-tuning Job")
if job_name:
print(f" name: {job_name}")
_print_links(evaluator_id, dataset_id, job_name)
return 0
except Exception as e:
print(f"Error creating RFT job: {e}")
return 1
18 changes: 18 additions & 0 deletions eval_protocol/cli_commands/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from eval_protocol.platform_api import create_or_update_fireworks_secret

from eval_protocol.evaluation import create_evaluation
from eval_protocol.fireworks_rft import save_evaluator_trace, detect_dataset_builder


@dataclass
Expand Down Expand Up @@ -666,6 +667,23 @@ def upload_command(args: argparse.Namespace) -> int:
)
name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id

# Persist local evaluator trace for later `create rft`
try:
metric_dir = os.path.dirname(source_file_path) if source_file_path else root
builder_spec = detect_dataset_builder(metric_dir) or None
trace_payload = {
"evaluator_id": evaluator_id,
"evaluator_resource_name": name,
"entry_point": entry_point,
"metric_dir": metric_dir,
"project_root": root,
"dataset_builder": builder_spec,
}
save_evaluator_trace(project_root=root, evaluator_id=evaluator_id, trace=trace_payload)
except Exception:
# Non-fatal; continue
pass

# Print success message with Fireworks dashboard link
print(f"\n✅ Successfully uploaded evaluator: {evaluator_id}")
print("📊 View in Fireworks Dashboard:")
Expand Down
Loading
Loading