|
1 | 1 | import json |
2 | 2 | import os |
3 | 3 | import sys |
| 4 | +import time |
| 5 | +import argparse |
4 | 6 | from typing import Any, Dict, Optional |
5 | 7 |
|
6 | 8 | from ..auth import ( |
|
11 | 13 | ) |
12 | 14 | from ..fireworks_rft import ( |
13 | 15 | _map_api_host_to_app_host, |
14 | | - build_default_dataset_id, |
15 | | - build_default_output_model, |
16 | 16 | create_dataset_from_jsonl, |
17 | 17 | create_reinforcement_fine_tuning_job, |
18 | | - detect_dataset_builder, |
19 | | - load_evaluator_trace, |
20 | | - materialize_dataset_via_builder, |
21 | 18 | ) |
22 | 19 | from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source |
23 | 20 |
|
@@ -58,6 +55,129 @@ def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> |
58 | 55 | pass |
59 | 56 |
|
60 | 57 |
|
| 58 | +def _auto_find_jsonl(cwd: str) -> Optional[str]: |
| 59 | + """Find a reasonable JSONL dataset file in the current project. |
| 60 | +
|
| 61 | + Priority order: |
| 62 | + - dataset.jsonl in cwd |
| 63 | + - data/dataset.jsonl |
| 64 | + - first *.jsonl under cwd (depth-first, skipping common vendor/venv/build dirs) |
| 65 | + Returns a RELATIVE path from cwd if possible. |
| 66 | + """ |
| 67 | + # Direct candidates |
| 68 | + direct_candidates = [ |
| 69 | + os.path.join(cwd, "dataset.jsonl"), |
| 70 | + os.path.join(cwd, "data", "dataset.jsonl"), |
| 71 | + ] |
| 72 | + for p in direct_candidates: |
| 73 | + if os.path.isfile(p): |
| 74 | + try: |
| 75 | + return os.path.relpath(p, cwd) |
| 76 | + except Exception: |
| 77 | + return p |
| 78 | + |
| 79 | + # Walk and find any .jsonl |
| 80 | + skip_dirs = {".venv", "venv", "node_modules", "dist", "build", "__pycache__", ".git", "vendor"} |
| 81 | + for dirpath, dirnames, filenames in os.walk(cwd): |
| 82 | + # prune |
| 83 | + dirnames[:] = [d for d in dirnames if d not in skip_dirs and not d.startswith(".")] |
| 84 | + for name in sorted(filenames): |
| 85 | + if name.endswith(".jsonl"): |
| 86 | + candidate = os.path.join(dirpath, name) |
| 87 | + try: |
| 88 | + return os.path.relpath(candidate, cwd) |
| 89 | + except Exception: |
| 90 | + return candidate |
| 91 | + return None |
| 92 | + |
| 93 | + |
| 94 | +def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: |
| 95 | + """Import the test module and extract a JSONL path from data_loaders param if present. |
| 96 | +
|
| 97 | + Looks for a pytest.mark.parametrize with argnames containing 'data_loaders' and attempts to |
| 98 | + find an object with attribute 'jsonl_path'. If a relative path is found, it is resolved |
| 99 | + relative to the directory of the test file. |
| 100 | + """ |
| 101 | + try: |
| 102 | + import importlib.util |
| 103 | + from pathlib import Path |
| 104 | + |
| 105 | + spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path) |
| 106 | + if not spec or not spec.loader: |
| 107 | + return None |
| 108 | + module = importlib.util.module_from_spec(spec) |
| 109 | + sys.modules[spec.name] = module |
| 110 | + spec.loader.exec_module(module) # type: ignore[attr-defined] |
| 111 | + if not hasattr(module, test_func_name): |
| 112 | + return None |
| 113 | + wrapper = getattr(module, test_func_name) |
| 114 | + marks = getattr(wrapper, "pytestmark", []) |
| 115 | + for m in marks: |
| 116 | + if getattr(m, "name", "") == "parametrize": |
| 117 | + kwargs = getattr(m, "kwargs", {}) |
| 118 | + argnames = kwargs.get("argnames", (m.args[0] if m.args else [])) |
| 119 | + argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else [])) |
| 120 | + # Normalize argnames to list |
| 121 | + if isinstance(argnames, str): |
| 122 | + names_list = [n.strip() for n in argnames.split(",") if n.strip()] |
| 123 | + else: |
| 124 | + names_list = list(argnames) |
| 125 | + if "data_loaders" not in names_list: |
| 126 | + continue |
| 127 | + idx = names_list.index("data_loaders") |
| 128 | + # argvalues is a list of tuples/values aligned with argnames |
| 129 | + for val in argvalues: |
| 130 | + # Normalize to tuple |
| 131 | + if not isinstance(val, (tuple, list)): |
| 132 | + params = (val,) |
| 133 | + else: |
| 134 | + params = tuple(val) |
| 135 | + if idx >= len(params): |
| 136 | + continue |
| 137 | + dataloaders_obj = params[idx] |
| 138 | + # May be a list or single loader |
| 139 | + candidates = ( |
| 140 | + list(dataloaders_obj) if isinstance(dataloaders_obj, (list, tuple)) else [dataloaders_obj] |
| 141 | + ) |
| 142 | + for dl in candidates: |
| 143 | + jsonl_path = getattr(dl, "jsonl_path", None) |
| 144 | + if isinstance(jsonl_path, str) and jsonl_path: |
| 145 | + if os.path.isabs(jsonl_path): |
| 146 | + return jsonl_path |
| 147 | + base_dir = os.path.dirname(os.path.abspath(test_file_path)) |
| 148 | + return os.path.abspath(os.path.join(base_dir, jsonl_path)) |
| 149 | + return None |
| 150 | + except Exception: |
| 151 | + return None |
| 152 | + |
| 153 | + |
| 154 | +def _build_trimmed_dataset_id(evaluator_id: str) -> str: |
| 155 | + """Build a dataset id derived from evaluator_id, trimmed to 63 chars. |
| 156 | +
|
| 157 | + Format: <normalized-base>-dataset-YYYYMMDDHHMMSS, where base is trimmed to fit. |
| 158 | + """ |
| 159 | + # Normalize base similarly to evaluator id rules |
| 160 | + from .upload import _normalize_evaluator_id # local import to avoid cycle at module import time |
| 161 | + |
| 162 | + base = _normalize_evaluator_id(evaluator_id) |
| 163 | + suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}" |
| 164 | + max_total = 63 |
| 165 | + max_base_len = max_total - len(suffix) |
| 166 | + if max_base_len < 1: |
| 167 | + max_base_len = 1 |
| 168 | + if len(base) > max_base_len: |
| 169 | + base = base[:max_base_len].rstrip("-") |
| 170 | + if not base: |
| 171 | + base = "dataset" |
| 172 | + # Ensure first char is a letter |
| 173 | + if not base[0].isalpha(): |
| 174 | + base = f"eval-{base}" |
| 175 | + if len(base) > max_base_len: |
| 176 | + base = base[:max_base_len] |
| 177 | + base = base.rstrip("-") or "dataset" |
| 178 | + return f"{base}{suffix}" |
| 179 | + |
| 180 | + |
61 | 181 | def _auto_select_evaluator_id(cwd: str) -> Optional[str]: |
62 | 182 | # Try local traces |
63 | 183 | traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators") |
@@ -101,75 +221,117 @@ def create_rft_command(args) -> int: |
101 | 221 | print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") |
102 | 222 | return 1 |
103 | 223 |
|
104 | | - # Resolve evaluator resource name via local trace |
105 | | - # trace = load_evaluator_trace(project_root, evaluator_id) |
106 | | - # if not trace or not isinstance(trace, dict): |
107 | | - # print( |
108 | | - # "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." |
109 | | - # ) |
110 | | - # return 1 |
111 | | - # evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id |
112 | | - evaluator_resource_name = evaluator_id |
| 224 | + # Resolve evaluator resource name to fully-qualified format required by API |
| 225 | + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" |
| 226 | + |
| 227 | + # Ensure evaluator exists by invoking the upload flow programmatically |
| 228 | + try: |
| 229 | + from .upload import upload_command |
| 230 | + |
| 231 | + tests = _discover_tests(project_root) |
| 232 | + selected_entry: Optional[str] = None |
| 233 | + if len(tests) == 1: |
| 234 | + func_name = tests[0].qualname.split(".")[-1] |
| 235 | + abs_path = os.path.abspath(tests[0].file_path) |
| 236 | + try: |
| 237 | + rel = os.path.relpath(abs_path, project_root) |
| 238 | + except Exception: |
| 239 | + rel = abs_path |
| 240 | + selected_entry = f"{rel}::{func_name}" |
| 241 | + else: |
| 242 | + # Try to match evaluator_id to a discovered test's normalized ID |
| 243 | + for t in tests: |
| 244 | + func_name = t.qualname.split(".")[-1] |
| 245 | + source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] |
| 246 | + candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") |
| 247 | + if candidate == evaluator_id: |
| 248 | + abs_path = os.path.abspath(t.file_path) |
| 249 | + try: |
| 250 | + rel = os.path.relpath(abs_path, project_root) |
| 251 | + except Exception: |
| 252 | + rel = abs_path |
| 253 | + selected_entry = f"{rel}::{func_name}" |
| 254 | + break |
| 255 | + |
| 256 | + upload_args = argparse.Namespace( |
| 257 | + path=project_root, |
| 258 | + entry=selected_entry, |
| 259 | + id=evaluator_id, |
| 260 | + display_name=None, |
| 261 | + description=None, |
| 262 | + force=False, |
| 263 | + yes=True, |
| 264 | + ) |
| 265 | + rc = upload_command(upload_args) |
| 266 | + if rc == 0: |
| 267 | + print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") |
| 268 | + else: |
| 269 | + print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") |
| 270 | + except Exception as e: |
| 271 | + print(f"Warning: Failed to upload evaluator automatically: {e}") |
113 | 272 |
|
114 | 273 | # Determine dataset id and materialization path |
115 | 274 | dataset_id = getattr(args, "dataset_id", None) |
116 | 275 | dataset_jsonl = getattr(args, "dataset_jsonl", None) |
117 | 276 | dataset_display_name = getattr(args, "dataset_display_name", None) |
118 | | - dataset_builder = getattr(args, "dataset_builder", None) |
| 277 | + dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow |
119 | 278 |
|
120 | 279 | if not dataset_id: |
121 | | - # Try builder from args, else from trace detection |
122 | | - # TODO: build dataset from traces directly |
123 | | - # builder_spec = dataset_builder or trace.get("dataset_builder") |
124 | | - # if not builder_spec: |
125 | | - # # Attempt detect from metric_dir |
126 | | - # metric_dir = trace.get("metric_dir") |
127 | | - # if metric_dir: |
128 | | - # builder_spec = detect_dataset_builder(metric_dir) |
129 | | - # if not builder_spec: |
130 | | - # print( |
131 | | - # "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." |
132 | | - # ) |
133 | | - # return 1 |
134 | | - # try: |
135 | | - # dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) |
136 | | - # print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") |
137 | | - # except Exception as e: |
138 | | - # print(f"Error: dataset builder failed: {e}") |
139 | | - # return 1 |
140 | | - |
| 280 | + # Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test |
141 | 281 | if not dataset_jsonl: |
142 | | - print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.") |
| 282 | + tests = _discover_tests(project_root) |
| 283 | + if len(tests) == 1: |
| 284 | + func_name = tests[0].qualname.split(".")[-1] |
| 285 | + dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name) |
| 286 | + if dataset_jsonl: |
| 287 | + # Display relative path for readability |
| 288 | + try: |
| 289 | + rel = os.path.relpath(dataset_jsonl, project_root) |
| 290 | + except Exception: |
| 291 | + rel = dataset_jsonl |
| 292 | + print(f"✓ Using JSONL from data loader: {rel}") |
| 293 | + if not dataset_jsonl: |
| 294 | + print( |
| 295 | + "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test." |
| 296 | + ) |
143 | 297 | return 1 |
144 | 298 |
|
145 | | - inferred_dataset_id = build_default_dataset_id(evaluator_id) |
| 299 | + inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) |
146 | 300 | if dry_run: |
147 | 301 | print("--dry-run: would create dataset and upload JSONL") |
148 | 302 | dataset_id = inferred_dataset_id |
149 | 303 | else: |
150 | 304 | try: |
| 305 | + # Resolve dataset_jsonl path relative to CWD if needed |
| 306 | + jsonl_path_for_upload = ( |
| 307 | + dataset_jsonl |
| 308 | + if os.path.isabs(dataset_jsonl) |
| 309 | + else os.path.abspath(os.path.join(project_root, dataset_jsonl)) |
| 310 | + ) |
151 | 311 | dataset_id, _ = create_dataset_from_jsonl( |
152 | 312 | account_id=account_id, |
153 | 313 | api_key=api_key, |
154 | 314 | api_base=api_base, |
155 | 315 | dataset_id=inferred_dataset_id, |
156 | 316 | display_name=dataset_display_name or inferred_dataset_id, |
157 | | - jsonl_path=dataset_jsonl, |
| 317 | + jsonl_path=jsonl_path_for_upload, |
158 | 318 | ) |
159 | 319 | print(f"✓ Created and uploaded dataset: {dataset_id}") |
160 | 320 | except Exception as e: |
161 | 321 | print(f"Error creating/uploading dataset: {e}") |
162 | 322 | return 1 |
163 | 323 |
|
164 | 324 | # Build training config/body |
165 | | - training_config: Dict[str, Any] = {} |
166 | | - if getattr(args, "base_model", None): |
167 | | - training_config["baseModel"] = args.base_model |
| 325 | + # Ensure base model is explicitly provided for clarity |
| 326 | + if not getattr(args, "base_model", None): |
| 327 | + print( |
| 328 | + "Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/<model_id>)." |
| 329 | + ) |
| 330 | + return 1 |
| 331 | + |
| 332 | + training_config: Dict[str, Any] = {"baseModel": args.base_model} |
168 | 333 | if getattr(args, "warm_start_from", None): |
169 | 334 | training_config["warmStartFrom"] = args.warm_start_from |
170 | | - if "baseModel" not in training_config and "warmStartFrom" not in training_config: |
171 | | - # Provide a conservative default if neither is set |
172 | | - training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct" |
173 | 335 |
|
174 | 336 | # Optional hyperparameters |
175 | 337 | for key, arg_name in [ |
@@ -221,14 +383,12 @@ def create_rft_command(args) -> int: |
221 | 383 | "outputMetrics": None, |
222 | 384 | "mcpServer": None, |
223 | 385 | } |
224 | | - print("Show body:") |
225 | | - print(json.dumps(body, indent=2)) |
| 386 | + # Debug: print minimal summary |
| 387 | + print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") |
226 | 388 | if getattr(args, "evaluation_dataset", None): |
227 | 389 | body["evaluationDataset"] = args.evaluation_dataset |
228 | 390 | if getattr(args, "output_model", None): |
229 | 391 | body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}" |
230 | | - else: |
231 | | - body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id) |
232 | 392 |
|
233 | 393 | # Clean None fields to avoid noisy payloads |
234 | 394 | body = {k: v for k, v in body.items() if v is not None} |
|
0 commit comments