diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index d2a8229c..4232ea2b 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -371,13 +371,13 @@ def parse_args(args=None): help="Create a Reinforcement Fine-tuning Job on Fireworks", ) rft_parser.add_argument( - "--evaluator-id", - help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test", + "--evaluator", + help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests", ) # Dataset options rft_parser.add_argument( - "--dataset-id", - help="Use existing Fireworks dataset id (skip local materialization)", + "--dataset", + help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization", ) rft_parser.add_argument( "--dataset-jsonl", @@ -400,6 +400,8 @@ def parse_args(args=None): rft_parser.add_argument("--learning-rate", type=float, default=3e-5) rft_parser.add_argument("--max-context-length", type=int, default=65536) rft_parser.add_argument("--lora-rank", type=int, default=16) + rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") + rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps") rft_parser.add_argument("--accelerator-count", type=int, default=1) rft_parser.add_argument("--region", help="Fireworks region enum value") rft_parser.add_argument("--display-name", help="RFT job display name") @@ -412,9 +414,14 @@ def parse_args(args=None): rft_parser.add_argument("--temperature", type=float) rft_parser.add_argument("--top-p", type=float) rft_parser.add_argument("--top-k", type=int) - rft_parser.add_argument("--max-tokens", type=int, default=32768) - rft_parser.add_argument("--n", type=int, default=8) - rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params") + rft_parser.add_argument("--max-output-tokens", type=int, default=32768) + rft_parser.add_argument("--response-candidates-count", type=int, default=8) + rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") + # MCP server (optional) + rft_parser.add_argument( + "--mcp-server", + help="The MCP server resource name to use for the reinforcement fine-tuning job.", + ) # Wandb rft_parser.add_argument("--wandb-enabled", action="store_true") rft_parser.add_argument("--wandb-project") @@ -422,7 +429,7 @@ def parse_args(args=None): rft_parser.add_argument("--wandb-run-id") rft_parser.add_argument("--wandb-api-key") # Misc - rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id") + rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index c0ff0358..afb5cd8d 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -344,7 +344,7 @@ def _poll_evaluator_status( def create_rft_command(args) -> int: - evaluator_id: Optional[str] = getattr(args, "evaluator_id", None) + evaluator_id: Optional[str] = getattr(args, "evaluator", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) force: bool = bool(getattr(args, "force", False)) @@ -373,11 +373,11 @@ def create_rft_command(args) -> int: print("No evaluation tests found.") print("\nHint: Make sure your tests use the @evaluation_test decorator.") return 1 - # Always interactive selection here (no implicit quiet unless --evaluator-id was provided) + # Always interactive selection here try: selected_tests = _prompt_select(tests, non_interactive=non_interactive) except Exception: - print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.") + print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.") return 1 if not selected_tests: print("No tests selected.") @@ -385,7 +385,7 @@ def create_rft_command(args) -> int: if len(selected_tests) != 1: if non_interactive and len(selected_tests) > 1: print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") - print(" Please pass --evaluator-id or --entry to disambiguate.") + print(" Please pass --evaluator or --entry to disambiguate.") try: # Offer candidate evaluator ids for convenience tests = _discover_tests(project_root) @@ -410,8 +410,13 @@ def create_rft_command(args) -> int: selected_test_file_path, selected_test_func_name = _resolve_selected_test( project_root, evaluator_id, selected_tests=selected_tests ) - # Resolve evaluator resource name to fully-qualified format required by API - evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" + # Resolve evaluator resource name to fully-qualified format required by API. + # Allow users to pass either short id or fully-qualified resource. + if evaluator_id and evaluator_id.startswith("accounts/"): + evaluator_resource_name = evaluator_id + evaluator_id = _extract_terminal_segment(evaluator_id) + else: + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" # Optional short-circuit: if evaluator already exists and not forcing, skip upload path skip_upload = False @@ -470,10 +475,10 @@ def create_rft_command(args) -> int: # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators if selected_entry is None and len(tests) > 1: print( - f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n" - " Please re-run specifying the evaluator id.\n" + f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" + " Please re-run specifying the evaluator.\n" " Hints:\n" - " - eval-protocol create rft --evaluator-id \n" + " - eval-protocol create rft --evaluator \n" ) return 1 @@ -523,10 +528,15 @@ def create_rft_command(args) -> int: print(f"Warning: Failed to upload evaluator automatically: {e}") # Determine dataset id and materialization path - dataset_id = getattr(args, "dataset_id", None) + dataset_id = getattr(args, "dataset", None) dataset_jsonl = getattr(args, "dataset_jsonl", None) dataset_display_name = getattr(args, "dataset_display_name", None) dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow + dataset_resource_override: Optional[str] = None + if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"): + # Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing + dataset_resource_override = dataset_id + dataset_id = _extract_terminal_segment(dataset_id) if not dataset_id: # Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset. @@ -573,7 +583,7 @@ def create_rft_command(args) -> int: print(f"Warning: dataset builder failed: {e}") if not dataset_jsonl: print( - "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test." + "Error: Could not determine dataset. Provide --dataset or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test." ) return 1 @@ -628,6 +638,8 @@ def create_rft_command(args) -> int: ("learningRate", "learning_rate"), ("maxContextLength", "max_context_length"), ("loraRank", "lora_rank"), + ("gradientAccumulationSteps", "gradient_accumulation_steps"), + ("learningRateWarmupSteps", "learning_rate_warmup_steps"), ("acceleratorCount", "accelerator_count"), ("region", "region"), ]: @@ -640,14 +652,25 @@ def create_rft_command(args) -> int: ("temperature", "temperature"), ("topP", "top_p"), ("topK", "top_k"), - ("maxTokens", "max_tokens"), - ("n", "n"), + ("maxTokens", "max_output_tokens"), + ("n", "response_candidates_count"), ]: val = getattr(args, arg_name, None) if val is not None: inference_params[key] = val - if getattr(args, "inference_extra_body", None): - inference_params["extraBody"] = args.inference_extra_body + if getattr(args, "extra_body", None): + extra = getattr(args, "extra_body") + if isinstance(extra, (dict, list)): + try: + inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False) + except (TypeError, ValueError) as e: + print(f"Error: --extra-body dict/list must be JSON-serializable: {e}") + return 1 + elif isinstance(extra, str): + inference_params["extraBody"] = extra + else: + print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.") + return 1 wandb_config: Optional[Dict[str, Any]] = None if getattr(args, "wandb_enabled", False): @@ -659,9 +682,12 @@ def create_rft_command(args) -> int: "runId": getattr(args, "wandb_run_id", None), } + # Build dataset resource (prefer override when provided) + dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}" + body: Dict[str, Any] = { - # "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", - "dataset": f"accounts/{account_id}/datasets/{dataset_id}", + "displayName": getattr(args, "display_name", None), + "dataset": dataset_resource, "evaluator": evaluator_resource_name, "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), "trainingConfig": training_config, @@ -670,7 +696,8 @@ def create_rft_command(args) -> int: "chunkSize": getattr(args, "chunk_size", None), "outputStats": None, "outputMetrics": None, - "mcpServer": None, + "mcpServer": getattr(args, "mcp_server", None), + "jobId": getattr(args, "job_id", None), } # Debug: print minimal summary print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 5c88c6dc..2e3cfc03 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -8,6 +8,7 @@ import uuid from pathlib import Path from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from urllib.parse import urlencode import requests @@ -186,6 +187,14 @@ def create_reinforcement_fine_tuning_job( body: Dict[str, Any], ) -> Dict[str, Any]: url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs" + # Move optional jobId from body to query parameter if provided + job_id = body.get("jobId") + if isinstance(job_id, str): + job_id = job_id.strip() + if job_id: + # Remove from body and append as query param + body.pop("jobId", None) + url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft_infer.py index 509a2d8a..9ef7d707 100644 --- a/tests/test_cli_create_rft_infer.py +++ b/tests/test_cli_create_rft_infer.py @@ -65,8 +65,8 @@ def _fake_create_job(account_id, api_key, api_base, body): args = argparse.Namespace( # Evaluator and dataset - evaluator_id="my-evaluator", - dataset_id=None, + evaluator="my-evaluator", + dataset=None, dataset_jsonl=str(ds_path), dataset_display_name="My Dataset", dataset_builder=None, @@ -91,9 +91,9 @@ def _fake_create_job(account_id, api_key, api_base, body): temperature=0.9, top_p=0.95, top_k=50, - max_tokens=4096, - n=6, - inference_extra_body='{"foo":"bar"}', + max_output_tokens=4096, + response_candidates_count=6, + extra_body='{"foo":"bar"}', # Rollout chunking and eval carveout chunk_size=250, eval_auto_carveout=False, # explicitly disabled via --no-eval-auto-carveout @@ -105,7 +105,7 @@ def _fake_create_job(account_id, api_key, api_base, body): wandb_run_id="run123", wandb_api_key="key123", # Unused in body but accepted by parser - rft_job_id=None, + job_id=None, display_name=None, ) @@ -195,12 +195,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d # Build args: non_interactive (yes=True), no explicit evaluator_id, valid warm_start_from args = type("Args", (), {})() - setattr(args, "evaluator_id", None) + setattr(args, "evaluator", None) setattr(args, "yes", True) setattr(args, "dry_run", False) setattr(args, "force", False) setattr(args, "env_file", None) - setattr(args, "dataset_id", None) + setattr(args, "dataset", None) setattr(args, "dataset_jsonl", str(ds_path)) setattr(args, "dataset_display_name", None) setattr(args, "dataset_builder", None) @@ -283,12 +283,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d import argparse args = argparse.Namespace( - evaluator_id=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), + evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=str(ds_path), dataset_display_name=None, dataset_builder=None, @@ -371,12 +371,12 @@ def _fake_upload(ns): import argparse args = argparse.Namespace( - evaluator_id=None, + evaluator=None, yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=str(ds_path), dataset_display_name=None, dataset_builder=None, @@ -438,12 +438,12 @@ def raise_for_status(self): import argparse args = argparse.Namespace( - evaluator_id="some-eval", + evaluator="some-eval", yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=str(ds_path), dataset_display_name=None, dataset_builder=None, @@ -495,12 +495,12 @@ def _raise(*a, **k): import argparse args = argparse.Namespace( - evaluator_id="some-eval", + evaluator="some-eval", yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=str(project / "dataset.jsonl"), dataset_display_name=None, dataset_builder=None, @@ -571,12 +571,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d import argparse args = argparse.Namespace( - evaluator_id=None, + evaluator=None, yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=None, dataset_display_name=None, dataset_builder=None, @@ -648,12 +648,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d import argparse args = argparse.Namespace( - evaluator_id=None, + evaluator=None, yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=None, dataset_display_name=None, dataset_builder=None, @@ -723,12 +723,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d import argparse args = argparse.Namespace( - evaluator_id=None, + evaluator=None, yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=None, dataset_display_name=None, dataset_builder=None, @@ -815,12 +815,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d # Provide evaluator_id that matches beta.test_two eval_id = cr._normalize_evaluator_id("beta-test_two") args = argparse.Namespace( - evaluator_id=eval_id, + evaluator=eval_id, yes=True, dry_run=False, force=False, env_file=None, - dataset_id=None, + dataset=None, dataset_jsonl=None, dataset_display_name=None, dataset_builder=None, @@ -844,3 +844,195 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d # Ensure the dataset id is based on evaluator_id assert captured["dataset_id"].startswith(f"{eval_id}-dataset-") assert captured["jsonl_path"] == str(jsonl_path) + + +def test_cli_full_command_style_evaluator_and_dataset_flags(monkeypatch): + # Env + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "pyroworks-dev") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Mock evaluator exists and ACTIVE + class _Resp: + ok = True + + def json(self): + return {"state": "ACTIVE"} + + def raise_for_status(self): + return None + + from eval_protocol.cli_commands import create_rft as cr + + monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + + # Capture URL and JSON via fireworks layer + import eval_protocol.fireworks_rft as fr + + captured = {"url": None, "json": None} + + class _RespPost: + status_code = 200 + + def json(self): + return {"name": "accounts/pyroworks-dev/reinforcementFineTuningJobs/xyz"} + + def _fake_post(url, json=None, headers=None, timeout=None): + captured["url"] = url + captured["json"] = json + return _RespPost() + + monkeypatch.setattr(fr.requests, "post", _fake_post) + + # Build args via CLI parser to validate flag names + from eval_protocol.cli import parse_args + + argv = [ + "create", + "rft", + "--base-model", + "accounts/fireworks/models/qwen3-0p6b", + "--dataset", + "svgbench-small", + "--output-model", + "svgbench-agent-small-bchen-2", + "--evaluator", + "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1", + "--max-context-length", + "65536", + "--response-candidates-count", + "4", + "--batch-size", + "128000", + "--chunk-size", + "50", + "--epochs", + "4", + "--max-output-tokens", + "32768", + "--learning-rate", + "0.00003", + "--lora-rank", + "16", + "--job-id", + "custom-job-123", + "--yes", + ] + args, _ = parse_args(argv) + + # Execute command + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["json"] is not None + body = captured["json"] + + # Evaluator and dataset resources + assert body["evaluator"] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1" + assert body["dataset"] == "accounts/pyroworks-dev/datasets/svgbench-small" + + # Training config mapping + tc = body["trainingConfig"] + assert tc["baseModel"] == "accounts/fireworks/models/qwen3-0p6b" + assert tc["outputModel"] == "accounts/pyroworks-dev/models/svgbench-agent-small-bchen-2" + assert tc["epochs"] == 4 + assert tc["batchSize"] == 128000 + assert abs(tc["learningRate"] - 0.00003) < 1e-12 + assert tc["loraRank"] == 16 + assert tc["maxContextLength"] == 65536 + + # Inference params mapping + ip = body["inferenceParameters"] + assert ip["n"] == 4 + assert ip["maxTokens"] == 32768 + + # Other top-level + assert body["chunkSize"] == 50 + # Job id sent as query param + assert captured["url"] is not None and "reinforcementFineTuningJobId=custom-job-123" in captured["url"] + assert "jobId" not in body + + +def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(tmp_path, monkeypatch): + # Setup project + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Environment + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Single discovered test + test_file = project / "metric" / "test_pref.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# prefer explicit dataset_jsonl", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_pref", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + + # Stub selector, upload, and polling + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # Prepare two JSONL paths: one explicit via --dataset-jsonl and one inferable via input_dataset + explicit_jsonl = project / "metric" / "explicit.jsonl" + explicit_jsonl.write_text('{"row":"explicit"}\n', encoding="utf-8") + inferred_jsonl = project / "metric" / "inferred.jsonl" + inferred_jsonl.write_text('{"row":"inferred"}\n', encoding="utf-8") + + # If inference were to happen, return inferred path — but explicit should win + monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None) + calls = {"input_dataset": 0} + + def _extract_input_dataset(file_path, func_name): + calls["input_dataset"] += 1 + return str(inferred_jsonl) + + monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", _extract_input_dataset) + monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None) + + captured = {"jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + import argparse + + args = argparse.Namespace( + evaluator=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=str(explicit_jsonl), + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + # Ensure the explicitly provided JSONL file is used, not the inferred one + assert captured["jsonl_path"] == str(explicit_jsonl) + assert captured["jsonl_path"] != str(inferred_jsonl) + # And because --dataset-jsonl was provided, we should never call the input_dataset extractor + assert calls["input_dataset"] == 0