|
| 1 | +import json |
| 2 | +import os |
| 3 | +import sys |
| 4 | +from typing import Any, Dict, Optional |
| 5 | + |
| 6 | +from ..auth import ( |
| 7 | + get_fireworks_account_id, |
| 8 | + get_fireworks_api_base, |
| 9 | + get_fireworks_api_key, |
| 10 | + verify_api_key_and_get_account_id, |
| 11 | +) |
| 12 | +from ..fireworks_rft import ( |
| 13 | + _map_api_host_to_app_host, |
| 14 | + build_default_dataset_id, |
| 15 | + build_default_output_model, |
| 16 | + create_dataset_from_jsonl, |
| 17 | + create_reinforcement_fine_tuning_job, |
| 18 | + detect_dataset_builder, |
| 19 | + load_evaluator_trace, |
| 20 | + materialize_dataset_via_builder, |
| 21 | +) |
| 22 | +from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source |
| 23 | + |
| 24 | + |
| 25 | +def _ensure_account_id() -> Optional[str]: |
| 26 | + account_id = get_fireworks_account_id() |
| 27 | + api_key = get_fireworks_api_key() |
| 28 | + if not account_id and api_key: |
| 29 | + resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base()) |
| 30 | + if resolved: |
| 31 | + os.environ["FIREWORKS_ACCOUNT_ID"] = resolved |
| 32 | + return resolved |
| 33 | + return account_id |
| 34 | + |
| 35 | + |
| 36 | +def _extract_terminal_segment(resource_name: str) -> str: |
| 37 | + """Return the last path segment if a fully-qualified resource name is provided.""" |
| 38 | + try: |
| 39 | + return resource_name.strip("/").split("/")[-1] |
| 40 | + except Exception: |
| 41 | + return resource_name |
| 42 | + |
| 43 | + |
| 44 | +def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: |
| 45 | + api_base = get_fireworks_api_base() |
| 46 | + app_base = _map_api_host_to_app_host(api_base) |
| 47 | + print("\n📊 Dashboard Links:") |
| 48 | + evaluator_slug = _extract_terminal_segment(evaluator_id) |
| 49 | + print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}") |
| 50 | + if dataset_id: |
| 51 | + print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") |
| 52 | + if job_name: |
| 53 | + # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} |
| 54 | + try: |
| 55 | + job_id = job_name.strip().split("/")[-1] |
| 56 | + print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}") |
| 57 | + except Exception: |
| 58 | + pass |
| 59 | + |
| 60 | + |
| 61 | +def _auto_select_evaluator_id(cwd: str) -> Optional[str]: |
| 62 | + # Try local traces |
| 63 | + traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators") |
| 64 | + if os.path.isdir(traces_dir): |
| 65 | + candidates = [f[:-5] for f in os.listdir(traces_dir) if f.endswith(".json")] |
| 66 | + if len(candidates) == 1: |
| 67 | + return candidates[0] |
| 68 | + # Fall back to discovering a single evaluation_test |
| 69 | + tests = _discover_tests(cwd) |
| 70 | + if len(tests) == 1: |
| 71 | + qualname, source_file_path = tests[0].qualname, tests[0].file_path |
| 72 | + test_func_name = qualname.split(".")[-1] |
| 73 | + source_file_name = os.path.splitext(os.path.basename(source_file_path))[0] |
| 74 | + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}") |
| 75 | + return evaluator_id |
| 76 | + return None |
| 77 | + |
| 78 | + |
| 79 | +def create_rft_command(args) -> int: |
| 80 | + evaluator_id: Optional[str] = getattr(args, "evaluator_id", None) |
| 81 | + non_interactive: bool = bool(getattr(args, "yes", False)) |
| 82 | + dry_run: bool = bool(getattr(args, "dry_run", False)) |
| 83 | + |
| 84 | + api_key = get_fireworks_api_key() |
| 85 | + if not api_key: |
| 86 | + print("Error: FIREWORKS_API_KEY not set.") |
| 87 | + return 1 |
| 88 | + |
| 89 | + account_id = _ensure_account_id() |
| 90 | + if not account_id: |
| 91 | + print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.") |
| 92 | + return 1 |
| 93 | + |
| 94 | + api_base = get_fireworks_api_base() |
| 95 | + |
| 96 | + # Resolve evaluator id if omitted |
| 97 | + project_root = os.getcwd() |
| 98 | + if not evaluator_id: |
| 99 | + evaluator_id = _auto_select_evaluator_id(project_root) |
| 100 | + if not evaluator_id: |
| 101 | + print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") |
| 102 | + return 1 |
| 103 | + |
| 104 | + # Resolve evaluator resource name via local trace |
| 105 | + # trace = load_evaluator_trace(project_root, evaluator_id) |
| 106 | + # if not trace or not isinstance(trace, dict): |
| 107 | + # print( |
| 108 | + # "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." |
| 109 | + # ) |
| 110 | + # return 1 |
| 111 | + # evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id |
| 112 | + evaluator_resource_name = evaluator_id |
| 113 | + |
| 114 | + # Determine dataset id and materialization path |
| 115 | + dataset_id = getattr(args, "dataset_id", None) |
| 116 | + dataset_jsonl = getattr(args, "dataset_jsonl", None) |
| 117 | + dataset_display_name = getattr(args, "dataset_display_name", None) |
| 118 | + dataset_builder = getattr(args, "dataset_builder", None) |
| 119 | + |
| 120 | + if not dataset_id: |
| 121 | + # Try builder from args, else from trace detection |
| 122 | + # TODO: build dataset from traces directly |
| 123 | + # builder_spec = dataset_builder or trace.get("dataset_builder") |
| 124 | + # if not builder_spec: |
| 125 | + # # Attempt detect from metric_dir |
| 126 | + # metric_dir = trace.get("metric_dir") |
| 127 | + # if metric_dir: |
| 128 | + # builder_spec = detect_dataset_builder(metric_dir) |
| 129 | + # if not builder_spec: |
| 130 | + # print( |
| 131 | + # "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." |
| 132 | + # ) |
| 133 | + # return 1 |
| 134 | + # try: |
| 135 | + # dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) |
| 136 | + # print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") |
| 137 | + # except Exception as e: |
| 138 | + # print(f"Error: dataset builder failed: {e}") |
| 139 | + # return 1 |
| 140 | + |
| 141 | + if not dataset_jsonl: |
| 142 | + print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.") |
| 143 | + return 1 |
| 144 | + |
| 145 | + inferred_dataset_id = build_default_dataset_id(evaluator_id) |
| 146 | + if dry_run: |
| 147 | + print("--dry-run: would create dataset and upload JSONL") |
| 148 | + dataset_id = inferred_dataset_id |
| 149 | + else: |
| 150 | + try: |
| 151 | + dataset_id, _ = create_dataset_from_jsonl( |
| 152 | + account_id=account_id, |
| 153 | + api_key=api_key, |
| 154 | + api_base=api_base, |
| 155 | + dataset_id=inferred_dataset_id, |
| 156 | + display_name=dataset_display_name or inferred_dataset_id, |
| 157 | + jsonl_path=dataset_jsonl, |
| 158 | + ) |
| 159 | + print(f"✓ Created and uploaded dataset: {dataset_id}") |
| 160 | + except Exception as e: |
| 161 | + print(f"Error creating/uploading dataset: {e}") |
| 162 | + return 1 |
| 163 | + |
| 164 | + # Build training config/body |
| 165 | + training_config: Dict[str, Any] = {} |
| 166 | + if getattr(args, "base_model", None): |
| 167 | + training_config["baseModel"] = args.base_model |
| 168 | + if getattr(args, "warm_start_from", None): |
| 169 | + training_config["warmStartFrom"] = args.warm_start_from |
| 170 | + if "baseModel" not in training_config and "warmStartFrom" not in training_config: |
| 171 | + # Provide a conservative default if neither is set |
| 172 | + training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct" |
| 173 | + |
| 174 | + # Optional hyperparameters |
| 175 | + for key, arg_name in [ |
| 176 | + ("epochs", "epochs"), |
| 177 | + ("batchSize", "batch_size"), |
| 178 | + ("learningRate", "learning_rate"), |
| 179 | + ("maxContextLength", "max_context_length"), |
| 180 | + ("loraRank", "lora_rank"), |
| 181 | + ("acceleratorCount", "accelerator_count"), |
| 182 | + ("region", "region"), |
| 183 | + ]: |
| 184 | + val = getattr(args, arg_name, None) |
| 185 | + if val is not None: |
| 186 | + training_config[key] = val |
| 187 | + |
| 188 | + inference_params: Dict[str, Any] = {} |
| 189 | + for key, arg_name in [ |
| 190 | + ("temperature", "temperature"), |
| 191 | + ("topP", "top_p"), |
| 192 | + ("topK", "top_k"), |
| 193 | + ("maxTokens", "max_tokens"), |
| 194 | + ("n", "n"), |
| 195 | + ]: |
| 196 | + val = getattr(args, arg_name, None) |
| 197 | + if val is not None: |
| 198 | + inference_params[key] = val |
| 199 | + if getattr(args, "inference_extra_body", None): |
| 200 | + inference_params["extraBody"] = args.inference_extra_body |
| 201 | + |
| 202 | + wandb_config: Optional[Dict[str, Any]] = None |
| 203 | + if getattr(args, "wandb_enabled", False): |
| 204 | + wandb_config = { |
| 205 | + "enabled": True, |
| 206 | + "apiKey": getattr(args, "wandb_api_key", None), |
| 207 | + "project": getattr(args, "wandb_project", None), |
| 208 | + "entity": getattr(args, "wandb_entity", None), |
| 209 | + "runId": getattr(args, "wandb_run_id", None), |
| 210 | + } |
| 211 | + |
| 212 | + body: Dict[str, Any] = { |
| 213 | + # "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", |
| 214 | + "dataset": f"accounts/{account_id}/datasets/{dataset_id}", |
| 215 | + "evaluator": evaluator_resource_name, |
| 216 | + "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), |
| 217 | + "trainingConfig": training_config, |
| 218 | + "inferenceParameters": inference_params or None, |
| 219 | + "wandbConfig": wandb_config, |
| 220 | + "outputStats": None, |
| 221 | + "outputMetrics": None, |
| 222 | + "mcpServer": None, |
| 223 | + } |
| 224 | + print("Show body:") |
| 225 | + print(json.dumps(body, indent=2)) |
| 226 | + if getattr(args, "evaluation_dataset", None): |
| 227 | + body["evaluationDataset"] = args.evaluation_dataset |
| 228 | + if getattr(args, "output_model", None): |
| 229 | + body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}" |
| 230 | + else: |
| 231 | + body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id) |
| 232 | + |
| 233 | + # Clean None fields to avoid noisy payloads |
| 234 | + body = {k: v for k, v in body.items() if v is not None} |
| 235 | + |
| 236 | + if dry_run: |
| 237 | + print("--dry-run: would create RFT job with body:") |
| 238 | + print(json.dumps(body, indent=2)) |
| 239 | + _print_links(evaluator_id, dataset_id, None) |
| 240 | + return 0 |
| 241 | + |
| 242 | + try: |
| 243 | + result = create_reinforcement_fine_tuning_job( |
| 244 | + account_id=account_id, api_key=api_key, api_base=api_base, body=body |
| 245 | + ) |
| 246 | + job_name = result.get("name") if isinstance(result, dict) else None |
| 247 | + print("\n✅ Created Reinforcement Fine-tuning Job") |
| 248 | + if job_name: |
| 249 | + print(f" name: {job_name}") |
| 250 | + _print_links(evaluator_id, dataset_id, job_name) |
| 251 | + return 0 |
| 252 | + except Exception as e: |
| 253 | + print(f"Error creating RFT job: {e}") |
| 254 | + return 1 |
0 commit comments