diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 8b6bce4d..c0ff0358 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -20,88 +20,8 @@ create_dataset_from_jsonl, create_reinforcement_fine_tuning_job, ) -from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source - - -def _last_evaluator_paths(cwd: str) -> list[str]: - return [ - os.path.join(cwd, ".eval_protocol", "last_evaluator.json"), - os.path.expanduser(os.path.join("~", ".eval_protocol", "last_evaluator.json")), - ] - - -def _load_last_evaluator(cwd: str) -> Optional[str]: - import json - - for p in _last_evaluator_paths(cwd): - try: - if os.path.isfile(p): - with open(p, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, dict) and data.get("evaluator_id"): - return str(data["evaluator_id"]) - except Exception: - # ignore and continue - pass - return None - - -def _save_last_evaluator(cwd: str, evaluator_id: str) -> None: - import json - - base = os.path.join(cwd, ".eval_protocol") - try: - os.makedirs(base, exist_ok=True) - with open(os.path.join(base, "last_evaluator.json"), "w", encoding="utf-8") as f: - json.dump({"evaluator_id": evaluator_id, "ts": time.time()}, f) - except Exception: - # best-effort only - pass - - -def _gather_evaluator_traces(cwd: str) -> list[dict]: - roots = [ - os.path.join(cwd, ".eval_protocol", "evaluators"), - os.path.expanduser(os.path.join("~", ".eval_protocol", "evaluators")), - ] - records: list[dict] = [] - for root in roots: - if os.path.isdir(root): - for name in os.listdir(root): - if name.endswith(".json"): - full = os.path.join(root, name) - try: - mtime = os.path.getmtime(full) - except Exception: - mtime = 0.0 - records.append({"id": name[:-5], "path": full, "mtime": mtime}) - # dedupe by id keeping most recent mtime - dedup: dict[str, dict] = {} - for rec in records: - cur = dedup.get(rec["id"]) - if not cur or rec["mtime"] > cur["mtime"]: - dedup[rec["id"]] = rec - return list(dedup.values()) - - -def _prompt_select_evaluator(candidates: list[dict]) -> Optional[str]: - print("\nMultiple evaluators detected. Select one:") - ordered = sorted(candidates, key=lambda x: -x["mtime"]) - for i, c in enumerate(ordered, start=1): - print(f" {i}) {c['id']} (from {c['path']})") - try: - choice = input("Enter a number (or press Enter to cancel): ").strip() - except KeyboardInterrupt: - print("\nCancelled.") - return None - if not choice or not choice.isdigit(): - return None - n = int(choice) - if 1 <= n <= len(ordered): - sel = ordered[n - 1]["id"] - print(f"āœ“ Using evaluator: {sel}") - return sel - return None +from ..fireworks_rft import detect_dataset_builder, materialize_dataset_via_builder +from .upload import _discover_tests, _normalize_evaluator_id, _prompt_select def _ensure_account_id() -> Optional[str]: @@ -331,35 +251,35 @@ def _build_trimmed_dataset_id(evaluator_id: str) -> str: return f"{base}{suffix}" -def _auto_select_evaluator_id(cwd: str, *, non_interactive: bool = False) -> Optional[str]: - # 1) Use last used pointer if available - last = _load_last_evaluator(cwd) - if last: - return last - - # 2) Look for evaluator traces in project and home - traces = _gather_evaluator_traces(cwd) - if len(traces) == 1: - return traces[0]["id"] - if len(traces) > 1: - if non_interactive: - sel = sorted(traces, key=lambda x: -x["mtime"])[0]["id"] - print(f"āš ļø Multiple evaluators found; using most recent: {sel}. Override with --evaluator-id.") - return sel - chosen = _prompt_select_evaluator(traces) - if chosen: - return chosen - return None - - # 3) Fall back to discovering a single evaluation_test - tests = _discover_tests(cwd) - if len(tests) == 1: - qualname, source_file_path = tests[0].qualname, tests[0].file_path - test_func_name = qualname.split(".")[-1] - source_file_name = os.path.splitext(os.path.basename(source_file_path))[0] - evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{test_func_name}") - return evaluator_id - return None +def _resolve_selected_test( + project_root: str, + evaluator_id: Optional[str], + selected_tests: Optional[list] = None, +) -> tuple[Optional[str], Optional[str]]: + """ + Resolve a single test's source file path and function name to use downstream. + Priority: + 1) If selected_tests provided and length == 1, use it. + 2) Else discover tests; if exactly one test, use it. + 3) Else, if evaluator_id provided, match by normalized '-'. + Returns: (file_path, func_name) or (None, None) if unresolved. + """ + try: + tests = selected_tests if selected_tests is not None else _discover_tests(project_root) + if not tests: + return None, None + if len(tests) == 1: + return tests[0].file_path, tests[0].qualname.split(".")[-1] + if evaluator_id: + for t in tests: + func_name = t.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] + candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + if candidate == evaluator_id: + return t.file_path, func_name + return None, None + except Exception: + return None, None def _poll_evaluator_status( @@ -428,6 +348,9 @@ def create_rft_command(args) -> int: non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) force: bool = bool(getattr(args, "force", False)) + # Track the specifically chosen test (if any) to aid dataset inference later + selected_test_file_path: Optional[str] = None + selected_test_func_name: Optional[str] = None api_key = get_fireworks_api_key() if not api_key: @@ -441,13 +364,52 @@ def create_rft_command(args) -> int: api_base = get_fireworks_api_base() - # Resolve evaluator id if omitted + # Resolve evaluator id/entry if omitted (reuse upload's selector flow) project_root = os.getcwd() if not evaluator_id: - evaluator_id = _auto_select_evaluator_id(project_root, non_interactive=non_interactive) - if not evaluator_id: - print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") + print("Scanning for evaluation tests...") + tests = _discover_tests(project_root) + if not tests: + print("No evaluation tests found.") + print("\nHint: Make sure your tests use the @evaluation_test decorator.") + return 1 + # Always interactive selection here (no implicit quiet unless --evaluator-id was provided) + try: + selected_tests = _prompt_select(tests, non_interactive=non_interactive) + except Exception: + print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.") + return 1 + if not selected_tests: + print("No tests selected.") return 1 + if len(selected_tests) != 1: + if non_interactive and len(selected_tests) > 1: + print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") + print(" Please pass --evaluator-id or --entry to disambiguate.") + try: + # Offer candidate evaluator ids for convenience + tests = _discover_tests(project_root) + if tests: + print(" Candidate evaluator ids:") + for t in tests: + func = t.qualname.split(".")[-1] + stem = os.path.splitext(os.path.basename(t.file_path))[0] + cand = _normalize_evaluator_id(f"{stem}-{func}") + print(f" - {cand}") + except Exception: + pass + else: + print("Error: Please select exactly one evaluation test for 'create rft'.") + return 1 + # Derive evaluator_id from user's single selection + chosen = selected_tests[0] + func_name = chosen.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0] + evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + # Resolve selected test once for downstream + selected_test_file_path, selected_test_func_name = _resolve_selected_test( + project_root, evaluator_id, selected_tests=selected_tests + ) # Resolve evaluator resource name to fully-qualified format required by API evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" @@ -479,8 +441,12 @@ def create_rft_command(args) -> int: print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") return 1 - _save_last_evaluator(project_root, evaluator_id) skip_upload = True + # Populate selected test info for dataset inference later + st_path, st_func = _resolve_selected_test(project_root, evaluator_id) + if st_path and st_func: + selected_test_file_path = st_path + selected_test_func_name = st_func except requests.exceptions.RequestException: pass @@ -491,28 +457,16 @@ def create_rft_command(args) -> int: tests = _discover_tests(project_root) selected_entry: Optional[str] = None - if len(tests) == 1: - func_name = tests[0].qualname.split(".")[-1] - abs_path = os.path.abspath(tests[0].file_path) + st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) + if st_path and st_func: + abs_path = os.path.abspath(st_path) try: rel = os.path.relpath(abs_path, project_root) except Exception: rel = abs_path - selected_entry = f"{rel}::{func_name}" - else: - # Try to match evaluator_id to a discovered test's normalized ID - for t in tests: - func_name = t.qualname.split(".")[-1] - source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] - candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") - if candidate == evaluator_id: - abs_path = os.path.abspath(t.file_path) - try: - rel = os.path.relpath(abs_path, project_root) - except Exception: - rel = abs_path - selected_entry = f"{rel}::{func_name}" - break + selected_entry = f"{rel}::{st_func}" + selected_test_file_path = st_path + selected_test_func_name = st_func # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators if selected_entry is None and len(tests) > 1: print( @@ -561,8 +515,8 @@ def create_rft_command(args) -> int: print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") return 1 else: - # Only persist last-used evaluator after successful ensure + ACTIVE - _save_last_evaluator(project_root, evaluator_id) + # Evaluator ACTIVE; proceed + pass else: print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") except Exception as e: @@ -575,30 +529,48 @@ def create_rft_command(args) -> int: dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow if not dataset_id: - # Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test + # Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset. if not dataset_jsonl: - tests = _discover_tests(project_root) - if len(tests) == 1: - func_name = tests[0].qualname.split(".")[-1] - # Try data_loaders first (existing behavior) - dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name) + # Use specifically selected test if available; else only infer when exactly one test exists + test_file_for_infer = None + func_for_infer = None + if selected_test_file_path and selected_test_func_name: + test_file_for_infer = selected_test_file_path + func_for_infer = selected_test_func_name + else: + tests = _discover_tests(project_root) + if len(tests) == 1: + test_file_for_infer = tests[0].file_path + func_for_infer = tests[0].qualname.split(".")[-1] + if test_file_for_infer and func_for_infer: + # Try data_loaders first + dataset_jsonl = _extract_jsonl_from_dataloader(test_file_for_infer, func_for_infer) if dataset_jsonl: - # Display relative path for readability try: rel = os.path.relpath(dataset_jsonl, project_root) except Exception: rel = dataset_jsonl print(f"āœ“ Using JSONL from data loader: {rel}") - else: + if not dataset_jsonl: # Fall back to input_dataset (dataset_path) - dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name) + dataset_jsonl = _extract_jsonl_from_input_dataset(test_file_for_infer, func_for_infer) if dataset_jsonl: - # Display relative path for readability try: rel = os.path.relpath(dataset_jsonl, project_root) except Exception: rel = dataset_jsonl print(f"āœ“ Using JSONL from input_dataset: {rel}") + if not dataset_jsonl: + # Last resort: attempt to detect and run a dataset builder in the test's directory + metric_dir = os.path.dirname(test_file_for_infer) + builder_spec = detect_dataset_builder(metric_dir) + if builder_spec: + try: + tmp_jsonl, count = materialize_dataset_via_builder(builder_spec) + dataset_jsonl = tmp_jsonl + print(f"āœ“ Materialized {count} rows via dataset builder: {builder_spec}") + except Exception as e: + print(f"Warning: dataset builder failed: {e}") if not dataset_jsonl: print( "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test." diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index aee5a8b4..d2b4100b 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -21,7 +21,6 @@ from eval_protocol.platform_api import create_or_update_fireworks_secret from eval_protocol.evaluation import create_evaluation -from eval_protocol.fireworks_rft import save_evaluator_trace, detect_dataset_builder @dataclass @@ -444,49 +443,25 @@ def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTe else: return [] - # Enter-only selection UX with optional multi-select via repeat - remaining_indices = list(range(len(tests))) - selected_indices: list[int] = [] - + # Single-select UX print("\n") - print("Tip: Use ↑/↓ arrows to navigate and press ENTER to select.") - print(" After selecting one, you can choose to add more.\n") - - while remaining_indices: - # Build choices from remaining - choices = [] - for idx, test_idx in enumerate(remaining_indices, 1): - t = tests[test_idx] - choice_text = _format_test_choice(t, idx) - choices.append({"name": choice_text, "value": test_idx}) - - selected = questionary.select( - "Select an evaluation test to upload:", choices=choices, style=custom_style - ).ask() - - if selected is None: # Ctrl+C - print("\nUpload cancelled.") - return [] + print("Tip: Use ↑/↓ arrows to navigate and press ENTER to select.\n") - if isinstance(selected, int): - selected_indices.append(selected) - # Remove from remaining - if selected in remaining_indices: - remaining_indices.remove(selected) + choices = [] + for idx, t in enumerate(tests, 1): + choice_text = _format_test_choice(t, idx) + choices.append({"name": choice_text, "value": idx - 1}) - # Ask whether to add another (ENTER to finish) - add_more = questionary.confirm("Add another?", default=False, style=custom_style).ask() - if not add_more: - break - else: - break + selected = questionary.select( + "Select an evaluation test to upload:", choices=choices, style=custom_style + ).ask() - if not selected_indices: - print("\nāš ļø No tests were selected.") + if selected is None: # Ctrl+C + print("\nUpload cancelled.") return [] - print(f"\nāœ“ Selected {len(selected_indices)} test(s)") - return [tests[i] for i in selected_indices] + print("\nāœ“ Selected 1 test") + return [tests[selected]] except ImportError: # Fallback to simpler implementation @@ -525,22 +500,19 @@ def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest] print("=" * 80) try: - choice = input("Enter numbers to upload (comma or space-separated), or 'all': ").strip() + choice = input("Enter the number to upload: ").strip() except KeyboardInterrupt: print("\n\nUpload cancelled.") return [] - if choice.lower() in ("all", "a", "*"): - return tests - - indices: list[int] = [] - for token in re.split(r"[\s,]+", choice): - if token.isdigit(): - n = int(token) - if 1 <= n <= len(tests): - indices.append(n - 1) - indices = sorted(set(indices)) - return [tests[i] for i in indices] + if not choice.isdigit(): + print("\nāš ļø Invalid selection.") + return [] + n = int(choice) + if not (1 <= n <= len(tests)): + print("\nāš ļø Selection out of range.") + return [] + return [tests[n - 1]] def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]: @@ -718,23 +690,6 @@ def upload_command(args: argparse.Namespace) -> int: ) name = result.get("name", evaluator_id) if isinstance(result, dict) else evaluator_id - # Persist local evaluator trace for later `create rft` - try: - metric_dir = os.path.dirname(source_file_path) if source_file_path else root - builder_spec = detect_dataset_builder(metric_dir) or None - trace_payload = { - "evaluator_id": evaluator_id, - "evaluator_resource_name": name, - "entry_point": entry_point, - "metric_dir": metric_dir, - "project_root": root, - "dataset_builder": builder_spec, - } - save_evaluator_trace(project_root=root, evaluator_id=evaluator_id, trace=trace_payload) - except Exception: - # Non-fatal; continue - pass - # Print success message with Fireworks dashboard link print(f"\nāœ… Successfully uploaded evaluator: {evaluator_id}") print("šŸ“Š View in Fireworks Dashboard:") diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 05b49291..73472896 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -37,25 +37,6 @@ def _map_api_host_to_app_host(api_base: str) -> str: return "https://app.fireworks.ai" -def load_evaluator_trace(project_root: str, evaluator_id: str) -> Optional[Dict[str, Any]]: - trace_path = Path(project_root) / ".eval_protocol" / "evaluators" / f"{evaluator_id}.json" - if not trace_path.exists(): - return None - try: - with open(trace_path, "r", encoding="utf-8") as f: - return json.load(f) - except Exception: - return None - - -def save_evaluator_trace(project_root: str, evaluator_id: str, trace: Dict[str, Any]) -> None: - base_dir = Path(project_root) / ".eval_protocol" / "evaluators" - base_dir.mkdir(parents=True, exist_ok=True) - trace_path = base_dir / f"{evaluator_id}.json" - with open(trace_path, "w", encoding="utf-8") as f: - json.dump(trace, f, indent=2, ensure_ascii=False) - - def detect_dataset_builder(metric_dir: str) -> Optional[str]: """ Best-effort scan for a dataset builder callable inside the metric directory. @@ -228,8 +209,6 @@ def build_default_output_model(evaluator_id: str) -> str: __all__ = [ - "load_evaluator_trace", - "save_evaluator_trace", "detect_dataset_builder", "materialize_dataset_via_builder", "create_dataset_from_jsonl", diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft_infer.py index 42307253..86411f8d 100644 --- a/tests/test_cli_create_rft_infer.py +++ b/tests/test_cli_create_rft_infer.py @@ -15,117 +15,6 @@ def _write_json(path: str, data: dict) -> None: json.dump(data, f) -def test_load_and_save_last_evaluator(tmp_path, monkeypatch): - # Force HOME to temp so expanduser paths remain inside tmp - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # Initially none - assert cr._load_last_evaluator(str(project)) is None - - # Save and load - cr._save_last_evaluator(str(project), "evaluator-abc") - assert cr._load_last_evaluator(str(project)) == "evaluator-abc" - - -def test_auto_select_uses_last_pointer(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # Write last pointer under project - last_path = project / ".eval_protocol" / "last_evaluator.json" - _write_json(str(last_path), {"evaluator_id": "chosen-id"}) - - eid = cr._auto_select_evaluator_id(str(project)) - assert eid == "chosen-id" - - -def test_auto_select_single_trace(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # Single evaluator trace under project - trace = project / ".eval_protocol" / "evaluators" / "only-one.json" - _write_json(str(trace), {"dummy": True}) - - eid = cr._auto_select_evaluator_id(str(project)) - assert eid == "only-one" - - -def test_auto_select_multiple_traces_non_interactive_most_recent(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # Two traces with different mtimes - older = project / ".eval_protocol" / "evaluators" / "older.json" - newer = project / ".eval_protocol" / "evaluators" / "newer.json" - _write_json(str(older), {}) - _write_json(str(newer), {}) - # Set older then newer mtime - t0 = time.time() - 100 - os.utime(str(older), (t0, t0)) - t1 = time.time() - os.utime(str(newer), (t1, t1)) - - eid = cr._auto_select_evaluator_id(str(project), non_interactive=True) - assert eid == "newer" - - -def test_auto_select_multiple_traces_interactive_prompt(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # Two traces with different mtimes to force ordering: newer first, older second - older = project / ".eval_protocol" / "evaluators" / "older.json" - newer = project / ".eval_protocol" / "evaluators" / "newer.json" - _write_json(str(older), {}) - _write_json(str(newer), {}) - t0 = time.time() - 100 - os.utime(str(older), (t0, t0)) - t1 = time.time() - os.utime(str(newer), (t1, t1)) - - with patch("builtins.input", return_value="2"): - eid = cr._auto_select_evaluator_id(str(project), non_interactive=False) - # Choosing "2" should pick the second item by recency => "older" - assert eid == "older" - - -def test_auto_select_falls_back_to_single_discovered_test(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # No traces; provide exactly one discovered test - test_file = project / "metric" / "test_dummy.py" - test_file.parent.mkdir(parents=True, exist_ok=True) - test_file.write_text("# dummy", encoding="utf-8") - - dummy = SimpleNamespace(qualname="dummy_module.test_dummy_evaluation", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [dummy]) - - eid = cr._auto_select_evaluator_id(str(project)) - assert eid is not None - # Should incorporate function name suffix - assert "test_dummy_evaluation".split("_")[-1] in eid or "test-dummy-evaluation" in eid - - -def test_auto_select_returns_none_when_no_candidates(tmp_path, monkeypatch): - monkeypatch.setenv("HOME", str(tmp_path / "home")) - project = tmp_path / "proj" - project.mkdir() - - # No traces, no tests - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: []) - eid = cr._auto_select_evaluator_id(str(project)) - assert eid is None - - def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, monkeypatch): # Isolate HOME so expanduser paths remain inside tmp monkeypatch.setenv("HOME", str(tmp_path / "home")) @@ -135,18 +24,6 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, project.mkdir() monkeypatch.chdir(project) - # Prepare two evaluator traces with different mtimes - traces_dir = project / ".eval_protocol" / "evaluators" - traces_dir.mkdir(parents=True, exist_ok=True) - older = traces_dir / "example-eval-1.json" - newer = traces_dir / "example-eval-2.json" - older.write_text("{}", encoding="utf-8") - newer.write_text("{}", encoding="utf-8") - t0 = time.time() - 200 - os.utime(str(older), (t0, t0)) - t1 = time.time() - os.utime(str(newer), (t1, t1)) - # Create a dummy dataset jsonl file ds_path = project / "evaluator" / "dummy_dataset.jsonl" ds_path.parent.mkdir(parents=True, exist_ok=True) @@ -158,9 +35,16 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") # Stub out networked/subcommands used by create_rft - # Patch upload command in its own module (create_rft imports it at call time) + # Patch selector and upload import eval_protocol.cli_commands.upload as upload_mod + # Simulate exactly one discovered test and selector returning it + one_file = project / "metric" / "test_single.py" + one_file.parent.mkdir(parents=True, exist_ok=True) + one_file.write_text("# single", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_single", file_path=str(one_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) @@ -200,9 +84,9 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d rc = cr.create_rft_command(args) assert rc == 0 - # Assert dataset id followed the most recent evaluator id ("example-eval-2") + # Assert dataset id derived from selected test: metric-test_single assert captured["dataset_id"] is not None - assert captured["dataset_id"].startswith("example-eval-2-dataset-") + assert captured["dataset_id"].startswith("test-single-test-single-dataset-") def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(tmp_path, monkeypatch): @@ -214,21 +98,6 @@ def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(t project.mkdir() monkeypatch.chdir(project) - # Two evaluator traces: make the target evaluator the most recent - traces_dir = project / ".eval_protocol" / "evaluators" - traces_dir.mkdir(parents=True, exist_ok=True) - svg_id = "example-svg-evaluation" - # Use an evaluator id that matches normalization logic for mapping to foo_eval.py::test_bar_evaluation - target_id = cr._normalize_evaluator_id("foo_eval-test_bar_evaluation") - older = traces_dir / f"{svg_id}.json" - newer = traces_dir / f"{target_id}.json" - older.write_text("{}", encoding="utf-8") - newer.write_text("{}", encoding="utf-8") - t0 = time.time() - 200 - os.utime(str(older), (t0, t0)) - t1 = time.time() - os.utime(str(newer), (t1, t1)) - # Create dummy test files for discovery eval_dir = project / "evaluator" eval_dir.mkdir(parents=True, exist_ok=True) @@ -274,11 +143,11 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d ds_path = eval_dir / "dummy_dataset.jsonl" ds_path.write_text('{"input":"x"}\n', encoding="utf-8") - # Build args: non-interactive, no explicit evaluator id + # Build args: no explicit evaluator id, selector will not be used here; mapping by id import argparse args = argparse.Namespace( - evaluator_id=None, + evaluator_id=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), yes=True, dry_run=False, force=False, @@ -304,11 +173,538 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d rc = cr.create_rft_command(args) assert rc == 0 - # Assert evaluator_id passed to upload matches the most recent trace (target) - assert captured["id"] == target_id + # Assert evaluator_id passed to upload matches the provided id + assert captured["id"] == cr._normalize_evaluator_id("foo_eval-test_bar_evaluation") # Assert entry points to the foo test (should map when id matches normalization) assert captured["entry"] is not None and captured["entry"].endswith("foo_eval.py::test_bar_evaluation") # Assert dataset id is derived from the same evaluator id (trimmed base + '-dataset-') assert captured["dataset_id"] is not None - expected_prefix = cr._build_trimmed_dataset_id(target_id).split("-dataset-")[0] + "-dataset-" + expected_prefix = ( + cr._build_trimmed_dataset_id(cr._normalize_evaluator_id("foo_eval-test_bar_evaluation")).split("-dataset-")[0] + + "-dataset-" + ) assert captured["dataset_id"].startswith(expected_prefix) + + +def test_create_rft_interactive_selector_single_test(tmp_path, monkeypatch): + # Setup project + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Single discovered test + test_file = project / "metric" / "test_one.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# one", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_one", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + + # Environment + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Stub selector to return the single test; stub upload and polling + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + captured = {"id": None, "entry": None, "dataset_id": None} + + def _fake_upload(ns): + captured["id"] = getattr(ns, "id", None) + captured["entry"] = getattr(ns, "entry", None) + return 0 + + monkeypatch.setattr(upload_mod, "upload_command", _fake_upload) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # Provide dataset jsonl + ds_path = project / "metric" / "dataset.jsonl" + ds_path.write_text('{"input":"x"}\n', encoding="utf-8") + monkeypatch.setattr( + cr, + "create_dataset_from_jsonl", + lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: ( + dataset_id, + {"name": f"accounts/{account_id}/datasets/{dataset_id}"}, + ), + ) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + # Run without evaluator_id; use --yes so selector returns tests directly (no UI) + import argparse + + args = argparse.Namespace( + evaluator_id=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["id"] is not None + assert captured["entry"] is not None and captured["entry"].endswith("test_one.py::test_one") + + +def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch): + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Env + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Mock evaluator exists and is ACTIVE + class _Resp: + ok = True + + def json(self): + return {"state": "ACTIVE"} + + def raise_for_status(self): + return None + + monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + + # Provide dataset via --dataset-jsonl so no test discovery needed + ds_path = project / "dataset.jsonl" + ds_path.write_text('{"input":"x"}\n', encoding="utf-8") + monkeypatch.setattr( + cr, + "create_dataset_from_jsonl", + lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: ( + dataset_id, + {"name": f"accounts/{account_id}/datasets/{dataset_id}"}, + ), + ) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + import argparse + + args = argparse.Namespace( + evaluator_id="some-eval", + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + + +def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, monkeypatch): + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Env + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Evaluator does not exist (force path into upload section) + def _raise(*a, **k): + raise requests.exceptions.RequestException("nope") + + import requests + + monkeypatch.setattr(cr.requests, "get", _raise) + + # Two discovered tests (ambiguous) + f1 = project / "a.py" + f2 = project / "b.py" + f1.write_text("# a", encoding="utf-8") + f2.write_text("# b", encoding="utf-8") + d1 = SimpleNamespace(qualname="a.test_one", file_path=str(f1)) + d2 = SimpleNamespace(qualname="b.test_two", file_path=str(f2)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + + import argparse + + args = argparse.Namespace( + evaluator_id="some-eval", + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=str(project / "dataset.jsonl"), + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + # create the dataset file so we don't fail earlier + (project / "dataset.jsonl").write_text('{"input":"x"}\n', encoding="utf-8") + + rc = cr.create_rft_command(args) + assert rc == 1 + + +def test_create_rft_fallback_to_dataset_builder(tmp_path, monkeypatch): + # Setup project + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Single discovered test without data_loaders or input_dataset + test_file = project / "metric" / "test_builder.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# builder case", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_builder", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + + # Environment + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Stub selector, upload, and polling + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # Dataset builder fallback + out_jsonl = project / "metric" / "builder_out.jsonl" + out_jsonl.write_text('{"row":1}\n{"row":2}\n', encoding="utf-8") + + monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: "builder.py::build_training_dataset") + monkeypatch.setattr(cr, "materialize_dataset_via_builder", lambda spec: (str(out_jsonl), 2)) + + # Capture dataset creation args + captured = {"dataset_id": None, "jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["dataset_id"] = dataset_id + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + # Run without dataset inputs so builder path is used + import argparse + + args = argparse.Namespace( + evaluator_id=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=None, + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + # Evaluator id derived from test_builder -> "test-builder-test-builder" + assert captured["dataset_id"] is not None + assert captured["dataset_id"].startswith("test-builder-test-builder-dataset-") + # Ensure we used the materialized JSONL + assert captured["jsonl_path"] == str(out_jsonl) + + +def test_create_rft_uses_dataloader_jsonl_when_available(tmp_path, monkeypatch): + # Setup project + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Single discovered test + test_file = project / "metric" / "test_loader.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# loader case", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_loader", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + + # Environment + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Stub selector, upload, and polling + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # Provide JSONL via dataloader extractor + dl_jsonl = project / "metric" / "loader_out.jsonl" + dl_jsonl.write_text('{"a":1}\n', encoding="utf-8") + monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: str(dl_jsonl)) + monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", lambda f, fn: None) + monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None) + + captured = {"dataset_id": None, "jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["dataset_id"] = dataset_id + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + import argparse + + args = argparse.Namespace( + evaluator_id=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=None, + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["dataset_id"] is not None + assert captured["dataset_id"].startswith("test-loader-test-loader-dataset-") + assert captured["jsonl_path"] == str(dl_jsonl) + + +def test_create_rft_uses_input_dataset_jsonl_when_available(tmp_path, monkeypatch): + # Setup project + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Single discovered test + test_file = project / "metric" / "test_input_ds.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# input_dataset case", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_input_ds", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + + # Environment + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Stub selector, upload, and polling + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # Provide JSONL via input_dataset extractor + id_jsonl = project / "metric" / "input_ds_out.jsonl" + id_jsonl.write_text('{"b":2}\n', encoding="utf-8") + monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None) + monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", lambda f, fn: str(id_jsonl)) + monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None) + + captured = {"dataset_id": None, "jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["dataset_id"] = dataset_id + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + import argparse + + args = argparse.Namespace( + evaluator_id=None, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=None, + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["dataset_id"] is not None + assert captured["dataset_id"].startswith("test-input-ds-test-input-ds-dataset-") + assert captured["jsonl_path"] == str(id_jsonl) + + +def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(tmp_path, monkeypatch): + # Setup project with multiple tests; evaluator exists (skip upload) + project = tmp_path / "proj" + project.mkdir() + monkeypatch.chdir(project) + + # Env + monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") + monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + + # Two tests discovered + f1 = project / "evals" / "alpha.py" + f2 = project / "evals" / "beta.py" + f1.parent.mkdir(parents=True, exist_ok=True) + f1.write_text("# alpha", encoding="utf-8") + f2.write_text("# beta", encoding="utf-8") + d1 = SimpleNamespace(qualname="alpha.test_one", file_path=str(f1)) + d2 = SimpleNamespace(qualname="beta.test_two", file_path=str(f2)) + monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) + + # Evaluator exists and is ACTIVE (skip upload) + class _Resp: + ok = True + + def json(self): + return {"state": "ACTIVE"} + + def raise_for_status(self): + return None + + monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + + # We will provide JSONL via input_dataset extractor for matching test (beta.test_two) + jsonl_path = project / "data.jsonl" + jsonl_path.write_text('{"c":3}\n', encoding="utf-8") + + # Stub extractors: only the matching test name should matter; our implementation calls extractor with file+func + def _extract_input_jsonl(file_path, func_name): + # Simulate returning JSONL regardless; dataset inference uses the selected test determined by evaluator_id + return str(jsonl_path) + + monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None) + monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", _extract_input_jsonl) + monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None) + + captured = {"dataset_id": None, "jsonl_path": None} + + def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): + captured["dataset_id"] = dataset_id + captured["jsonl_path"] = jsonl_path + return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"} + + monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) + monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) + + import argparse + + # Provide evaluator_id that matches beta.test_two + eval_id = cr._normalize_evaluator_id("beta-test_two") + args = argparse.Namespace( + evaluator_id=eval_id, + yes=True, + dry_run=False, + force=False, + env_file=None, + dataset_id=None, + dataset_jsonl=None, + dataset_display_name=None, + dataset_builder=None, + base_model=None, + warm_start_from="accounts/acct123/models/ft-abc123", + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + assert captured["dataset_id"] is not None + # Ensure the dataset id is based on evaluator_id + assert captured["dataset_id"].startswith(f"{eval_id}-dataset-") + assert captured["jsonl_path"] == str(jsonl_path)