|
| 1 | +import json |
| 2 | +import os |
| 3 | +import sys |
| 4 | +from typing import Any, Dict, Optional |
| 5 | + |
| 6 | +from ..auth import ( |
| 7 | + get_fireworks_account_id, |
| 8 | + get_fireworks_api_base, |
| 9 | + get_fireworks_api_key, |
| 10 | + verify_api_key_and_get_account_id, |
| 11 | +) |
| 12 | +from ..fireworks_rft import ( |
| 13 | + _map_api_host_to_app_host, |
| 14 | + build_default_dataset_id, |
| 15 | + build_default_output_model, |
| 16 | + create_dataset_from_jsonl, |
| 17 | + create_reinforcement_fine_tuning_job, |
| 18 | + detect_dataset_builder, |
| 19 | + load_evaluator_trace, |
| 20 | + materialize_dataset_via_builder, |
| 21 | +) |
| 22 | + |
| 23 | + |
| 24 | +def _ensure_account_id() -> Optional[str]: |
| 25 | + account_id = get_fireworks_account_id() |
| 26 | + api_key = get_fireworks_api_key() |
| 27 | + if not account_id and api_key: |
| 28 | + resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base()) |
| 29 | + if resolved: |
| 30 | + os.environ["FIREWORKS_ACCOUNT_ID"] = resolved |
| 31 | + return resolved |
| 32 | + return account_id |
| 33 | + |
| 34 | + |
| 35 | +def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: |
| 36 | + api_base = get_fireworks_api_base() |
| 37 | + app_base = _map_api_host_to_app_host(api_base) |
| 38 | + print("\n📊 Dashboard Links:") |
| 39 | + print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_id}") |
| 40 | + if dataset_id: |
| 41 | + print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") |
| 42 | + if job_name: |
| 43 | + # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} |
| 44 | + try: |
| 45 | + job_id = job_name.strip().split("/")[-1] |
| 46 | + print(f" RFT Job: {app_base}/dashboard/rft/{job_id}") |
| 47 | + except Exception: |
| 48 | + pass |
| 49 | + |
| 50 | + |
| 51 | +def create_rft_command(args) -> int: |
| 52 | + evaluator_id: str = getattr(args, "evaluator_id") |
| 53 | + non_interactive: bool = bool(getattr(args, "yes", False)) |
| 54 | + dry_run: bool = bool(getattr(args, "dry_run", False)) |
| 55 | + |
| 56 | + api_key = get_fireworks_api_key() |
| 57 | + if not api_key: |
| 58 | + print("Error: FIREWORKS_API_KEY not set.") |
| 59 | + return 1 |
| 60 | + |
| 61 | + account_id = _ensure_account_id() |
| 62 | + if not account_id: |
| 63 | + print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.") |
| 64 | + return 1 |
| 65 | + |
| 66 | + api_base = get_fireworks_api_base() |
| 67 | + |
| 68 | + # Resolve evaluator resource name via local trace |
| 69 | + project_root = os.getcwd() |
| 70 | + trace = load_evaluator_trace(project_root, evaluator_id) |
| 71 | + if not trace or not isinstance(trace, dict): |
| 72 | + print( |
| 73 | + "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." |
| 74 | + ) |
| 75 | + return 1 |
| 76 | + evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id |
| 77 | + |
| 78 | + # Determine dataset id and materialization path |
| 79 | + dataset_id = getattr(args, "dataset_id", None) |
| 80 | + dataset_jsonl = getattr(args, "dataset_jsonl", None) |
| 81 | + dataset_display_name = getattr(args, "dataset_display_name", None) |
| 82 | + dataset_builder = getattr(args, "dataset_builder", None) |
| 83 | + |
| 84 | + if not dataset_id: |
| 85 | + if not dataset_jsonl: |
| 86 | + # Try builder from args, else from trace detection |
| 87 | + builder_spec = dataset_builder or trace.get("dataset_builder") |
| 88 | + if not builder_spec: |
| 89 | + # Attempt detect from metric_dir |
| 90 | + metric_dir = trace.get("metric_dir") |
| 91 | + if metric_dir: |
| 92 | + builder_spec = detect_dataset_builder(metric_dir) |
| 93 | + if not builder_spec: |
| 94 | + print( |
| 95 | + "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." |
| 96 | + ) |
| 97 | + return 1 |
| 98 | + try: |
| 99 | + dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) |
| 100 | + print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") |
| 101 | + except Exception as e: |
| 102 | + print(f"Error: dataset builder failed: {e}") |
| 103 | + return 1 |
| 104 | + |
| 105 | + inferred_dataset_id = build_default_dataset_id(evaluator_id) |
| 106 | + if dry_run: |
| 107 | + print("--dry-run: would create dataset and upload JSONL") |
| 108 | + dataset_id = inferred_dataset_id |
| 109 | + else: |
| 110 | + try: |
| 111 | + dataset_id, _ = create_dataset_from_jsonl( |
| 112 | + account_id=account_id, |
| 113 | + api_key=api_key, |
| 114 | + api_base=api_base, |
| 115 | + dataset_id=inferred_dataset_id, |
| 116 | + display_name=dataset_display_name or inferred_dataset_id, |
| 117 | + jsonl_path=dataset_jsonl, |
| 118 | + ) |
| 119 | + print(f"✓ Created and uploaded dataset: {dataset_id}") |
| 120 | + except Exception as e: |
| 121 | + print(f"Error creating/uploading dataset: {e}") |
| 122 | + return 1 |
| 123 | + |
| 124 | + # Build training config/body |
| 125 | + training_config: Dict[str, Any] = {} |
| 126 | + if getattr(args, "base_model", None): |
| 127 | + training_config["baseModel"] = args.base_model |
| 128 | + if getattr(args, "warm_start_from", None): |
| 129 | + training_config["warmStartFrom"] = args.warm_start_from |
| 130 | + if "baseModel" not in training_config and "warmStartFrom" not in training_config: |
| 131 | + # Provide a conservative default if neither is set |
| 132 | + training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct" |
| 133 | + |
| 134 | + # Optional hyperparameters |
| 135 | + for key, arg_name in [ |
| 136 | + ("epochs", "epochs"), |
| 137 | + ("batchSize", "batch_size"), |
| 138 | + ("learningRate", "learning_rate"), |
| 139 | + ("maxContextLength", "max_context_length"), |
| 140 | + ("loraRank", "lora_rank"), |
| 141 | + ("acceleratorCount", "accelerator_count"), |
| 142 | + ("region", "region"), |
| 143 | + ]: |
| 144 | + val = getattr(args, arg_name, None) |
| 145 | + if val is not None: |
| 146 | + training_config[key] = val |
| 147 | + |
| 148 | + inference_params: Dict[str, Any] = {} |
| 149 | + for key, arg_name in [ |
| 150 | + ("temperature", "temperature"), |
| 151 | + ("topP", "top_p"), |
| 152 | + ("topK", "top_k"), |
| 153 | + ("maxTokens", "max_tokens"), |
| 154 | + ("n", "n"), |
| 155 | + ]: |
| 156 | + val = getattr(args, arg_name, None) |
| 157 | + if val is not None: |
| 158 | + inference_params[key] = val |
| 159 | + if getattr(args, "inference_extra_body", None): |
| 160 | + inference_params["extraBody"] = args.inference_extra_body |
| 161 | + |
| 162 | + wandb_config: Optional[Dict[str, Any]] = None |
| 163 | + if getattr(args, "wandb_enabled", False): |
| 164 | + wandb_config = { |
| 165 | + "enabled": True, |
| 166 | + "apiKey": getattr(args, "wandb_api_key", None), |
| 167 | + "project": getattr(args, "wandb_project", None), |
| 168 | + "entity": getattr(args, "wandb_entity", None), |
| 169 | + "runId": getattr(args, "wandb_run_id", None), |
| 170 | + } |
| 171 | + |
| 172 | + body: Dict[str, Any] = { |
| 173 | + "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft", |
| 174 | + "dataset": dataset_id, |
| 175 | + "evaluator": evaluator_resource_name, |
| 176 | + "evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)), |
| 177 | + "trainingConfig": training_config, |
| 178 | + "inferenceParameters": inference_params or None, |
| 179 | + "wandbConfig": wandb_config, |
| 180 | + "outputStats": None, |
| 181 | + "outputMetrics": None, |
| 182 | + "mcpServer": None, |
| 183 | + } |
| 184 | + if getattr(args, "evaluation_dataset", None): |
| 185 | + body["evaluationDataset"] = args.evaluation_dataset |
| 186 | + if getattr(args, "output_model", None): |
| 187 | + body.setdefault("trainingConfig", {})["outputModel"] = args.output_model |
| 188 | + else: |
| 189 | + body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id) |
| 190 | + |
| 191 | + # Clean None fields to avoid noisy payloads |
| 192 | + body = {k: v for k, v in body.items() if v is not None} |
| 193 | + |
| 194 | + if dry_run: |
| 195 | + print("--dry-run: would create RFT job with body:") |
| 196 | + print(json.dumps(body, indent=2)) |
| 197 | + _print_links(evaluator_id, dataset_id, None) |
| 198 | + return 0 |
| 199 | + |
| 200 | + try: |
| 201 | + result = create_reinforcement_fine_tuning_job( |
| 202 | + account_id=account_id, api_key=api_key, api_base=api_base, body=body |
| 203 | + ) |
| 204 | + job_name = result.get("name") if isinstance(result, dict) else None |
| 205 | + print("\n✅ Created Reinforcement Fine-tuning Job") |
| 206 | + if job_name: |
| 207 | + print(f" name: {job_name}") |
| 208 | + _print_links(evaluator_id, dataset_id, job_name) |
| 209 | + return 0 |
| 210 | + except Exception as e: |
| 211 | + print(f"Error creating RFT job: {e}") |
| 212 | + return 1 |
0 commit comments