From 763116e51df8f082608c99abbd422dcea126786e Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 13 Oct 2025 15:34:40 -0700 Subject: [PATCH 1/2] Create RFT on Fireworks directly --- eval_protocol/cli.py | 72 ++++++++ eval_protocol/cli_commands/create_rft.py | 212 ++++++++++++++++++++++ eval_protocol/cli_commands/upload.py | 18 ++ eval_protocol/fireworks_rft.py | 218 +++++++++++++++++++++++ 4 files changed, 520 insertions(+) create mode 100644 eval_protocol/cli_commands/create_rft.py create mode 100644 eval_protocol/fireworks_rft.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 3b7ff58f..81c835ad 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -356,6 +356,71 @@ def parse_args(args=None): help="Non-interactive: upload all discovered evaluation tests", ) + # Create command group + create_parser = subparsers.add_parser( + "create", + help="Resource creation commands", + ) + create_subparsers = create_parser.add_subparsers(dest="create_command") + rft_parser = create_subparsers.add_parser( + "rft", + help="Create a Reinforcement Fine-tuning Job on Fireworks", + ) + rft_parser.add_argument( + "--evaluator-id", + required=True, + help="Evaluator ID used during upload; resolves evaluator resource via local trace", + ) + # Dataset options + rft_parser.add_argument( + "--dataset-id", + help="Use existing Fireworks dataset id (skip local materialization)", + ) + rft_parser.add_argument( + "--dataset-jsonl", + help="Path to JSONL to upload as a new Fireworks dataset", + ) + rft_parser.add_argument( + "--dataset-builder", + help="Explicit dataset builder spec (module::function or path::function)", + ) + rft_parser.add_argument( + "--dataset-display-name", + help="Display name for dataset on Fireworks (defaults to dataset id)", + ) + # Training config and evaluator/job settings + rft_parser.add_argument("--base-model", help="Base model resource id") + rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") + rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") + rft_parser.add_argument("--epochs", type=int) + rft_parser.add_argument("--batch-size", type=int) + rft_parser.add_argument("--learning-rate", type=float) + rft_parser.add_argument("--max-context-length", type=int) + rft_parser.add_argument("--lora-rank", type=int) + rft_parser.add_argument("--accelerator-count", type=int) + rft_parser.add_argument("--region", help="Fireworks region enum value") + rft_parser.add_argument("--display-name", help="RFT job display name") + rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id") + rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True) + rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false") + # Inference params + rft_parser.add_argument("--temperature", type=float) + rft_parser.add_argument("--top-p", type=float) + rft_parser.add_argument("--top-k", type=int) + rft_parser.add_argument("--max-tokens", type=int) + rft_parser.add_argument("--n", type=int) + rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params") + # Wandb + rft_parser.add_argument("--wandb-enabled", action="store_true") + rft_parser.add_argument("--wandb-project") + rft_parser.add_argument("--wandb-entity") + rft_parser.add_argument("--wandb-run-id") + rft_parser.add_argument("--wandb-api-key") + # Misc + rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id") + rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") + rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") + # Run command (for Hydra-based evaluations) # This subparser intentionally defines no arguments itself. # All arguments after 'run' will be passed to Hydra by parse_known_args. @@ -481,6 +546,13 @@ def _extract_flag_value(argv_list, flag_name): from .cli_commands.upload import upload_command return upload_command(args) + elif args.command == "create": + if args.create_command == "rft": + from .cli_commands.create_rft import create_rft_command + + return create_rft_command(args) + print("Error: missing subcommand for 'create'. Try: eval-protocol create rft") + return 1 elif args.command == "run": # For the 'run' command, Hydra takes over argument parsing. diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py new file mode 100644 index 00000000..3253d759 --- /dev/null +++ b/eval_protocol/cli_commands/create_rft.py @@ -0,0 +1,212 @@ +import json +import os +import sys +from typing import Any, Dict, Optional + +from ..auth import ( + get_fireworks_account_id, + get_fireworks_api_base, + get_fireworks_api_key, + verify_api_key_and_get_account_id, +) +from ..fireworks_rft import ( + _map_api_host_to_app_host, + build_default_dataset_id, + build_default_output_model, + create_dataset_from_jsonl, + create_reinforcement_fine_tuning_job, + detect_dataset_builder, + load_evaluator_trace, + materialize_dataset_via_builder, +) + + +def _ensure_account_id() -> Optional[str]: + account_id = get_fireworks_account_id() + api_key = get_fireworks_api_key() + if not account_id and api_key: + resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base()) + if resolved: + os.environ["FIREWORKS_ACCOUNT_ID"] = resolved + return resolved + return account_id + + +def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: + api_base = get_fireworks_api_base() + app_base = _map_api_host_to_app_host(api_base) + print("\nšŸ“Š Dashboard Links:") + print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_id}") + if dataset_id: + print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") + if job_name: + # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} + try: + job_id = job_name.strip().split("/")[-1] + print(f" RFT Job: {app_base}/dashboard/rft/{job_id}") + except Exception: + pass + + +def create_rft_command(args) -> int: + evaluator_id: str = getattr(args, "evaluator_id") + non_interactive: bool = bool(getattr(args, "yes", False)) + dry_run: bool = bool(getattr(args, "dry_run", False)) + + api_key = get_fireworks_api_key() + if not api_key: + print("Error: FIREWORKS_API_KEY not set.") + return 1 + + account_id = _ensure_account_id() + if not account_id: + print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.") + return 1 + + api_base = get_fireworks_api_base() + + # Resolve evaluator resource name via local trace + project_root = os.getcwd() + trace = load_evaluator_trace(project_root, evaluator_id) + if not trace or not isinstance(trace, dict): + print( + "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." + ) + return 1 + evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id + + # Determine dataset id and materialization path + dataset_id = getattr(args, "dataset_id", None) + dataset_jsonl = getattr(args, "dataset_jsonl", None) + dataset_display_name = getattr(args, "dataset_display_name", None) + dataset_builder = getattr(args, "dataset_builder", None) + + if not dataset_id: + if not dataset_jsonl: + # Try builder from args, else from trace detection + builder_spec = dataset_builder or trace.get("dataset_builder") + if not builder_spec: + # Attempt detect from metric_dir + metric_dir = trace.get("metric_dir") + if metric_dir: + builder_spec = detect_dataset_builder(metric_dir) + if not builder_spec: + print( + "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." + ) + return 1 + try: + dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) + print(f"āœ“ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") + except Exception as e: + print(f"Error: dataset builder failed: {e}") + return 1 + + inferred_dataset_id = build_default_dataset_id(evaluator_id) + if dry_run: + print("--dry-run: would create dataset and upload JSONL") + dataset_id = inferred_dataset_id + else: + try: + dataset_id, _ = create_dataset_from_jsonl( + account_id=account_id, + api_key=api_key, + api_base=api_base, + dataset_id=inferred_dataset_id, + display_name=dataset_display_name or inferred_dataset_id, + jsonl_path=dataset_jsonl, + ) + print(f"āœ“ Created and uploaded dataset: {dataset_id}") + except Exception as e: + print(f"Error creating/uploading dataset: {e}") + return 1 + + # Build training config/body + training_config: Dict[str, Any] = {} + if getattr(args, "base_model", None): + training_config["baseModel"] = args.base_model + if getattr(args, "warm_start_from", None): + training_config["warmStartFrom"] = args.warm_start_from + if "baseModel" not in training_config and "warmStartFrom" not in training_config: + # Provide a conservative default if neither is set + training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct" + + # Optional hyperparameters + for key, arg_name in [ + ("epochs", "epochs"), + ("batchSize", "batch_size"), + ("learningRate", "learning_rate"), + ("maxContextLength", "max_context_length"), + ("loraRank", "lora_rank"), + ("acceleratorCount", "accelerator_count"), + ("region", "region"), + ]: + val = getattr(args, arg_name, None) + if val is not None: + training_config[key] = val + + inference_params: Dict[str, Any] = {} + for key, arg_name in [ + ("temperature", "temperature"), + ("topP", "top_p"), + ("topK", "top_k"), + ("maxTokens", "max_tokens"), + ("n", "n"), + ]: + val = getattr(args, arg_name, None) + if val is not None: + inference_params[key] = val + if getattr(args, "inference_extra_body", None): + inference_params["extraBody"] = args.inference_extra_body + + wandb_config: Optional[Dict[str, Any]] = None + if getattr(args, "wandb_enabled", False): + wandb_config = { + "enabled": True, + "apiKey": getattr(args, "wandb_api_key", None), + "project": getattr(args, "wandb_project", None), + "entity": getattr(args, "wandb_entity", None), + "runId": getattr(args, "wandb_run_id", None), + } + + body: Dict[str, Any] = { + "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", + "dataset": dataset_id, + "evaluator": evaluator_resource_name, + "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), + "trainingConfig": training_config, + "inferenceParameters": inference_params or None, + "wandbConfig": wandb_config, + "outputStats": None, + "outputMetrics": None, + "mcpServer": None, + } + if getattr(args, "evaluation_dataset", None): + body["evaluationDataset"] = args.evaluation_dataset + if getattr(args, "output_model", None): + body.setdefault("trainingConfig", {})["outputModel"] = args.output_model + else: + body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id) + + # Clean None fields to avoid noisy payloads + body = {k: v for k, v in body.items() if v is not None} + + if dry_run: + print("--dry-run: would create RFT job with body:") + print(json.dumps(body, indent=2)) + _print_links(evaluator_id, dataset_id, None) + return 0 + + try: + result = create_reinforcement_fine_tuning_job( + account_id=account_id, api_key=api_key, api_base=api_base, body=body + ) + job_name = result.get("name") if isinstance(result, dict) else None + print("\nāœ… Created Reinforcement Fine-tuning Job") + if job_name: + print(f" name: {job_name}") + _print_links(evaluator_id, dataset_id, job_name) + return 0 + except Exception as e: + print(f"Error creating RFT job: {e}") + return 1 diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index d696e664..86490f62 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -21,6 +21,7 @@ from eval_protocol.platform_api import create_or_update_fireworks_secret from eval_protocol.evaluation import create_evaluation +from eval_protocol.fireworks_rft import save_evaluator_trace, detect_dataset_builder @dataclass @@ -666,6 +667,23 @@ def upload_command(args: argparse.Namespace) -> int: ) name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id + # Persist local evaluator trace for later `create rft` + try: + metric_dir = os.path.dirname(source_file_path) if source_file_path else root + builder_spec = detect_dataset_builder(metric_dir) or None + trace_payload = { + "evaluator_id": evaluator_id, + "evaluator_resource_name": name, + "entry_point": entry_point, + "metric_dir": metric_dir, + "project_root": root, + "dataset_builder": builder_spec, + } + save_evaluator_trace(project_root=root, evaluator_id=evaluator_id, trace=trace_payload) + except Exception: + # Non-fatal; continue + pass + # Print success message with Fireworks dashboard link print(f"\nāœ… Successfully uploaded evaluator: {evaluator_id}") print("šŸ“Š View in Fireworks Dashboard:") diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py new file mode 100644 index 00000000..3fd44eaa --- /dev/null +++ b/eval_protocol/fireworks_rft.py @@ -0,0 +1,218 @@ +import importlib.util +import io +import json +import os +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Optional, Tuple + +import requests + +from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key + + +def _map_api_host_to_app_host(api_base: str) -> str: + try: + from urllib.parse import urlparse + + parsed = urlparse(api_base) + host = parsed.netloc or parsed.path + if host.startswith("dev.api.fireworks.ai"): + return f"{parsed.scheme or 'https'}://dev.fireworks.ai" + if host.startswith("api."): + return f"{parsed.scheme or 'https'}://{host.replace('api.', 'app.', 1)}" + return f"{parsed.scheme or 'https'}://{host}" + except Exception: + return "https://app.fireworks.ai" + + +def load_evaluator_trace(project_root: str, evaluator_id: str) -> Optional[Dict[str, Any]]: + trace_path = Path(project_root) / ".eval_protocol" / "evaluators" / f"{evaluator_id}.json" + if not trace_path.exists(): + return None + try: + with open(trace_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + +def save_evaluator_trace(project_root: str, evaluator_id: str, trace: Dict[str, Any]) -> None: + base_dir = Path(project_root) / ".eval_protocol" / "evaluators" + base_dir.mkdir(parents=True, exist_ok=True) + trace_path = base_dir / f"{evaluator_id}.json" + with open(trace_path, "w", encoding="utf-8") as f: + json.dump(trace, f, indent=2, ensure_ascii=False) + + +def detect_dataset_builder(metric_dir: str) -> Optional[str]: + """ + Best-effort scan for a dataset builder callable inside the metric directory. + Returns a builder spec string in the form "path/to/module.py::function" if found. + """ + try: + candidates: list[Tuple[str, str]] = [] + for root, _, files in os.walk(metric_dir): + for name in files: + if not name.endswith(".py"): + continue + file_path = os.path.join(root, name) + # Load module via file location + module_name = Path(file_path).stem + spec = importlib.util.spec_from_file_location(module_name, file_path) + if not spec or not spec.loader: + continue + module = importlib.util.module_from_spec(spec) + try: + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + except Exception: + continue + # Common exported symbol names + symbol_names = [ + "build_training_dataset", + "get_training_dataset", + "get_dataset", + "dataset", + "DATASET_BUILDER", + ] + for symbol in symbol_names: + if hasattr(module, symbol): + candidates.append((file_path, symbol)) + if not candidates: + return None + # Prefer build_training_dataset then get_training_dataset, else first + preference = { + "build_training_dataset": 0, + "get_training_dataset": 1, + "get_dataset": 2, + "dataset": 3, + "DATASET_BUILDER": 4, + } + candidates.sort(key=lambda x: preference.get(x[1], 99)) + best_file, best_symbol = candidates[0] + return f"{best_file}::{best_symbol}" + except Exception: + return None + + +def _import_builder(builder_spec: str) -> Callable[[], Iterable[Dict[str, Any]]]: + target, func = builder_spec.split("::", 1) + # If target looks like a path, load from file + if "/" in target or target.endswith(".py") or os.path.exists(target): + file_path = target if target.endswith(".py") else f"{target}.py" + if not os.path.isfile(file_path): + raise ValueError(f"Builder file not found: {file_path}") + module_name = Path(file_path).stem + spec = importlib.util.spec_from_file_location(module_name, file_path) + if not spec or not spec.loader: + raise ValueError(f"Unable to load builder module: {file_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + else: + # Treat as module path + module = importlib.import_module(target) + if not hasattr(module, func): + raise ValueError(f"Function '{func}' not found in module '{getattr(module, '__name__', target)}'") + callable_obj = getattr(module, func) + if callable(callable_obj): + return callable_obj # type: ignore[return-value] + # If symbol is a constant like DATASET_BUILDER, expect it to be callable + if hasattr(callable_obj, "__call__"): + return callable_obj # type: ignore[return-value] + raise ValueError("Dataset builder is not callable") + + +def materialize_dataset_via_builder(builder_spec: str, output_path: Optional[str] = None) -> Tuple[str, int]: + builder = _import_builder(builder_spec) + rows_iter = builder() + if output_path is None: + fd, tmp_path = tempfile.mkstemp(prefix="ep_rft_dataset_", suffix=".jsonl") + os.close(fd) + output_path = tmp_path + count = 0 + with open(output_path, "w", encoding="utf-8") as f: + for row in rows_iter: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + count += 1 + return output_path, count + + +def create_dataset_from_jsonl( + account_id: str, + api_key: str, + api_base: str, + dataset_id: str, + display_name: Optional[str], + jsonl_path: str, +) -> Tuple[str, Dict[str, Any]]: + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + # Count examples quickly + example_count = 0 + with open(jsonl_path, "r", encoding="utf-8") as f: + for _ in f: + example_count += 1 + dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" + payload = { + "dataset": { + "displayName": display_name or dataset_id, + "evalProtocol": {}, + "format": "FORMAT_UNSPECIFIED", + "exampleCount": str(example_count), + }, + "datasetId": dataset_id, + } + resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) + if resp.status_code not in (200, 201): + raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") + ds = resp.json() + + upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" + with open(jsonl_path, "rb") as f: + files = {"file": f} + up_headers = {"Authorization": f"Bearer {api_key}"} + up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) + if up_resp.status_code not in (200, 201): + raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") + return dataset_id, ds + + +def create_reinforcement_fine_tuning_job( + account_id: str, + api_key: str, + api_base: str, + body: Dict[str, Any], +) -> Dict[str, Any]: + url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs" + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"} + resp = requests.post(url, json=body, headers=headers, timeout=60) + if resp.status_code not in (200, 201): + raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}") + return resp.json() + + +def build_default_dataset_id(evaluator_id: str) -> str: + ts = time.strftime("%Y%m%d%H%M%S") + base = evaluator_id.lower().replace("_", "-") + return f"{base}-dataset-{ts}" + + +def build_default_output_model(evaluator_id: str) -> str: + base = evaluator_id.lower().replace("_", "-") + return f"{base}-rft" + + +__all__ = [ + "load_evaluator_trace", + "save_evaluator_trace", + "detect_dataset_builder", + "materialize_dataset_via_builder", + "create_dataset_from_jsonl", + "create_reinforcement_fine_tuning_job", + "build_default_dataset_id", + "build_default_output_model", + "_map_api_host_to_app_host", +] From d6c76c1152c582efb767cf57d39a3a8430838284 Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Mon, 13 Oct 2025 15:36:46 -0700 Subject: [PATCH 2/2] kick off RFT in one command --- eval_protocol/cli.py | 3 +- eval_protocol/cli_commands/create_rft.py | 106 ++++++++++++++++------- 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 81c835ad..125198e1 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -368,8 +368,7 @@ def parse_args(args=None): ) rft_parser.add_argument( "--evaluator-id", - required=True, - help="Evaluator ID used during upload; resolves evaluator resource via local trace", + help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test", ) # Dataset options rft_parser.add_argument( diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 3253d759..cb78fbae 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -19,6 +19,7 @@ load_evaluator_trace, materialize_dataset_via_builder, ) +from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source def _ensure_account_id() -> Optional[str]: @@ -32,24 +33,51 @@ def _ensure_account_id() -> Optional[str]: return account_id +def _extract_terminal_segment(resource_name: str) -> str: + """Return the last path segment if a fully-qualified resource name is provided.""" + try: + return resource_name.strip("/").split("/")[-1] + except Exception: + return resource_name + + def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: api_base = get_fireworks_api_base() app_base = _map_api_host_to_app_host(api_base) print("\nšŸ“Š Dashboard Links:") - print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_id}") + evaluator_slug = _extract_terminal_segment(evaluator_id) + print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}") if dataset_id: print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") if job_name: # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} try: job_id = job_name.strip().split("/")[-1] - print(f" RFT Job: {app_base}/dashboard/rft/{job_id}") + print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}") except Exception: pass +def _auto_select_evaluator_id(cwd: str) -> Optional[str]: + # Try local traces + traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators") + if os.path.isdir(traces_dir): + candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")] + if len(candidates) == 1: + return candidates[0] + # Fall back to discovering a single evaluation_test + tests = _discover_tests(cwd) + if len(tests) == 1: + qualname, source_file_path = tests[0].qualname, tests[0].file_path + test_func_name = qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(source_file_path))[0] + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}") + return evaluator_id + return None + + def create_rft_command(args) -> int: - evaluator_id: str = getattr(args, "evaluator_id") + evaluator_id: Optional[str] = getattr(args, "evaluator_id", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) @@ -65,15 +93,23 @@ def create_rft_command(args) -> int: api_base = get_fireworks_api_base() - # Resolve evaluator resource name via local trace + # Resolve evaluator id if omitted project_root = os.getcwd() - trace = load_evaluator_trace(project_root, evaluator_id) - if not trace or not isinstance(trace, dict): - print( - "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." - ) - return 1 - evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id + if not evaluator_id: + evaluator_id = _auto_select_evaluator_id(project_root) + if not evaluator_id: + print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") + return 1 + + # Resolve evaluator resource name via local trace + # trace = load_evaluator_trace(project_root, evaluator_id) + # if not trace or not isinstance(trace, dict): + # print( + # "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." + # ) + # return 1 + # evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id + evaluator_resource_name = evaluator_id # Determine dataset id and materialization path dataset_id = getattr(args, "dataset_id", None) @@ -82,25 +118,29 @@ def create_rft_command(args) -> int: dataset_builder = getattr(args, "dataset_builder", None) if not dataset_id: + # Try builder from args, else from trace detection + # TODO: build dataset from traces directly + # builder_spec = dataset_builder or trace.get("dataset_builder") + # if not builder_spec: + # # Attempt detect from metric_dir + # metric_dir = trace.get("metric_dir") + # if metric_dir: + # builder_spec = detect_dataset_builder(metric_dir) + # if not builder_spec: + # print( + # "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." + # ) + # return 1 + # try: + # dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) + # print(f"āœ“ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") + # except Exception as e: + # print(f"Error: dataset builder failed: {e}") + # return 1 + if not dataset_jsonl: - # Try builder from args, else from trace detection - builder_spec = dataset_builder or trace.get("dataset_builder") - if not builder_spec: - # Attempt detect from metric_dir - metric_dir = trace.get("metric_dir") - if metric_dir: - builder_spec = detect_dataset_builder(metric_dir) - if not builder_spec: - print( - "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." - ) - return 1 - try: - dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) - print(f"āœ“ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") - except Exception as e: - print(f"Error: dataset builder failed: {e}") - return 1 + print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.") + return 1 inferred_dataset_id = build_default_dataset_id(evaluator_id) if dry_run: @@ -170,8 +210,8 @@ def create_rft_command(args) -> int: } body: Dict[str, Any] = { - "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", - "dataset": dataset_id, + # "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", + "dataset": f"accounts/{account_id}/datasets/{dataset_id}", "evaluator": evaluator_resource_name, "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), "trainingConfig": training_config, @@ -181,10 +221,12 @@ def create_rft_command(args) -> int: "outputMetrics": None, "mcpServer": None, } + print("Show body:") + print(json.dumps(body, indent=2)) if getattr(args, "evaluation_dataset", None): body["evaluationDataset"] = args.evaluation_dataset if getattr(args, "output_model", None): - body.setdefault("trainingConfig", {})["outputModel"] = args.output_model + body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}" else: body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)