Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,13 @@ def parse_args(args=None):
help="Create a Reinforcement Fine-tuning Job on Fireworks",
)
rft_parser.add_argument(
"--evaluator-id",
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
"--evaluator",
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
)
# Dataset options
rft_parser.add_argument(
"--dataset-id",
help="Use existing Fireworks dataset id (skip local materialization)",
"--dataset",
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
)
rft_parser.add_argument(
"--dataset-jsonl",
Expand All @@ -400,6 +400,8 @@ def parse_args(args=None):
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
rft_parser.add_argument("--max-context-length", type=int, default=65536)
rft_parser.add_argument("--lora-rank", type=int, default=16)
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
rft_parser.add_argument("--accelerator-count", type=int, default=1)
rft_parser.add_argument("--region", help="Fireworks region enum value")
rft_parser.add_argument("--display-name", help="RFT job display name")
Expand All @@ -412,17 +414,22 @@ def parse_args(args=None):
rft_parser.add_argument("--temperature", type=float)
rft_parser.add_argument("--top-p", type=float)
rft_parser.add_argument("--top-k", type=int)
rft_parser.add_argument("--max-tokens", type=int, default=32768)
rft_parser.add_argument("--n", type=int, default=8)
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
# MCP server (optional)
rft_parser.add_argument(
"--mcp-server",
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
)
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true")
rft_parser.add_argument("--wandb-project")
rft_parser.add_argument("--wandb-entity")
rft_parser.add_argument("--wandb-run-id")
rft_parser.add_argument("--wandb-api-key")
# Misc
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
Expand Down
63 changes: 45 additions & 18 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _poll_evaluator_status(


def create_rft_command(args) -> int:
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
evaluator_id: Optional[str] = getattr(args, "evaluator", None)
non_interactive: bool = bool(getattr(args, "yes", False))
dry_run: bool = bool(getattr(args, "dry_run", False))
force: bool = bool(getattr(args, "force", False))
Expand Down Expand Up @@ -373,19 +373,19 @@ def create_rft_command(args) -> int:
print("No evaluation tests found.")
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
return 1
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
# Always interactive selection here
try:
selected_tests = _prompt_select(tests, non_interactive=non_interactive)
except Exception:
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.")
return 1
if not selected_tests:
print("No tests selected.")
return 1
if len(selected_tests) != 1:
if non_interactive and len(selected_tests) > 1:
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
print(" Please pass --evaluator-id or --entry to disambiguate.")
print(" Please pass --evaluator or --entry to disambiguate.")
try:
# Offer candidate evaluator ids for convenience
tests = _discover_tests(project_root)
Expand All @@ -410,8 +410,13 @@ def create_rft_command(args) -> int:
selected_test_file_path, selected_test_func_name = _resolve_selected_test(
project_root, evaluator_id, selected_tests=selected_tests
)
# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
# Resolve evaluator resource name to fully-qualified format required by API.
# Allow users to pass either short id or fully-qualified resource.
if evaluator_id and evaluator_id.startswith("accounts/"):
evaluator_resource_name = evaluator_id
evaluator_id = _extract_terminal_segment(evaluator_id)
else:
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
skip_upload = False
Expand Down Expand Up @@ -470,10 +475,10 @@ def create_rft_command(args) -> int:
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
if selected_entry is None and len(tests) > 1:
print(
f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n"
" Please re-run specifying the evaluator id.\n"
f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n"
" Please re-run specifying the evaluator.\n"
" Hints:\n"
" - eval-protocol create rft --evaluator-id <existing-evaluator-id>\n"
" - eval-protocol create rft --evaluator <existing-evaluator-id>\n"
)
return 1

Expand Down Expand Up @@ -523,10 +528,15 @@ def create_rft_command(args) -> int:
print(f"Warning: Failed to upload evaluator automatically: {e}")

# Determine dataset id and materialization path
dataset_id = getattr(args, "dataset_id", None)
dataset_id = getattr(args, "dataset", None)
dataset_jsonl = getattr(args, "dataset_jsonl", None)
dataset_display_name = getattr(args, "dataset_display_name", None)
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
dataset_resource_override: Optional[str] = None
if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"):
# Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing
dataset_resource_override = dataset_id
dataset_id = _extract_terminal_segment(dataset_id)

if not dataset_id:
# Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset.
Expand Down Expand Up @@ -573,7 +583,7 @@ def create_rft_command(args) -> int:
print(f"Warning: dataset builder failed: {e}")
if not dataset_jsonl:
print(
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
"Error: Could not determine dataset. Provide --dataset or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
)
return 1

Expand Down Expand Up @@ -628,6 +638,8 @@ def create_rft_command(args) -> int:
("learningRate", "learning_rate"),
("maxContextLength", "max_context_length"),
("loraRank", "lora_rank"),
("gradientAccumulationSteps", "gradient_accumulation_steps"),
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
("acceleratorCount", "accelerator_count"),
("region", "region"),
]:
Expand All @@ -640,14 +652,25 @@ def create_rft_command(args) -> int:
("temperature", "temperature"),
("topP", "top_p"),
("topK", "top_k"),
("maxTokens", "max_tokens"),
("n", "n"),
("maxTokens", "max_output_tokens"),
("n", "response_candidates_count"),
]:
val = getattr(args, arg_name, None)
if val is not None:
inference_params[key] = val
if getattr(args, "inference_extra_body", None):
inference_params["extraBody"] = args.inference_extra_body
if getattr(args, "extra_body", None):
extra = getattr(args, "extra_body")
if isinstance(extra, (dict, list)):
try:
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
return 1
elif isinstance(extra, str):
inference_params["extraBody"] = extra
else:
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
return 1

wandb_config: Optional[Dict[str, Any]] = None
if getattr(args, "wandb_enabled", False):
Expand All @@ -659,9 +682,12 @@ def create_rft_command(args) -> int:
"runId": getattr(args, "wandb_run_id", None),
}

# Build dataset resource (prefer override when provided)
dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}"

body: Dict[str, Any] = {
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
"displayName": getattr(args, "display_name", None),
"dataset": dataset_resource,
"evaluator": evaluator_resource_name,
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
"trainingConfig": training_config,
Expand All @@ -670,7 +696,8 @@ def create_rft_command(args) -> int:
"chunkSize": getattr(args, "chunk_size", None),
"outputStats": None,
"outputMetrics": None,
"mcpServer": None,
"mcpServer": getattr(args, "mcp_server", None),
"jobId": getattr(args, "job_id", None),
}
# Debug: print minimal summary
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
Expand Down
9 changes: 9 additions & 0 deletions eval_protocol/fireworks_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from urllib.parse import urlencode

import requests

Expand Down Expand Up @@ -186,6 +187,14 @@ def create_reinforcement_fine_tuning_job(
body: Dict[str, Any],
) -> Dict[str, Any]:
url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs"
# Move optional jobId from body to query parameter if provided
job_id = body.get("jobId")
if isinstance(job_id, str):
job_id = job_id.strip()
if job_id:
# Remove from body and append as query param
body.pop("jobId", None)
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
Expand Down
Loading
Loading