From 227d4ad8ffa901c1ba2a45afc63c27c372e4fff1 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 14:52:59 -0800 Subject: [PATCH 1/7] refactor cli and validate before create rft --- eval_protocol/cli.py | 20 + eval_protocol/cli_commands/create_rft.py | 681 ++++++++++++----------- eval_protocol/cli_commands/local_test.py | 121 ++-- eval_protocol/cli_commands/upload.py | 473 +--------------- eval_protocol/cli_commands/utils.py | 511 +++++++++++++++++ tests/test_cli_create_rft_infer.py | 261 ++++----- 6 files changed, 1086 insertions(+), 981 deletions(-) create mode 100644 eval_protocol/cli_commands/utils.py diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 89ee5df3..90b620c1 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -433,6 +433,26 @@ def parse_args(args=None): rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") + rft_parser.add_argument( + "--skip-validation", + action="store_true", + help="Skip local dataset and evaluator validation before creating the RFT job", + ) + rft_parser.add_argument( + "--ignore-docker", + action="store_true", + help="Ignore Dockerfile even if present; run pytest on host during evaluator validation", + ) + rft_parser.add_argument( + "--docker-build-extra", + default="", + help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")", + ) + rft_parser.add_argument( + "--docker-run-extra", + default="", + help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")", + ) # Local test command local_test_parser = subparsers.add_parser( diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index afb5cd8d..e277e078 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -1,99 +1,37 @@ +import argparse import json import os import sys import time -import argparse from typing import Any, Dict, Optional import requests +from pydantic import ValidationError -from ..auth import ( - get_fireworks_account_id, - get_fireworks_api_base, - get_fireworks_api_key, - verify_api_key_and_get_account_id, -) +from ..auth import get_fireworks_api_base, get_fireworks_api_key from ..common_utils import get_user_agent from ..fireworks_rft import ( - _map_api_host_to_app_host, build_default_output_model, create_dataset_from_jsonl, create_reinforcement_fine_tuning_job, + detect_dataset_builder, + materialize_dataset_via_builder, ) -from ..fireworks_rft import detect_dataset_builder, materialize_dataset_via_builder -from .upload import _discover_tests, _normalize_evaluator_id, _prompt_select - - -def _ensure_account_id() -> Optional[str]: - account_id = get_fireworks_account_id() - api_key = get_fireworks_api_key() - if not account_id and api_key: - resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base()) - if resolved: - os.environ["FIREWORKS_ACCOUNT_ID"] = resolved - return resolved - return account_id - - -def _extract_terminal_segment(resource_name: str) -> str: - """Return the last path segment if a fully-qualified resource name is provided.""" - try: - return resource_name.strip("/").split("/")[-1] - except Exception: - return resource_name - - -def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: - api_base = get_fireworks_api_base() - app_base = _map_api_host_to_app_host(api_base) - print("\nšŸ“Š Dashboard Links:") - evaluator_slug = _extract_terminal_segment(evaluator_id) - print(f" Evaluator: {app_base}/dashboard/evaluators/{evaluator_slug}") - if dataset_id: - print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") - if job_name: - # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} - try: - job_id = job_name.strip().split("/")[-1] - print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}") - except Exception: - pass - - -def _auto_find_jsonl(cwd: str) -> Optional[str]: - """Find a reasonable JSONL dataset file in the current project. - - Priority order: - - dataset.jsonl in cwd - - data/dataset.jsonl - - first *.jsonl under cwd (depth-first, skipping common vendor/venv/build dirs) - Returns a RELATIVE path from cwd if possible. - """ - # Direct candidates - direct_candidates = [ - os.path.join(cwd, "dataset.jsonl"), - os.path.join(cwd, "data", "dataset.jsonl"), - ] - for p in direct_candidates: - if os.path.isfile(p): - try: - return os.path.relpath(p, cwd) - except Exception: - return p - - # Walk and find any .jsonl - skip_dirs = {".venv", "venv", "node_modules", "dist", "build", "__pycache__", ".git", "vendor"} - for dirpath, dirnames, filenames in os.walk(cwd): - # prune - dirnames[:] = [d for d in dirnames if d not in skip_dirs and not d.startswith(".")] - for name in sorted(filenames): - if name.endswith(".jsonl"): - candidate = os.path.join(dirpath, name) - try: - return os.path.relpath(candidate, cwd) - except Exception: - return candidate - return None +from ..models import EvaluationRow +from .upload import upload_command +from .utils import ( + _build_entry_point, + _build_trimmed_dataset_id, + _build_evaluator_dashboard_url, + _discover_and_select_tests, + _discover_tests, + _ensure_account_id, + _extract_terminal_segment, + _normalize_evaluator_id, + _print_links, + _resolve_selected_test, +) +from .local_test import run_evaluator_test def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: @@ -205,83 +143,23 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) if isinstance(dataset_path, (list, tuple)) and len(dataset_path) > 0: dataset_path = dataset_path[0] if isinstance(dataset_path, str) and dataset_path: + candidate_paths = [] if os.path.isabs(dataset_path): - return dataset_path - base_dir = os.path.dirname(os.path.abspath(test_file_path)) - resolved = os.path.abspath(os.path.join(base_dir, dataset_path)) - if os.path.isfile(resolved): - return resolved - # Try resolving from project root if relative to test file doesn't work - if not os.path.isabs(dataset_path): - # Try resolving from current working directory - cwd_path = os.path.abspath(os.path.join(os.getcwd(), dataset_path)) - if os.path.isfile(cwd_path): - return cwd_path + candidate_paths.append(dataset_path) + else: + base_dir = os.path.dirname(os.path.abspath(test_file_path)) + candidate_paths.append(os.path.abspath(os.path.join(base_dir, dataset_path))) + # Also try resolving from current working directory + candidate_paths.append(os.path.abspath(os.path.join(os.getcwd(), dataset_path))) + + for candidate in candidate_paths: + if os.path.isfile(candidate) and _validate_dataset_jsonl(candidate): + return candidate return None except Exception: return None -def _build_trimmed_dataset_id(evaluator_id: str) -> str: - """Build a dataset id derived from evaluator_id, trimmed to 63 chars. - - Format: -dataset-YYYYMMDDHHMMSS, where base is trimmed to fit. - """ - # Normalize base similarly to evaluator id rules - from .upload import _normalize_evaluator_id # local import to avoid cycle at module import time - - base = _normalize_evaluator_id(evaluator_id) - suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}" - max_total = 63 - max_base_len = max_total - len(suffix) - if max_base_len < 1: - max_base_len = 1 - if len(base) > max_base_len: - base = base[:max_base_len].rstrip("-") - if not base: - base = "dataset" - # Ensure first char is a letter - if not base: - base = "dataset" - if not base[0].isalpha(): - base = f"eval-{base}" - if len(base) > max_base_len: - base = base[:max_base_len] - base = base.rstrip("-") or "dataset" - return f"{base}{suffix}" - - -def _resolve_selected_test( - project_root: str, - evaluator_id: Optional[str], - selected_tests: Optional[list] = None, -) -> tuple[Optional[str], Optional[str]]: - """ - Resolve a single test's source file path and function name to use downstream. - Priority: - 1) If selected_tests provided and length == 1, use it. - 2) Else discover tests; if exactly one test, use it. - 3) Else, if evaluator_id provided, match by normalized '-'. - Returns: (file_path, func_name) or (None, None) if unresolved. - """ - try: - tests = selected_tests if selected_tests is not None else _discover_tests(project_root) - if not tests: - return None, None - if len(tests) == 1: - return tests[0].file_path, tests[0].qualname.split(".")[-1] - if evaluator_id: - for t in tests: - func_name = t.qualname.split(".")[-1] - source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] - candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") - if candidate == evaluator_id: - return t.file_path, func_name - return None, None - except Exception: - return None, None - - def _poll_evaluator_status( evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 ) -> bool: @@ -343,45 +221,96 @@ def _poll_evaluator_status( return False -def create_rft_command(args) -> int: - evaluator_id: Optional[str] = getattr(args, "evaluator", None) - non_interactive: bool = bool(getattr(args, "yes", False)) - dry_run: bool = bool(getattr(args, "dry_run", False)) - force: bool = bool(getattr(args, "force", False)) - # Track the specifically chosen test (if any) to aid dataset inference later - selected_test_file_path: Optional[str] = None - selected_test_func_name: Optional[str] = None +def _validate_dataset_jsonl(jsonl_path: str, sample_limit: int = 50) -> bool: + """Validate that a JSONL file contains rows compatible with EvaluationRow. - api_key = get_fireworks_api_key() - if not api_key: - print("Error: FIREWORKS_API_KEY not set.") - return 1 + We stream up to `sample_limit` rows, ensuring each is JSON-decodable and can be + parsed by the EvaluationRow model. Returns True on success, False on any error. + """ + try: + if not os.path.isfile(jsonl_path): + print(f"Error: dataset JSONL not found at path: {jsonl_path}") + return False + + row_count = 0 + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError as e: + print(f"Error: dataset JSONL contains invalid JSON (line {row_count + 1}): {e}") + return False - account_id = _ensure_account_id() - if not account_id: - print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.") - return 1 + try: + EvaluationRow.model_validate(data) + except ValidationError as e: + print(f"Error: dataset JSONL row {row_count + 1} is not a valid EvaluationRow: {e}") + return False - api_base = get_fireworks_api_base() + row_count += 1 + if row_count >= sample_limit: + break + + if row_count == 0: + print(f"Error: dataset JSONL at {jsonl_path} appears to be empty.") + return False + + return True + except Exception as e: + print(f"Error validating dataset JSONL at {jsonl_path}: {e}") + return False + + +def _validate_dataset(dataset_jsonl: Optional[str]) -> bool: + """Validate dataset JSONL path when available; no-op when using dataset IDs only.""" + if not dataset_jsonl: + return True + return _validate_dataset_jsonl(dataset_jsonl) + + +def _validate_evaluator_locally( + project_root: str, + selected_test_file: Optional[str], + selected_test_func: Optional[str], + ignore_docker: bool, + docker_build_extra: str, + docker_run_extra: str, +) -> bool: + """Run pytest locally for the selected evaluation test to validate the evaluator.""" + if not selected_test_file or not selected_test_func: + # No local test associated; skip validation but warn the user. + print("Warning: Could not resolve a local evaluation test for this evaluator; skipping local validation.") + return True + + pytest_target = _build_entry_point(project_root, selected_test_file, selected_test_func) + exit_code = run_evaluator_test( + project_root=project_root, + pytest_target=pytest_target, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ) + return exit_code == 0 + + +def _resolve_evaluator( + project_root: str, + evaluator_arg: Optional[str], + non_interactive: bool, + account_id: str, +) -> tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Resolve evaluator id/resource and associated local test (file + func).""" + evaluator_id = evaluator_arg + selected_test_file_path: Optional[str] = None + selected_test_func_name: Optional[str] = None - # Resolve evaluator id/entry if omitted (reuse upload's selector flow) - project_root = os.getcwd() if not evaluator_id: - print("Scanning for evaluation tests...") - tests = _discover_tests(project_root) - if not tests: - print("No evaluation tests found.") - print("\nHint: Make sure your tests use the @evaluation_test decorator.") - return 1 - # Always interactive selection here - try: - selected_tests = _prompt_select(tests, non_interactive=non_interactive) - except Exception: - print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.") - return 1 + selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive) if not selected_tests: - print("No tests selected.") - return 1 + return None, None, None, None if len(selected_tests) != 1: if non_interactive and len(selected_tests) > 1: print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.") @@ -400,7 +329,8 @@ def create_rft_command(args) -> int: pass else: print("Error: Please select exactly one evaluation test for 'create rft'.") - return 1 + return None, None, None, None + # Derive evaluator_id from user's single selection chosen = selected_tests[0] func_name = chosen.qualname.split(".")[-1] @@ -410,129 +340,46 @@ def create_rft_command(args) -> int: selected_test_file_path, selected_test_func_name = _resolve_selected_test( project_root, evaluator_id, selected_tests=selected_tests ) - # Resolve evaluator resource name to fully-qualified format required by API. - # Allow users to pass either short id or fully-qualified resource. - if evaluator_id and evaluator_id.startswith("accounts/"): - evaluator_resource_name = evaluator_id - evaluator_id = _extract_terminal_segment(evaluator_id) else: - evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" - - # Optional short-circuit: if evaluator already exists and not forcing, skip upload path - skip_upload = False - if not force: - try: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) - if resp.ok: - state = resp.json().get("state", "STATE_UNSPECIFIED") - print(f"āœ“ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).") - # Poll for ACTIVE before proceeding - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - if not _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ): - app_base = _map_api_host_to_app_host(api_base) - evaluator_slug = _extract_terminal_segment(evaluator_id) - dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}" - print("\nāŒ Evaluator is not ready within the timeout period.") - print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return 1 - skip_upload = True - # Populate selected test info for dataset inference later - st_path, st_func = _resolve_selected_test(project_root, evaluator_id) - if st_path and st_func: - selected_test_file_path = st_path - selected_test_func_name = st_func - except requests.exceptions.RequestException: - pass - - # Ensure evaluator exists by invoking the upload flow programmatically - if not skip_upload: - try: - from .upload import upload_command - - tests = _discover_tests(project_root) - selected_entry: Optional[str] = None - st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) - if st_path and st_func: - abs_path = os.path.abspath(st_path) - try: - rel = os.path.relpath(abs_path, project_root) - except Exception: - rel = abs_path - selected_entry = f"{rel}::{st_func}" - selected_test_file_path = st_path - selected_test_func_name = st_func - # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators - if selected_entry is None and len(tests) > 1: - print( - f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" - " Please re-run specifying the evaluator.\n" - " Hints:\n" - " - eval-protocol create rft --evaluator \n" - ) - return 1 - - upload_args = argparse.Namespace( - path=project_root, - entry=selected_entry, - id=evaluator_id, - display_name=None, - description=None, - force=force, # Pass through the --force flag - yes=True, - env_file=None, # Add the new env_file parameter - ) + # Caller provided an evaluator id or fully-qualified resource; try to resolve local test + short_id = evaluator_id + if evaluator_id.startswith("accounts/"): + short_id = _extract_terminal_segment(evaluator_id) + st_path, st_func = _resolve_selected_test(project_root, short_id) + if st_path and st_func: + selected_test_file_path = st_path + selected_test_func_name = st_func + evaluator_id = short_id - if force: - print(f"šŸ”„ Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") + if not evaluator_id: + return None, None, None, None - rc = upload_command(upload_args) - if rc == 0: - print(f"āœ“ Uploaded/ensured evaluator: {evaluator_id}") + # Resolve evaluator resource name to fully-qualified format required by API. + if evaluator_arg and evaluator_arg.startswith("accounts/"): + evaluator_resource_name = evaluator_arg + else: + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" - # Poll for evaluator status - print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") - is_active = _poll_evaluator_status( - evaluator_resource_name=evaluator_resource_name, - api_key=api_key, - api_base=api_base, - timeout_minutes=10, - ) + return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name - if not is_active: - # Print helpful message with dashboard link - app_base = _map_api_host_to_app_host(api_base) - evaluator_slug = _extract_terminal_segment(evaluator_id) - dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}" - print("\nāŒ Evaluator is not ready within the timeout period.") - print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") - print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") - return 1 - else: - # Evaluator ACTIVE; proceed - pass - else: - print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") - except Exception as e: - print(f"Warning: Failed to upload evaluator automatically: {e}") - - # Determine dataset id and materialization path +def _resolve_and_prepare_dataset( + project_root: str, + account_id: str, + api_key: str, + api_base: str, + evaluator_id: str, + args: argparse.Namespace, + selected_test_file_path: Optional[str], + selected_test_func_name: Optional[str], + dry_run: bool, +) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolve dataset id/resource and ensure dataset exists if using JSONL.""" dataset_id = getattr(args, "dataset", None) dataset_jsonl = getattr(args, "dataset_jsonl", None) dataset_display_name = getattr(args, "dataset_display_name", None) - dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow dataset_resource_override: Optional[str] = None + if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"): # Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing dataset_resource_override = dataset_id @@ -553,23 +400,21 @@ def create_rft_command(args) -> int: test_file_for_infer = tests[0].file_path func_for_infer = tests[0].qualname.split(".")[-1] if test_file_for_infer and func_for_infer: - # Try data_loaders first + # Block using data loaders as a dataset source dataset_jsonl = _extract_jsonl_from_dataloader(test_file_for_infer, func_for_infer) + if dataset_jsonl: + print( + "Error: Evaluation tests that use 'data_loaders' to provide a dataset JSONL are not supported for 'create rft'.\n" + " Please switch to a JSONL-based dataset via input_dataset arg in @evaluation_test decorator." + ) + return None, None, None + dataset_jsonl = _extract_jsonl_from_input_dataset(test_file_for_infer, func_for_infer) if dataset_jsonl: try: rel = os.path.relpath(dataset_jsonl, project_root) except Exception: rel = dataset_jsonl - print(f"āœ“ Using JSONL from data loader: {rel}") - if not dataset_jsonl: - # Fall back to input_dataset (dataset_path) - dataset_jsonl = _extract_jsonl_from_input_dataset(test_file_for_infer, func_for_infer) - if dataset_jsonl: - try: - rel = os.path.relpath(dataset_jsonl, project_root) - except Exception: - rel = dataset_jsonl - print(f"āœ“ Using JSONL from input_dataset: {rel}") + print(f"āœ“ Using JSONL from input_dataset: {rel}") if not dataset_jsonl: # Last resort: attempt to detect and run a dataset builder in the test's directory metric_dir = os.path.dirname(test_file_for_infer) @@ -585,7 +430,7 @@ def create_rft_command(args) -> int: print( "Error: Could not determine dataset. Provide --dataset or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test." ) - return 1 + return None, None, None inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) if dry_run: @@ -610,8 +455,125 @@ def create_rft_command(args) -> int: print(f"āœ“ Created and uploaded dataset: {dataset_id}") except Exception as e: print(f"Error creating/uploading dataset: {e}") - return 1 + return None, None, None + + if not dataset_id: + return None, None, None + + # Build dataset resource (prefer override when provided) + dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}" + return dataset_id, dataset_resource, dataset_jsonl + + +def _ensure_evaluator_active( + project_root: str, + evaluator_id: str, + evaluator_resource_name: str, + api_key: str, + api_base: str, + force: bool, +) -> bool: + """Ensure the evaluator exists and is ACTIVE, uploading it if needed.""" + # Optional short-circuit: if evaluator already exists and not forcing, skip upload path + if not force: + try: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } + resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10) + if resp.ok: + state = resp.json().get("state", "STATE_UNSPECIFIED") + print(f"āœ“ Evaluator exists (state: {state}). Skipping upload (use --force to overwrite).") + # Poll for ACTIVE before proceeding + print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") + if not _poll_evaluator_status( + evaluator_resource_name=evaluator_resource_name, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ): + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\nāŒ Evaluator is not ready within the timeout period.") + print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") + return False + return True + except requests.exceptions.RequestException: + pass + # Ensure evaluator exists by invoking the upload flow programmatically + try: + tests = _discover_tests(project_root) + selected_entry: Optional[str] = None + st_path, st_func = _resolve_selected_test(project_root, evaluator_id, selected_tests=tests) + if st_path and st_func: + selected_entry = _build_entry_point(project_root, st_path, st_func) + # If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators + if selected_entry is None and len(tests) > 1: + print( + f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n" + " Please re-run specifying the evaluator.\n" + " Hints:\n" + " - eval-protocol create rft --evaluator \n" + ) + return False + + upload_args = argparse.Namespace( + path=project_root, + entry=selected_entry, + id=evaluator_id, + display_name=None, + description=None, + force=force, # Pass through the --force flag + yes=True, + env_file=None, # Add the new env_file parameter + ) + + if force: + print(f"šŸ”„ Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") + + rc = upload_command(upload_args) + if rc == 0: + print(f"āœ“ Uploaded/ensured evaluator: {evaluator_id}") + + # Poll for evaluator status + print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") + is_active = _poll_evaluator_status( + evaluator_resource_name=evaluator_resource_name, + api_key=api_key, + api_base=api_base, + timeout_minutes=10, + ) + + if not is_active: + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print("\nāŒ Evaluator is not ready within the timeout period.") + print(f"šŸ“Š Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") + return False + return True + else: + print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") + return False + except Exception as e: + print(f"Warning: Failed to upload evaluator automatically: {e}") + return False + + +def _create_rft_job( + account_id: str, + api_key: str, + api_base: str, + evaluator_id: str, + evaluator_resource_name: str, + dataset_id: str, + dataset_resource: str, + args: argparse.Namespace, + dry_run: bool, +) -> int: + """Build and submit the RFT job request.""" # Build training config/body # Exactly one of base-model or warm-start-from must be provided base_model_raw = getattr(args, "base_model", None) @@ -682,9 +644,6 @@ def create_rft_command(args) -> int: "runId": getattr(args, "wandb_run_id", None), } - # Build dataset resource (prefer override when provided) - dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}" - body: Dict[str, Any] = { "displayName": getattr(args, "display_name", None), "dataset": dataset_resource, @@ -732,3 +691,93 @@ def create_rft_command(args) -> int: except Exception as e: print(f"Error creating RFT job: {e}") return 1 + + +def create_rft_command(args) -> int: + # Pre-flight: resolve auth and environment + api_key = get_fireworks_api_key() + if not api_key: + print("Error: FIREWORKS_API_KEY not set.") + return 1 + + account_id = _ensure_account_id() + if not account_id: + print("Error: FIREWORKS_ACCOUNT_ID not set and could not be resolved.") + return 1 + + api_base = get_fireworks_api_base() + project_root = os.getcwd() + evaluator_arg: Optional[str] = getattr(args, "evaluator", None) + non_interactive: bool = bool(getattr(args, "yes", False)) + dry_run: bool = bool(getattr(args, "dry_run", False)) + force: bool = bool(getattr(args, "force", False)) + skip_validation: bool = bool(getattr(args, "skip_validation", False)) + ignore_docker: bool = bool(getattr(args, "ignore_docker", False)) + docker_build_extra: str = getattr(args, "docker_build_extra", "") or "" + docker_run_extra: str = getattr(args, "docker_run_extra", "") or "" + + # 1) Resolve evaluator and associated local test + ( + evaluator_id, + evaluator_resource_name, + selected_test_file_path, + selected_test_func_name, + ) = _resolve_evaluator(project_root, evaluator_arg, non_interactive, account_id) + if not evaluator_id or not evaluator_resource_name: + return 1 + + # 2) Resolve dataset (id/resource) and underlying JSONL (if any) + dataset_id, dataset_resource, dataset_jsonl = _resolve_and_prepare_dataset( + project_root=project_root, + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + args=args, + selected_test_file_path=selected_test_file_path, + selected_test_func_name=selected_test_func_name, + dry_run=dry_run, + ) + if not dataset_id or not dataset_resource: + return 1 + + # 3) Optional local validation + if not skip_validation: + # Dataset validation (JSONL must be EvaluationRow-compatible when present) + if not _validate_dataset(dataset_jsonl): + return 1 + + # Evaluator validation (run pytest for the selected test, possibly via Docker) + if not _validate_evaluator_locally( + project_root=project_root, + selected_test_file=selected_test_file_path, + selected_test_func=selected_test_func_name, + ignore_docker=ignore_docker, + docker_build_extra=docker_build_extra, + docker_run_extra=docker_run_extra, + ): + return 1 + + # 4) Ensure evaluator exists and is ACTIVE (upload + poll if needed) + if not _ensure_evaluator_active( + project_root=project_root, + evaluator_id=evaluator_id, + evaluator_resource_name=evaluator_resource_name, + api_key=api_key, + api_base=api_base, + force=force, + ): + return 1 + + # 5) Create the RFT job + return _create_rft_job( + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + evaluator_resource_name=evaluator_resource_name, + dataset_id=dataset_id, + dataset_resource=dataset_resource, + args=args, + dry_run=dry_run, + ) diff --git a/eval_protocol/cli_commands/local_test.py b/eval_protocol/cli_commands/local_test.py index 49d34190..545b61fd 100644 --- a/eval_protocol/cli_commands/local_test.py +++ b/eval_protocol/cli_commands/local_test.py @@ -1,11 +1,11 @@ import argparse import os +import shlex import subprocess import sys -import shlex from typing import List -from .upload import _discover_tests, _prompt_select +from .utils import _build_entry_point, _discover_and_select_tests def _find_dockerfiles(root: str) -> List[str]: @@ -19,12 +19,6 @@ def _find_dockerfiles(root: str) -> List[str]: return dockerfiles -def _run_pytest_host(pytest_target: str) -> int: - print(f"Running locally: pytest {pytest_target} -vs") - proc = subprocess.run([sys.executable, "-m", "pytest", pytest_target, "-vs"]) - return proc.returncode - - def _build_docker_image(dockerfile_path: str, image_tag: str, build_extras: List[str] | None = None) -> bool: context_dir = os.path.dirname(dockerfile_path) print(f"Building Docker image '{image_tag}' from {dockerfile_path} ...") @@ -41,6 +35,13 @@ def _build_docker_image(dockerfile_path: str, image_tag: str, build_extras: List return False +def _run_pytest_host(pytest_target: str) -> int: + """Run pytest against a target on the host and return its exit code.""" + print(f"Running locally: pytest {pytest_target} -vs") + proc = subprocess.run([sys.executable, "-m", "pytest", pytest_target, "-vs"]) + return proc.returncode + + def _run_pytest_in_docker( project_root: str, image_tag: str, pytest_target: str, run_extras: List[str] | None = None ) -> int: @@ -87,6 +88,53 @@ def _run_pytest_in_docker( return 1 +def run_evaluator_test( + project_root: str, + pytest_target: str, + ignore_docker: bool, + docker_build_extra: str = "", + docker_run_extra: str = "", +) -> int: + """Run an evaluator test either on host or in Docker, reusing local-test logic.""" + build_extras = shlex.split(docker_build_extra) if docker_build_extra else [] + run_extras = shlex.split(docker_run_extra) if docker_run_extra else [] + + if ignore_docker: + if not pytest_target: + print("Error: Failed to resolve a pytest target to run.") + return 1 + return _run_pytest_host(pytest_target) + + dockerfiles = _find_dockerfiles(project_root) + if len(dockerfiles) > 1: + print("Error: Multiple Dockerfiles found. Only one Dockerfile is allowed for evaluator validation/local-test.") + for df in dockerfiles: + print(f" - {df}") + print("Hint: or use --ignore-docker to bypass Docker and use local pytest.") + return 1 + if len(dockerfiles) == 1: + # Ensure host home logs directory exists so container writes are visible to host ep logs + try: + os.makedirs(os.path.join(os.path.expanduser("~"), ".eval_protocol"), exist_ok=True) + except Exception: + pass + image_tag = "ep-evaluator:local" + ok = _build_docker_image(dockerfiles[0], image_tag, build_extras=build_extras) + if not ok: + print("Docker build failed. See logs above.") + return 1 + if not pytest_target: + print("Error: Failed to resolve a pytest target to run.") + return 1 + return _run_pytest_in_docker(project_root, image_tag, pytest_target, run_extras=run_extras) + + # No Dockerfile: run on host + if not pytest_target: + print("Error: Failed to resolve a pytest target to run.") + return 1 + return _run_pytest_host(pytest_target) + + def local_test_command(args: argparse.Namespace) -> int: project_root = os.getcwd() @@ -99,12 +147,7 @@ def local_test_command(args: argparse.Namespace) -> int: file_path = ( file_part if os.path.isabs(file_part) else os.path.abspath(os.path.join(project_root, file_part)) ) - # Convert to project-relative like the non-:: path - try: - rel = os.path.relpath(file_path, project_root) - except Exception: - rel = file_path - pytest_target = f"{rel}::{func_part}" + pytest_target = _build_entry_point(project_root, file_path, func_part) else: file_path = entry if os.path.isabs(entry) else os.path.abspath(os.path.join(project_root, entry)) # Use path relative to project_root when possible @@ -114,14 +157,9 @@ def local_test_command(args: argparse.Namespace) -> int: rel = file_path pytest_target = rel else: - tests = _discover_tests(project_root) - if not tests: - print("No evaluation tests found.\nHint: Ensure @evaluation_test is applied.") - return 1 non_interactive = bool(getattr(args, "yes", False)) - selected = _prompt_select(tests, non_interactive=non_interactive) + selected = _discover_and_select_tests(project_root, non_interactive=non_interactive) if not selected: - print("No tests selected.") return 1 if len(selected) != 1: print("Error: Please select exactly one evaluation test for 'local-test'.") @@ -137,39 +175,10 @@ def local_test_command(args: argparse.Namespace) -> int: ignore_docker = bool(getattr(args, "ignore_docker", False)) build_extras_str = getattr(args, "docker_build_extra", "") or "" run_extras_str = getattr(args, "docker_run_extra", "") or "" - build_extras = shlex.split(build_extras_str) if build_extras_str else [] - run_extras = shlex.split(run_extras_str) if run_extras_str else [] - if ignore_docker: - if not pytest_target: - print("Error: Failed to resolve a pytest target to run.") - return 1 - return _run_pytest_host(pytest_target) - - dockerfiles = _find_dockerfiles(project_root) - if len(dockerfiles) > 1: - print("Error: Multiple Dockerfiles found. Only one Dockerfile is allowed for local-test.") - for df in dockerfiles: - print(f" - {df}") - print("Hint: use --ignore-docker to bypass Docker.") - return 1 - if len(dockerfiles) == 1: - # Ensure host home logs directory exists so container writes are visible to host ep logs - try: - os.makedirs(os.path.join(os.path.expanduser("~"), ".eval_protocol"), exist_ok=True) - except Exception: - pass - image_tag = "ep-evaluator:local" - ok = _build_docker_image(dockerfiles[0], image_tag, build_extras=build_extras) - if not ok: - print("Docker build failed. See logs above.") - return 1 - if not pytest_target: - print("Error: Failed to resolve a pytest target to run.") - return 1 - return _run_pytest_in_docker(project_root, image_tag, pytest_target, run_extras=run_extras) - - # No Dockerfile: run on host - if not pytest_target: - print("Error: Failed to resolve a pytest target to run.") - return 1 - return _run_pytest_host(pytest_target) + return run_evaluator_test( + project_root, + pytest_target, + ignore_docker=ignore_docker, + docker_build_extra=build_extras_str, + docker_run_extra=run_extras_str, + ) diff --git a/eval_protocol/cli_commands/upload.py b/eval_protocol/cli_commands/upload.py index 8c6e7baf..33e0ed2f 100644 --- a/eval_protocol/cli_commands/upload.py +++ b/eval_protocol/cli_commands/upload.py @@ -1,255 +1,24 @@ import argparse import importlib.util -import inspect -import json import os -import pkgutil import re -import runpy import sys -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable - -import pytest -from eval_protocol.auth import ( - get_fireworks_account_id, - get_fireworks_api_key, - get_fireworks_api_base, - verify_api_key_and_get_account_id, -) +from typing import Any, Dict + +from eval_protocol.auth import get_fireworks_api_key from eval_protocol.platform_api import create_or_update_fireworks_secret from eval_protocol.evaluation import create_evaluation - - -@dataclass -class DiscoveredTest: - module_path: str - module_name: str - qualname: str - file_path: str - lineno: int | None - has_parametrize: bool - param_count: int - nodeids: list[str] - - -def _iter_python_files(root: str) -> Iterable[str]: - # Don't follow symlinks to avoid infinite loops - for dirpath, dirnames, filenames in os.walk(root, followlinks=False): - # Skip common virtualenv and node paths - if any( - skip in dirpath - for skip in [ - "/.venv", - "/venv", - "/node_modules", - "/.git", - "/dist", - "/build", - "/__pycache__", - ".egg-info", - "/vendor", - ] - ): - continue - # Also skip specific directories by modifying dirnames in-place - dirnames[:] = [ - d - for d in dirnames - if not d.startswith(".") and d not in ["venv", "node_modules", "__pycache__", "dist", "build", "vendor"] - ] - - for name in filenames: - # Skip setup files, test discovery scripts, __init__, and hidden files - if ( - name.endswith(".py") - and not name.startswith(".") - and not name.startswith("test_discovery") - and name not in ["setup.py", "versioneer.py", "conf.py", "__main__.py"] - ): - yield os.path.join(dirpath, name) - - -def _is_eval_protocol_test(obj: Any) -> bool: - # evaluation_test decorator returns a dual_mode_wrapper with _origin_func and pytest marks - if not callable(obj): - return False - origin = getattr(obj, "_origin_func", None) - if origin is None: - return False - # Must have pytest marks from evaluation_test - marks = getattr(obj, "pytestmark", []) - # Handle pytest proxy objects (APIRemovedInV1Proxy) - if not isinstance(marks, (list, tuple)): - try: - marks = list(marks) if marks else [] - except (TypeError, AttributeError): - return False - return len(marks) > 0 - - -def _extract_param_info_from_marks(obj: Any) -> tuple[bool, int, list[str]]: - """Extract parametrization info from pytest marks. - - Returns: - (has_parametrize, param_count, param_ids) - """ - marks = getattr(obj, "pytestmark", []) - - # Handle pytest proxy objects (APIRemovedInV1Proxy) - same as _is_eval_protocol_test - if not isinstance(marks, (list, tuple)): - try: - marks = list(marks) if marks else [] - except (TypeError, AttributeError): - marks = [] - - has_parametrize = False - total_combinations = 0 - all_param_ids: list[str] = [] - - for m in marks: - if getattr(m, "name", "") == "parametrize": - has_parametrize = True - # The data is in kwargs for eval_protocol's parametrization - kwargs = getattr(m, "kwargs", {}) - argnames = kwargs.get("argnames", m.args[0] if m.args else "") - argvalues = kwargs.get("argvalues", m.args[1] if len(m.args) > 1 else []) - ids = kwargs.get("ids", []) - - # Count this dimension of parameters - if isinstance(argvalues, (list, tuple)): - count = len(argvalues) - total_combinations = count # For now, just use the count from this mark - - # Use provided IDs - if ids and isinstance(ids, (list, tuple)): - all_param_ids = list(ids[:count]) - else: - # Generate IDs based on argnames - if isinstance(argnames, str) and "," not in argnames: - # Single parameter - all_param_ids = [f"{argnames}={i}" for i in range(count)] - else: - # Multiple parameters - all_param_ids = [f"variant_{i}" for i in range(count)] - - return has_parametrize, total_combinations, all_param_ids - - -def _discover_tests(root: str) -> list[DiscoveredTest]: - abs_root = os.path.abspath(root) - if abs_root not in sys.path: - sys.path.insert(0, abs_root) - - discovered: list[DiscoveredTest] = [] - - class CollectionPlugin: - """Plugin to capture collected items without running code.""" - - def __init__(self): - self.items = [] - - def pytest_ignore_collect(self, collection_path, config): - """Ignore problematic files before pytest tries to import them.""" - # Ignore specific files - ignored_files = ["setup.py", "versioneer.py", "conf.py", "__main__.py"] - if collection_path.name in ignored_files: - return True - - # Ignore hidden files (starting with .) - if collection_path.name.startswith("."): - return True - - # Ignore test_discovery files - if collection_path.name.startswith("test_discovery"): - return True - - return None - - def pytest_collection_modifyitems(self, items): - """Hook called after collection is done.""" - self.items = items - - plugin = CollectionPlugin() - - # Run pytest collection only (--collect-only prevents code execution) - # Override python_files to collect from ANY .py file - args = [ - abs_root, - "--collect-only", - "-q", - "--pythonwarnings=ignore", - "-o", - "python_files=*.py", # Override to collect all .py files - ] - - try: - # Suppress pytest output - import io - import contextlib - - with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): - pytest.main(args, plugins=[plugin]) - except Exception: - # If pytest collection fails, fall back to empty list - return [] - - # Process collected items - for item in plugin.items: - if not hasattr(item, "obj"): - continue - - obj = item.obj - if not _is_eval_protocol_test(obj): - continue - - origin = getattr(obj, "_origin_func", obj) - try: - src_file = inspect.getsourcefile(origin) or str(item.path) - _, lineno = inspect.getsourcelines(origin) - except Exception: - src_file, lineno = str(item.path), None - - # Extract parametrization info from marks - has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj) - - # Get module name and function name - module_name = ( - item.module.__name__ - if hasattr(item, "module") - else item.nodeid.split("::")[0].replace("/", ".").replace(".py", "") - ) - func_name = item.name.split("[")[0] if "[" in item.name else item.name - - # Generate nodeids - base_nodeid = f"{os.path.basename(src_file)}::{func_name}" - if param_ids: - nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids] - else: - nodeids = [base_nodeid] - - discovered.append( - DiscoveredTest( - module_path=module_name, - module_name=module_name, - qualname=f"{module_name}.{func_name}", - file_path=os.path.abspath(src_file), - lineno=lineno, - has_parametrize=has_parametrize, - param_count=param_count, - nodeids=nodeids, - ) - ) - - # Deduplicate by qualname (in case same test appears multiple times) - by_qual: dict[str, DiscoveredTest] = {} - for t in discovered: - existing = by_qual.get(t.qualname) - if not existing or t.param_count > existing.param_count: - by_qual[t.qualname] = t - return sorted(by_qual.values(), key=lambda x: (x.file_path, x.lineno or 0)) +from .utils import ( + _build_entry_point, + _build_evaluator_dashboard_url, + _discover_and_select_tests, + _discover_tests, + _ensure_account_id, + _normalize_evaluator_id, + _prompt_select, +) def _to_pyargs_nodeid(file_path: str, func_name: str) -> str | None: @@ -364,165 +133,6 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]: return qualname, os.path.abspath(source_file_path) if source_file_path else "" -def _generate_ts_mode_code(test: DiscoveredTest) -> tuple[str, str]: - # Deprecated: we no longer generate a shim; keep stub for import compatibility - return ("", "main.py") - - -def _normalize_evaluator_id(evaluator_id: str) -> str: - """ - Normalize evaluator ID to meet Fireworks requirements: - - Only lowercase a-z, 0-9, and hyphen (-) - - Maximum 63 characters - """ - # Convert to lowercase - normalized = evaluator_id.lower() - - # Replace underscores with hyphens - normalized = normalized.replace("_", "-") - - # Remove any characters that aren't alphanumeric or hyphen - normalized = re.sub(r"[^a-z0-9-]", "", normalized) - - # Remove consecutive hyphens - normalized = re.sub(r"-+", "-", normalized) - - # Remove leading/trailing hyphens - normalized = normalized.strip("-") - - # Ensure it starts with a letter (Fireworks requirement) - if normalized and not normalized[0].isalpha(): - normalized = "eval-" + normalized - - # Truncate to 63 characters - if len(normalized) > 63: - normalized = normalized[:63].rstrip("-") - - return normalized - - -def _format_test_choice(test: DiscoveredTest, idx: int) -> str: - """Format a test as a choice string for display.""" - # Shorten the qualname for display - name = test.qualname.split(".")[-1] - location = f"{Path(test.file_path).name}:{test.lineno}" if test.lineno else Path(test.file_path).name - - if test.has_parametrize and test.param_count > 1: - return f"{name} ({test.param_count} variants) - {location}" - else: - return f"{name} - {location}" - - -def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTest]: - """Interactive selection with arrow keys using questionary.""" - try: - import questionary - from questionary import Style - - # Custom style similar to Vercel CLI - custom_style = Style( - [ - ("qmark", "fg:#673ab7 bold"), - ("question", "bold"), - ("answer", "fg:#f44336 bold"), - ("pointer", "fg:#673ab7 bold"), - ("highlighted", "fg:#673ab7 bold"), - ("selected", "fg:#cc5454"), - ("separator", "fg:#cc5454"), - ("instruction", ""), - ("text", ""), - ] - ) - - # Check if only one test - auto-select it - if len(tests) == 1: - print(f"\nFound 1 test: {_format_test_choice(tests[0], 1)}") - confirm = questionary.confirm("Select this test?", default=True, style=custom_style).ask() - if confirm: - return tests - else: - return [] - - # Single-select UX - print("\n") - print("Tip: Use ↑/↓ arrows to navigate and press ENTER to select.\n") - - choices = [] - for idx, t in enumerate(tests, 1): - choice_text = _format_test_choice(t, idx) - choices.append({"name": choice_text, "value": idx - 1}) - - selected = questionary.select( - "Select an evaluation test to upload:", choices=choices, style=custom_style - ).ask() - - if selected is None: # Ctrl+C - print("\nUpload cancelled.") - return [] - - print("\nāœ“ Selected 1 test") - return [tests[selected]] - - except ImportError: - # Fallback to simpler implementation - return _prompt_select_fallback(tests) - except KeyboardInterrupt: - print("\n\nUpload cancelled.") - return [] - - -def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest]: - """Fallback prompt selection for when questionary is not available.""" - print("\n" + "=" * 80) - print("Discovered evaluation tests:") - print("=" * 80) - print("\nTip: Install questionary for better UX: pip install questionary\n") - - for idx, t in enumerate(tests, 1): - loc = f"{t.file_path}:{t.lineno}" if t.lineno else t.file_path - print(f" [{idx}] {t.qualname}") - print(f" Location: {loc}") - - if t.has_parametrize and t.nodeids: - print(f" Parameterized: {t.param_count} variant(s)") - # Show first few variants as examples - example_nodeids = t.nodeids[:3] - for nodeid in example_nodeids: - # Extract just the parameter part for display - if "[" in nodeid: - param_part = nodeid.split("[", 1)[1].rstrip("]") - print(f" - {param_part}") - if len(t.nodeids) > 3: - print(f" ... and {len(t.nodeids) - 3} more") - else: - print(" Type: Single test (no parametrization)") - print() - - print("=" * 80) - try: - choice = input("Enter the number to select: ").strip() - except KeyboardInterrupt: - print("\n\nUpload cancelled.") - return [] - - if not choice.isdigit(): - print("\nāš ļø Invalid selection.") - return [] - n = int(choice) - if not (1 <= n <= len(tests)): - print("\nāš ļø Selection out of range.") - return [] - return [tests[n - 1]] - - -def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]: - """Prompt user to select tests to upload.""" - if non_interactive: - return tests - - return _prompt_select_interactive(tests) - - def _load_secrets_from_env_file(env_file_path: str) -> Dict[str, str]: """ Load secrets from a .env file that should be uploaded to Fireworks. @@ -572,6 +182,7 @@ def _mask_secret_value(value: str) -> str: def upload_command(args: argparse.Namespace) -> int: root = os.path.abspath(getattr(args, "path", ".")) entries_arg = getattr(args, "entry", None) + non_interactive: bool = bool(getattr(args, "yes", False)) if entries_arg: entries = [e.strip() for e in re.split(r"[,\s]+", entries_arg) if e.strip()] selected_specs: list[tuple[str, str]] = [] @@ -579,17 +190,9 @@ def upload_command(args: argparse.Namespace) -> int: qualname, resolved_path = _resolve_entry_to_qual_and_source(e, root) selected_specs.append((qualname, resolved_path)) else: - print("Scanning for evaluation tests...") - tests = _discover_tests(root) - if not tests: - print("No evaluation tests found.") - print("\nHint: Make sure your tests use the @evaluation_test decorator.") - return 1 - selected_tests = _prompt_select(tests, non_interactive=bool(getattr(args, "yes", False))) + selected_tests = _discover_and_select_tests(root, non_interactive=non_interactive) if not selected_tests: - print("No tests selected.") return 1 - # Warn about parameterized tests parameterized_tests = [t for t in selected_tests if t.has_parametrize] if parameterized_tests: @@ -607,7 +210,7 @@ def upload_command(args: argparse.Namespace) -> int: # Load secrets from .env file and ensure they're available on Fireworks try: - fw_account_id = get_fireworks_account_id() + fw_account_id = _ensure_account_id() # Determine .env file path if env_file: @@ -624,15 +227,6 @@ def upload_command(args: argparse.Namespace) -> int: if fw_api_key_value and "FIREWORKS_API_KEY" not in secrets_from_file: secrets_from_file["FIREWORKS_API_KEY"] = fw_api_key_value - if not fw_account_id and fw_api_key_value: - # Attempt to verify and resolve account id from server headers - resolved = verify_api_key_and_get_account_id(api_key=fw_api_key_value, api_base=get_fireworks_api_base()) - if resolved: - fw_account_id = resolved - # Propagate to environment so downstream calls use it if needed - os.environ["FIREWORKS_ACCOUNT_ID"] = fw_account_id - print(f"Resolved FIREWORKS_ACCOUNT_ID via API verification: {fw_account_id}") - if fw_account_id and secrets_from_file: print(f"Found {len(secrets_from_file)} API keys to upload as Fireworks secrets...") if secrets_from_env_file and os.path.exists(env_file_path): @@ -684,18 +278,7 @@ def upload_command(args: argparse.Namespace) -> int: # Compute entry point metadata for backend as a pytest nodeid usable with `pytest ` # Always prefer a path-based nodeid to work in plain pytest environments (server may not use --pyargs) func_name = qualname.split(".")[-1] - entry_point = None - if source_file_path: - # Use path relative to current working directory if possible - abs_path = os.path.abspath(source_file_path) - try: - rel = os.path.relpath(abs_path, root) - except Exception: - rel = abs_path - entry_point = f"{rel}::{func_name}" - else: - # Fallback: use filename from qualname only (rare) - entry_point = f"{func_name}.py::{func_name}" + entry_point = _build_entry_point(root, source_file_path, func_name) print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...") try: @@ -714,28 +297,8 @@ def upload_command(args: argparse.Namespace) -> int: # Print success message with Fireworks dashboard link print(f"\nāœ… Successfully uploaded evaluator: {evaluator_id}") print("šŸ“Š View in Fireworks Dashboard:") - # Map API base to app host (e.g., dev.api.fireworks.ai -> dev.app.fireworks.ai) - from urllib.parse import urlparse - - api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") - try: - parsed = urlparse(api_base) - host = parsed.netloc or parsed.path # handle cases where scheme may be missing - # Mapping rules: - # - dev.api.fireworks.ai → dev.fireworks.ai - # - *.api.fireworks.ai → *.app.fireworks.ai (default) - if host.startswith("dev.api.fireworks.ai"): - app_host = "dev.fireworks.ai" - elif host.startswith("api."): - app_host = host.replace("api.", "app.", 1) - else: - app_host = host - scheme = parsed.scheme or "https" - dashboard_url = f"{scheme}://{app_host}/dashboard/evaluators/{evaluator_id}" - except Exception: - dashboard_url = f"https://app.fireworks.ai/dashboard/evaluators/{evaluator_id}" - print(f" {dashboard_url}") - print() + dashboard_url = _build_evaluator_dashboard_url(evaluator_id) + print(f" {dashboard_url}\n") except Exception as e: print(f"Failed to upload {qualname}: {e}") exit_code = 2 diff --git a/eval_protocol/cli_commands/utils.py b/eval_protocol/cli_commands/utils.py new file mode 100644 index 00000000..4384e09f --- /dev/null +++ b/eval_protocol/cli_commands/utils.py @@ -0,0 +1,511 @@ +import os +import sys +import time +import inspect +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import pytest + +from ..auth import ( + get_fireworks_account_id, + get_fireworks_api_base, + get_fireworks_api_key, + verify_api_key_and_get_account_id, +) +from ..fireworks_rft import _map_api_host_to_app_host + + +@dataclass +class DiscoveredTest: + module_path: str + module_name: str + qualname: str + file_path: str + lineno: int | None + has_parametrize: bool + param_count: int + nodeids: List[str] + + +def _is_eval_protocol_test(obj: Any) -> bool: + """Return True if the given object looks like an eval_protocol evaluation test.""" + # evaluation_test decorator returns a dual_mode_wrapper with _origin_func and pytest marks + if not callable(obj): + return False + origin = getattr(obj, "_origin_func", None) + if origin is None: + return False + # Must have pytest marks from evaluation_test + marks = getattr(obj, "pytestmark", []) + # Handle pytest proxy objects (APIRemovedInV1Proxy) + if not isinstance(marks, (list, tuple)): + try: + marks = list(marks) if marks else [] + except (TypeError, AttributeError): + return False + return len(marks) > 0 + + +def _extract_param_info_from_marks(obj: Any) -> tuple[bool, int, list[str]]: + """Extract parametrization info from pytest marks. + + Returns: + (has_parametrize, param_count, param_ids) + """ + marks = getattr(obj, "pytestmark", []) + + # Handle pytest proxy objects (APIRemovedInV1Proxy) - same as _is_eval_protocol_test + if not isinstance(marks, (list, tuple)): + try: + marks = list(marks) if marks else [] + except (TypeError, AttributeError): + marks = [] + + has_parametrize = False + total_combinations = 0 + all_param_ids: list[str] = [] + + for m in marks: + if getattr(m, "name", "") == "parametrize": + has_parametrize = True + # The data is in kwargs for eval_protocol's parametrization + kwargs = getattr(m, "kwargs", {}) + argnames = kwargs.get("argnames", m.args[0] if m.args else "") + argvalues = kwargs.get("argvalues", m.args[1] if len(m.args) > 1 else []) + ids = kwargs.get("ids", []) + + # Count this dimension of parameters + if isinstance(argvalues, (list, tuple)): + count = len(argvalues) + total_combinations = count # For now, just use the count from this mark + + # Use provided IDs + if ids and isinstance(ids, (list, tuple)): + all_param_ids = list(ids[:count]) + else: + # Generate IDs based on argnames + if isinstance(argnames, str) and "," not in argnames: + # Single parameter + all_param_ids = [f"{argnames}={i}" for i in range(count)] + else: + # Multiple parameters + all_param_ids = [f"variant_{i}" for i in range(count)] + + return has_parametrize, total_combinations, all_param_ids + + +def _discover_tests(root: str) -> list[DiscoveredTest]: + """Discover eval_protocol tests under the given root directory.""" + abs_root = os.path.abspath(root) + if abs_root not in sys.path: + sys.path.insert(0, abs_root) + + discovered: list[DiscoveredTest] = [] + + class CollectionPlugin: + """Plugin to capture collected items without running code.""" + + def __init__(self) -> None: + self.items: list[Any] = [] + + def pytest_ignore_collect(self, collection_path, config): # type: ignore[override] + """Ignore problematic files before pytest tries to import them.""" + # Ignore specific files + ignored_files = ["setup.py", "versioneer.py", "conf.py", "__main__.py"] + if collection_path.name in ignored_files: + return True + + # Ignore hidden files (starting with .) + if collection_path.name.startswith("."): + return True + + # Ignore test_discovery files + if collection_path.name.startswith("test_discovery"): + return True + + return None + + def pytest_collection_modifyitems(self, items): # type: ignore[override] + """Hook called after collection is done.""" + self.items = items + + plugin = CollectionPlugin() + + # Run pytest collection only (--collect-only prevents code execution) + # Override python_files to collect from ANY .py file + args = [ + abs_root, + "--collect-only", + "-q", + "--pythonwarnings=ignore", + "-o", + "python_files=*.py", # Override to collect all .py files + ] + + try: + # Suppress pytest output + import io + import contextlib + + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + pytest.main(args, plugins=[plugin]) + except Exception: + # If pytest collection fails, fall back to empty list + return [] + + # Process collected items + for item in plugin.items: + if not hasattr(item, "obj"): + continue + + obj = item.obj + if not _is_eval_protocol_test(obj): + continue + + origin = getattr(obj, "_origin_func", obj) + try: + src_file = inspect.getsourcefile(origin) or str(item.path) + _, lineno = inspect.getsourcelines(origin) + except Exception: + src_file, lineno = str(item.path), None + + # Extract parametrization info from marks + has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj) + + # Get module name and function name + module_name = ( + item.module.__name__ # type: ignore[attr-defined] + if hasattr(item, "module") + else item.nodeid.split("::")[0].replace("/", ".").replace(".py", "") + ) + func_name = item.name.split("[")[0] if "[" in item.name else item.name + + # Generate nodeids + base_nodeid = f"{os.path.basename(src_file)}::{func_name}" + if param_ids: + nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids] + else: + nodeids = [base_nodeid] + + discovered.append( + DiscoveredTest( + module_path=module_name, + module_name=module_name, + qualname=f"{module_name}.{func_name}", + file_path=os.path.abspath(src_file), + lineno=lineno, + has_parametrize=has_parametrize, + param_count=param_count, + nodeids=nodeids, + ) + ) + + # Deduplicate by qualname (in case same test appears multiple times) + by_qual: dict[str, DiscoveredTest] = {} + for t in discovered: + existing = by_qual.get(t.qualname) + if not existing or t.param_count > existing.param_count: + by_qual[t.qualname] = t + return sorted(by_qual.values(), key=lambda x: (x.file_path, x.lineno or 0)) + + +def _format_test_choice(test: DiscoveredTest, idx: int) -> str: + """Format a test as a choice string for display.""" + # Shorten the qualname for display + name = test.qualname.split(".")[-1] + location = f"{Path(test.file_path).name}:{test.lineno}" if test.lineno else Path(test.file_path).name + + if test.has_parametrize and test.param_count > 1: + return f"{name} ({test.param_count} variants) - {location}" + else: + return f"{name} - {location}" + + +def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTest]: + """Interactive selection with arrow keys using questionary.""" + try: + import questionary + from questionary import Style + + # Custom style similar to Vercel CLI + custom_style = Style( + [ + ("qmark", "fg:#673ab7 bold"), + ("question", "bold"), + ("answer", "fg:#f44336 bold"), + ("pointer", "fg:#673ab7 bold"), + ("highlighted", "fg:#673ab7 bold"), + ("selected", "fg:#cc5454"), + ("separator", "fg:#cc5454"), + ("instruction", ""), + ("text", ""), + ] + ) + + # Check if only one test - auto-select it + if len(tests) == 1: + print(f"\nFound 1 test: {_format_test_choice(tests[0], 1)}") + confirm = questionary.confirm("Select this test?", default=True, style=custom_style).ask() + if confirm: + return tests + else: + return [] + + # Single-select UX + print("\n") + print("Tip: Use ↑/↓ arrows to navigate and press ENTER to select.\n") + + choices = [] + for idx, t in enumerate(tests, 1): + choice_text = _format_test_choice(t, idx) + choices.append({"name": choice_text, "value": idx - 1}) + + selected = questionary.select( + "Select an evaluation test to upload:", choices=choices, style=custom_style + ).ask() + + if selected is None: # Ctrl+C + print("\nUpload cancelled.") + return [] + + print("\nāœ“ Selected 1 test") + return [tests[selected]] + + except ImportError: + # Fallback to simpler implementation + return _prompt_select_fallback(tests) + except KeyboardInterrupt: + print("\n\nUpload cancelled.") + return [] + + +def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest]: + """Fallback prompt selection for when questionary is not available.""" + print("\n" + "=" * 80) + print("Discovered evaluation tests:") + print("=" * 80) + print("\nTip: Install questionary for better UX: pip install questionary\n") + + for idx, t in enumerate(tests, 1): + loc = f"{t.file_path}:{t.lineno}" if t.lineno else t.file_path + print(f" [{idx}] {t.qualname}") + print(f" Location: {loc}") + + if t.has_parametrize and t.nodeids: + print(f" Parameterized: {t.param_count} variant(s)") + # Show first few variants as examples + example_nodeids = t.nodeids[:3] + for nodeid in example_nodeids: + # Extract just the parameter part for display + if "[" in nodeid: + param_part = nodeid.split("[", 1)[1].rstrip("]") + print(f" - {param_part}") + if len(t.nodeids) > 3: + print(f" ... and {len(t.nodeids) - 3} more") + else: + print(" Type: Single test (no parametrization)") + print() + + print("=" * 80) + try: + choice = input("Enter the number to select: ").strip() + except KeyboardInterrupt: + print("\n\nUpload cancelled.") + return [] + + if not choice.isdigit(): + print("\nāš ļø Invalid selection.") + return [] + n = int(choice) + if not (1 <= n <= len(tests)): + print("\nāš ļø Selection out of range.") + return [] + return [tests[n - 1]] + + +def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]: + """Prompt user to select tests to upload.""" + if non_interactive: + return tests + + return _prompt_select_interactive(tests) + + +def _discover_and_select_tests(project_root: str, non_interactive: bool) -> Optional[list[DiscoveredTest]]: + """Discover evaluation tests under the given root and prompt the user to select some. + + Returns a list of selected tests, or None if discovery/selection failed or the user + cancelled. Callers are responsible for enforcing additional constraints (e.g. exactly + one selection). + """ + print("Scanning for evaluation tests...") + tests = _discover_tests(project_root) + if not tests: + print("No evaluation tests found.") + print("\nHint: Make sure your tests use the @evaluation_test decorator.") + return None + + try: + selected_tests = _prompt_select(tests, non_interactive=non_interactive) + except Exception: + print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.") + return None + + if not selected_tests: + print("No tests selected.") + return None + + return selected_tests + + +def _normalize_evaluator_id(evaluator_id: str) -> str: + """ + Normalize evaluator ID to meet Fireworks requirements: + - Only lowercase a-z, 0-9, and hyphen (-) + - Maximum 63 characters + """ + import re + + # Convert to lowercase + normalized = evaluator_id.lower() + + # Replace underscores with hyphens + normalized = normalized.replace("_", "-") + + # Remove any characters that aren't alphanumeric or hyphen + normalized = re.sub(r"[^a-z0-9-]", "", normalized) + + # Remove consecutive hyphens + normalized = re.sub(r"-+", "-", normalized) + + # Remove leading/trailing hyphens + normalized = normalized.strip("-") + + # Ensure it starts with a letter (Fireworks requirement) + if normalized and not normalized[0].isalpha(): + normalized = "eval-" + normalized + + # Truncate to 63 characters + if len(normalized) > 63: + normalized = normalized[:63].rstrip("-") + + return normalized + + +def _ensure_account_id() -> Optional[str]: + """Resolve and cache FIREWORKS_ACCOUNT_ID if possible.""" + account_id = get_fireworks_account_id() + api_key = get_fireworks_api_key() + if not account_id and api_key: + resolved = verify_api_key_and_get_account_id(api_key=api_key, api_base=get_fireworks_api_base()) + if resolved: + os.environ["FIREWORKS_ACCOUNT_ID"] = resolved + return resolved + return account_id + + +def _extract_terminal_segment(resource_name: str) -> str: + """Return the last path segment if a fully-qualified resource name is provided.""" + try: + return resource_name.strip("/").split("/")[-1] + except Exception: + return resource_name + + +def _build_evaluator_dashboard_url(evaluator_id: str) -> str: + """Build the evaluator dashboard URL for the given evaluator id or resource name.""" + api_base = get_fireworks_api_base() + app_base = _map_api_host_to_app_host(api_base) + evaluator_slug = _extract_terminal_segment(evaluator_id) + return f"{app_base}/dashboard/evaluators/{evaluator_slug}" + + +def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> None: + """Print dashboard links for evaluator, dataset, and optional RFT job.""" + evaluator_url = _build_evaluator_dashboard_url(evaluator_id) + print("\nšŸ“Š Dashboard Links:") + print(f" Evaluator: {evaluator_url}") + if dataset_id: + api_base = get_fireworks_api_base() + app_base = _map_api_host_to_app_host(api_base) + print(f" Dataset: {app_base}/dashboard/datasets/{dataset_id}") + if job_name: + # job_name likely like accounts/{account}/reinforcementFineTuningJobs/{id} + try: + job_id = job_name.strip().split("/")[-1] + print(f" RFT Job: {app_base}/dashboard/fine-tuning/reinforcement/{job_id}") + except Exception: + pass + + +def _build_trimmed_dataset_id(evaluator_id: str) -> str: + """Build a dataset id derived from evaluator_id, trimmed to 63 chars. + + Format: -dataset-YYYYMMDDHHMMSS, where base is trimmed to fit. + """ + base = _normalize_evaluator_id(evaluator_id) + suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}" + max_total = 63 + max_base_len = max_total - len(suffix) + if max_base_len < 1: + max_base_len = 1 + if len(base) > max_base_len: + base = base[:max_base_len].rstrip("-") + if not base: + base = "dataset" + # Ensure first char is a letter + if not base: + base = "dataset" + if not base[0].isalpha(): + base = f"eval-{base}" + if len(base) > max_base_len: + base = base[:max_base_len] + base = base.rstrip("-") or "dataset" + return f"{base}{suffix}" + + +def _resolve_selected_test( + project_root: str, + evaluator_id: Optional[str], + selected_tests: Optional[list[DiscoveredTest]] = None, +) -> tuple[Optional[str], Optional[str]]: + """ + Resolve a single test's source file path and function name to use downstream. + Priority: + 1) If selected_tests provided and length == 1, use it. + 2) Else discover tests; if exactly one test, use it. + 3) Else, if evaluator_id provided, match by normalized '-'. + Returns: (file_path, func_name) or (None, None) if unresolved. + """ + try: + tests = selected_tests if selected_tests is not None else _discover_tests(project_root) + if not tests: + return None, None + if len(tests) == 1: + return tests[0].file_path, tests[0].qualname.split(".")[-1] + if evaluator_id: + for t in tests: + func_name = t.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] + candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + if candidate == evaluator_id: + return t.file_path, func_name + return None, None + except Exception: + return None, None + + +def _build_entry_point(project_root: str, source_file_path: Optional[str], func_name: str) -> str: + """Build a pytest-style entry point (path::func) relative to the given root.""" + if source_file_path: + abs_path = os.path.abspath(source_file_path) + try: + rel = os.path.relpath(abs_path, project_root) + except Exception: + rel = abs_path + return f"{rel}::{func_name}" + # Fallback: use filename only + return f"{func_name}.py::{func_name}" diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft_infer.py index 9ef7d707..9b54099e 100644 --- a/tests/test_cli_create_rft_infer.py +++ b/tests/test_cli_create_rft_infer.py @@ -15,7 +15,14 @@ def _write_json(path: str, data: dict) -> None: json.dump(data, f) -def test_create_rft_passes_all_flags_into_request_body(tmp_path, monkeypatch): +@pytest.fixture +def rft_test_harness(tmp_path, monkeypatch): + """ + Common setup for create_rft_command tests: + - Creates a temp project and chdirs into it + - Sets FIREWORKS_* env vars + - Stubs out upload / polling / evaluator activation to avoid real network calls + """ # Isolate HOME and CWD monkeypatch.setenv("HOME", str(tmp_path / "home")) project = tmp_path / "proj" @@ -27,6 +34,20 @@ def test_create_rft_passes_all_flags_into_request_body(tmp_path, monkeypatch): monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + # Stub selector, upload, polling, and evaluator activation + import eval_protocol.cli_commands.upload as upload_mod + + monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) + monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + monkeypatch.setattr(cr, "_ensure_evaluator_active", lambda *a, **k: True) + + return project + + +def test_create_rft_passes_all_flags_into_request_body(rft_test_harness, monkeypatch): + project = rft_test_harness + # Provide dataset via --dataset-jsonl ds_path = project / "dataset.jsonl" ds_path.write_text('{"input":"x"}\n', encoding="utf-8") @@ -151,25 +172,14 @@ def _fake_create_job(account_id, api_key, api_base, body): assert wb["apiKey"] == "key123" -def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, monkeypatch): - # Isolate HOME so expanduser paths remain inside tmp - monkeypatch.setenv("HOME", str(tmp_path / "home")) - - # Create a fake project and chdir into it (create_rft uses os.getcwd()) - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) +def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_harness, monkeypatch): + project = rft_test_harness # Create a dummy dataset jsonl file ds_path = project / "evaluator" / "dummy_dataset.jsonl" ds_path.parent.mkdir(parents=True, exist_ok=True) ds_path.write_text('{"input":"x"}\n', encoding="utf-8") - # Env required by create_rft_command - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - # Stub out networked/subcommands used by create_rft # Patch selector and upload import eval_protocol.cli_commands.upload as upload_mod @@ -179,7 +189,8 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, one_file.parent.mkdir(parents=True, exist_ok=True) one_file.write_text("# single", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_single", file_path=str(one_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + # New flow uses _discover_and_select_tests; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) @@ -216,6 +227,10 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d setattr(args, "max_context_length", None) setattr(args, "chunk_size", None) setattr(args, "eval_auto_carveout", None) + setattr(args, "skip_validation", True) + setattr(args, "ignore_docker", False) + setattr(args, "docker_build_extra", "") + setattr(args, "docker_run_extra", "") rc = cr.create_rft_command(args) assert rc == 0 @@ -225,14 +240,9 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d assert captured["dataset_id"].startswith("test-single-test-single-dataset-") -def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(tmp_path, monkeypatch): - # Ensure expanduser paths stay under tmp - monkeypatch.setenv("HOME", str(tmp_path / "home")) - - # Project structure and CWD - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) +def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(rft_test_harness, monkeypatch): + # Project structure and CWD from shared harness + project = rft_test_harness # Create dummy test files for discovery eval_dir = project / "evaluator" @@ -247,26 +257,8 @@ def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(t svg_disc = SimpleNamespace(qualname="bar_eval.test_baz_evaluation", file_path=str(svg_file)) monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [cal_disc, svg_disc]) - # Env for CLI - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - - # Capture what upload receives (id and entry) - captured = {"id": None, "entry": None, "dataset_id": None} - - # Monkeypatch the upload command from the upload module (the function imports it inside) - import eval_protocol.cli_commands.upload as upload_mod - - def _fake_upload(ns): - captured["id"] = getattr(ns, "id", None) - captured["entry"] = getattr(ns, "entry", None) - return 0 - - monkeypatch.setattr(upload_mod, "upload_command", _fake_upload) - - # Avoid network and capture dataset id - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # Capture dataset id used during dataset creation + captured = {"dataset_id": None} def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path): captured["dataset_id"] = dataset_id @@ -304,16 +296,16 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) assert rc == 0 - # Assert evaluator_id passed to upload matches the provided id - assert captured["id"] == cr._normalize_evaluator_id("foo_eval-test_bar_evaluation") - # Assert entry points to the foo test (should map when id matches normalization) - assert captured["entry"] is not None and captured["entry"].endswith("foo_eval.py::test_bar_evaluation") - # Assert dataset id is derived from the same evaluator id (trimmed base + '-dataset-') + # Assert dataset id is derived from the evaluator id (trimmed base + '-dataset-') assert captured["dataset_id"] is not None expected_prefix = ( cr._build_trimmed_dataset_id(cr._normalize_evaluator_id("foo_eval-test_bar_evaluation")).split("-dataset-")[0] @@ -322,37 +314,20 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d assert captured["dataset_id"].startswith(expected_prefix) -def test_create_rft_interactive_selector_single_test(tmp_path, monkeypatch): - # Setup project - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) +def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypatch): + # Setup project using shared harness + project = rft_test_harness # Single discovered test test_file = project / "metric" / "test_one.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# one", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_one", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) - - # Environment - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - - # Stub selector to return the single test; stub upload and polling - import eval_protocol.cli_commands.upload as upload_mod - - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) - captured = {"id": None, "entry": None, "dataset_id": None} - - def _fake_upload(ns): - captured["id"] = getattr(ns, "id", None) - captured["entry"] = getattr(ns, "entry", None) - return 0 + # New flow uses _discover_and_select_tests; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - monkeypatch.setattr(upload_mod, "upload_command", _fake_upload) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # Capture dataset id used during dataset creation + captured = {"dataset_id": None} # Provide dataset jsonl ds_path = project / "metric" / "dataset.jsonl" @@ -361,7 +336,7 @@ def _fake_upload(ns): cr, "create_dataset_from_jsonl", lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: ( - dataset_id, + captured.__setitem__("dataset_id", dataset_id) or dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}"}, ), ) @@ -392,12 +367,21 @@ def _fake_upload(ns): max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) assert rc == 0 - assert captured["id"] is not None - assert captured["entry"] is not None and captured["entry"].endswith("test_one.py::test_one") + # Assert dataset id is derived from the selected test's evaluator id + assert captured["dataset_id"] is not None + expected_prefix = ( + cr._build_trimmed_dataset_id(cr._normalize_evaluator_id("test_one-test_one")).split("-dataset-")[0] + + "-dataset-" + ) + assert captured["dataset_id"].startswith(expected_prefix) def test_create_rft_quiet_existing_evaluator_skips_upload(tmp_path, monkeypatch): @@ -524,31 +508,18 @@ def _raise(*a, **k): assert rc == 1 -def test_create_rft_fallback_to_dataset_builder(tmp_path, monkeypatch): - # Setup project - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) - +def test_create_rft_fallback_to_dataset_builder(rft_test_harness, monkeypatch): + project = rft_test_harness # Single discovered test without data_loaders or input_dataset test_file = project / "metric" / "test_builder.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# builder case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_builder", file_path=str(test_file)) + # New flow uses _discover_and_select_tests for evaluator resolution; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + # Also patch _discover_tests for any direct calls during dataset inference. monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) - # Environment - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - - # Stub selector, upload, and polling - import eval_protocol.cli_commands.upload as upload_mod - - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) - monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) - # Dataset builder fallback out_jsonl = project / "metric" / "builder_out.jsonl" out_jsonl.write_text('{"row":1}\n{"row":2}\n', encoding="utf-8") @@ -592,6 +563,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, ) rc = cr.create_rft_command(args) @@ -603,30 +575,15 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d assert captured["jsonl_path"] == str(out_jsonl) -def test_create_rft_uses_dataloader_jsonl_when_available(tmp_path, monkeypatch): - # Setup project - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) - +def test_create_rft_rejects_dataloader_jsonl(rft_test_harness, monkeypatch): + project = rft_test_harness # Single discovered test test_file = project / "metric" / "test_loader.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# loader case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_loader", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) - - # Environment - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - - # Stub selector, upload, and polling - import eval_protocol.cli_commands.upload as upload_mod - - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) - monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # New flow uses _discover_and_select_tests; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via dataloader extractor dl_jsonl = project / "metric" / "loader_out.jsonl" @@ -669,39 +626,28 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) - assert rc == 0 - assert captured["dataset_id"] is not None - assert captured["dataset_id"].startswith("test-loader-test-loader-dataset-") - assert captured["jsonl_path"] == str(dl_jsonl) - + # Dataloader-provided JSONL is now rejected for create rft + assert rc == 1 + assert captured["dataset_id"] is None + assert captured["jsonl_path"] is None -def test_create_rft_uses_input_dataset_jsonl_when_available(tmp_path, monkeypatch): - # Setup project - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) +def test_create_rft_uses_input_dataset_jsonl_when_available(rft_test_harness, monkeypatch): + project = rft_test_harness # Single discovered test test_file = project / "metric" / "test_input_ds.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# input_dataset case", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_input_ds", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) - - # Environment - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - - # Stub selector, upload, and polling - import eval_protocol.cli_commands.upload as upload_mod - - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) - monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) - monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) + # New flow uses _discover_and_select_tests; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Provide JSONL via input_dataset extractor id_jsonl = project / "metric" / "input_ds_out.jsonl" @@ -744,6 +690,10 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) @@ -753,16 +703,9 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d assert captured["jsonl_path"] == str(id_jsonl) -def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(tmp_path, monkeypatch): +def test_create_rft_quiet_existing_evaluator_infers_dataset_from_matching_test(rft_test_harness, monkeypatch): # Setup project with multiple tests; evaluator exists (skip upload) - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) - - # Env - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + project = rft_test_harness # Two tests discovered f1 = project / "evals" / "alpha.py" @@ -814,6 +757,14 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d # Provide evaluator_id that matches beta.test_two eval_id = cr._normalize_evaluator_id("beta-test_two") + # Ensure evaluator_id maps back to the beta test for dataset inference + monkeypatch.setattr( + cr, + "_resolve_selected_test", + lambda project_root, evaluator_id, selected_tests=None: (str(f2), "test_two") + if evaluator_id == eval_id + else (None, None), + ) args = argparse.Namespace( evaluator=eval_id, yes=True, @@ -836,6 +787,10 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) @@ -952,23 +907,17 @@ def _fake_post(url, json=None, headers=None, timeout=None): assert "jobId" not in body -def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(tmp_path, monkeypatch): +def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_harness, monkeypatch): # Setup project - project = tmp_path / "proj" - project.mkdir() - monkeypatch.chdir(project) - - # Environment - monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy") - monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") - monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") + project = rft_test_harness # Single discovered test test_file = project / "metric" / "test_pref.py" test_file.parent.mkdir(parents=True, exist_ok=True) test_file.write_text("# prefer explicit dataset_jsonl", encoding="utf-8") single_disc = SimpleNamespace(qualname="metric.test_pref", file_path=str(test_file)) - monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc]) + # New flow uses _discover_and_select_tests; patch it to return our single test. + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) # Stub selector, upload, and polling import eval_protocol.cli_commands.upload as upload_mod @@ -1027,6 +976,10 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d max_context_length=None, chunk_size=None, eval_auto_carveout=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="", + docker_run_extra="", ) rc = cr.create_rft_command(args) From 64c3b7eb9e237fb4d634de23c908f868bbd0d5cb Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 15:04:22 -0800 Subject: [PATCH 2/7] update tests --- tests/test_cli_create_rft_infer.py | 271 ++++++++++++++++++++++++----- 1 file changed, 232 insertions(+), 39 deletions(-) diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft_infer.py index 9b54099e..c6714402 100644 --- a/tests/test_cli_create_rft_infer.py +++ b/tests/test_cli_create_rft_infer.py @@ -1,12 +1,15 @@ import json import os -import time +import argparse +import requests from types import SimpleNamespace from unittest.mock import patch - import pytest from eval_protocol.cli_commands import create_rft as cr +from eval_protocol.cli_commands import upload as upload_mod +import eval_protocol.fireworks_rft as fr +from eval_protocol.cli import parse_args def _write_json(path: str, data: dict) -> None: @@ -34,9 +37,6 @@ def rft_test_harness(tmp_path, monkeypatch): monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123") monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai") - # Stub selector, upload, polling, and evaluator activation - import eval_protocol.cli_commands.upload as upload_mod - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) @@ -82,8 +82,6 @@ def _fake_create_job(account_id, api_key, api_base, body): monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job) - import argparse - args = argparse.Namespace( # Evaluator and dataset evaluator="my-evaluator", @@ -96,6 +94,10 @@ def _fake_create_job(account_id, api_key, api_base, body): dry_run=False, force=False, env_file=None, + skip_validation=True, + ignore_docker=False, + docker_build_extra="--build-extra FLAG", + docker_run_extra="--run-extra FLAG", # Model selection (exactly one) base_model="accounts/fireworks/models/llama-v3p1-8b-instruct", warm_start_from=None, @@ -171,6 +173,229 @@ def _fake_create_job(account_id, api_key, api_base, body): assert wb["runId"] == "run123" assert wb["apiKey"] == "key123" + # The validation / docker flags should not appear in the request body + for k in ("skip_validation", "ignore_docker", "docker_build_extra", "docker_run_extra"): + assert k not in body + + +def test_create_rft_evaluator_validation_fails(rft_test_harness, monkeypatch): + project = rft_test_harness + + # Valid dataset JSONL so dataset validation passes; focus on evaluator validation + ds_path = project / "dataset_valid.jsonl" + ds_path.write_text('{"messages":[{"role":"user","content":"hi"}]}\n', encoding="utf-8") + + # Single discovered test for evaluator resolution + test_file = project / "metric" / "test_eval_validation.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# dummy eval test", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_eval_validation", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + + # Force local evaluator validation to fail + calls = {"count": 0, "pytest_target": None} + + def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_build_extra, docker_run_extra): + calls["count"] += 1 + calls["pytest_target"] = pytest_target + return 1 # non-zero exit code => validation failure + + monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + + args = argparse.Namespace( + evaluator=None, + yes=True, + dry_run=True, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model="accounts/fireworks/models/llama-v3p1-8b-instruct", + warm_start_from=None, + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + skip_validation=False, + ignore_docker=True, + docker_build_extra="", + docker_run_extra="", + ) + + rc = cr.create_rft_command(args) + assert rc == 1 + # Evaluator validation should have been invoked once and failed + assert calls["count"] == 1 + assert isinstance(calls["pytest_target"], str) + assert "test_eval_validation.py::test_eval_validation" in calls["pytest_target"] + + +def test_create_rft_evaluator_validation_passes(rft_test_harness, monkeypatch): + project = rft_test_harness + + # Valid dataset JSONL so dataset validation passes; focus on evaluator validation + ds_path = project / "dataset_valid.jsonl" + ds_path.write_text('{"messages":[{"role":"user","content":"hi"}]}\n', encoding="utf-8") + + # Single discovered test for evaluator resolution + test_file = project / "metric" / "test_eval_ok.py" + test_file.parent.mkdir(parents=True, exist_ok=True) + test_file.write_text("# dummy ok eval test", encoding="utf-8") + single_disc = SimpleNamespace(qualname="metric.test_eval_ok", file_path=str(test_file)) + monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) + + # Force local evaluator validation to succeed + calls = {"count": 0, "pytest_target": None} + + def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_build_extra, docker_run_extra): + calls["count"] += 1 + calls["pytest_target"] = pytest_target + return 0 # success + + monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + + args = argparse.Namespace( + evaluator=None, + yes=True, + dry_run=True, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model="accounts/fireworks/models/llama-v3p1-8b-instruct", + warm_start_from=None, + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + skip_validation=False, + ignore_docker=True, + docker_build_extra="", + docker_run_extra="", + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + # Evaluator validation should have been invoked once and passed + assert calls["count"] == 1 + assert isinstance(calls["pytest_target"], str) + assert "test_eval_ok.py::test_eval_ok" in calls["pytest_target"] + + +def test_create_rft_dataset_validation_fails(rft_test_harness, monkeypatch): + project = rft_test_harness + + # Invalid dataset JSONL (schema mismatch for EvaluationRow) + ds_path = project / "dataset_invalid.jsonl" + ds_path.write_text('{"messages": "not-a-list"}\n', encoding="utf-8") + + # Ensure evaluator validation would pass if reached (so failure is from dataset) + calls = {"evaluator_validation_calls": 0} + + def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_build_extra, docker_run_extra): + calls["evaluator_validation_calls"] += 1 + return 0 + + monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + + args = argparse.Namespace( + evaluator="my-evaluator", + yes=True, + dry_run=True, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model="accounts/fireworks/models/llama-v3p1-8b-instruct", + warm_start_from=None, + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + skip_validation=False, + ignore_docker=True, + docker_build_extra="", + docker_run_extra="", + ) + + rc = cr.create_rft_command(args) + assert rc == 1 + # Dataset validation should fail before evaluator validation is invoked + assert calls["evaluator_validation_calls"] == 0 + + +def test_create_rft_dataset_validation_passes(rft_test_harness, monkeypatch): + project = rft_test_harness + + # Valid dataset JSONL compatible with EvaluationRow + ds_path = project / "dataset_valid_evalrow.jsonl" + ds_path.write_text('{"messages":[{"role":"user","content":"hi"}]}\n', encoding="utf-8") + + # Evaluator validation should run and succeed + calls = {"evaluator_validation_calls": 0} + + def _fake_run_evaluator_test(project_root, pytest_target, ignore_docker, docker_build_extra, docker_run_extra): + calls["evaluator_validation_calls"] += 1 + return 0 + + monkeypatch.setattr(cr, "run_evaluator_test", _fake_run_evaluator_test) + + args = argparse.Namespace( + evaluator="my-evaluator", + yes=True, + dry_run=True, + force=False, + env_file=None, + dataset=None, + dataset_jsonl=str(ds_path), + dataset_display_name=None, + dataset_builder=None, + base_model="accounts/fireworks/models/llama-v3p1-8b-instruct", + warm_start_from=None, + output_model=None, + n=None, + max_tokens=None, + learning_rate=None, + batch_size=None, + epochs=None, + lora_rank=None, + max_context_length=None, + chunk_size=None, + eval_auto_carveout=None, + skip_validation=False, + ignore_docker=True, + docker_build_extra="", + docker_run_extra="", + ) + + rc = cr.create_rft_command(args) + assert rc == 0 + # Dataset validation should pass; evaluator validation may be skipped when no local test is associated + def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_harness, monkeypatch): project = rft_test_harness @@ -180,10 +405,6 @@ def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(rft_test_ ds_path.parent.mkdir(parents=True, exist_ok=True) ds_path.write_text('{"input":"x"}\n', encoding="utf-8") - # Stub out networked/subcommands used by create_rft - # Patch selector and upload - import eval_protocol.cli_commands.upload as upload_mod - # Simulate exactly one discovered test and selector returning it one_file = project / "metric" / "test_single.py" one_file.parent.mkdir(parents=True, exist_ok=True) @@ -272,7 +493,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d ds_path.write_text('{"input":"x"}\n', encoding="utf-8") # Build args: no explicit evaluator id, selector will not be used here; mapping by id - import argparse args = argparse.Namespace( evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"), @@ -343,7 +563,6 @@ def test_create_rft_interactive_selector_single_test(rft_test_harness, monkeypat monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) # Run without evaluator_id; use --yes so selector returns tests directly (no UI) - import argparse args = argparse.Namespace( evaluator=None, @@ -419,8 +638,6 @@ def raise_for_status(self): ) monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) - import argparse - args = argparse.Namespace( evaluator="some-eval", yes=True, @@ -463,8 +680,6 @@ def test_create_rft_quiet_new_evaluator_ambiguous_without_entry_errors(tmp_path, def _raise(*a, **k): raise requests.exceptions.RequestException("nope") - import requests - monkeypatch.setattr(cr.requests, "get", _raise) # Two discovered tests (ambiguous) @@ -476,8 +691,6 @@ def _raise(*a, **k): d2 = SimpleNamespace(qualname="b.test_two", file_path=str(f2)) monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [d1, d2]) - import argparse - args = argparse.Namespace( evaluator="some-eval", yes=True, @@ -539,7 +752,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) # Run without dataset inputs so builder path is used - import argparse args = argparse.Namespace( evaluator=None, @@ -602,8 +814,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) - import argparse - args = argparse.Namespace( evaluator=None, yes=True, @@ -666,8 +876,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) - import argparse - args = argparse.Namespace( evaluator=None, yes=True, @@ -753,8 +961,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) - import argparse - # Provide evaluator_id that matches beta.test_two eval_id = cr._normalize_evaluator_id("beta-test_two") # Ensure evaluator_id maps back to the beta test for dataset inference @@ -817,13 +1023,8 @@ def json(self): def raise_for_status(self): return None - from eval_protocol.cli_commands import create_rft as cr - monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp()) - # Capture URL and JSON via fireworks layer - import eval_protocol.fireworks_rft as fr - captured = {"url": None, "json": None} class _RespPost: @@ -839,9 +1040,6 @@ def _fake_post(url, json=None, headers=None, timeout=None): monkeypatch.setattr(fr.requests, "post", _fake_post) - # Build args via CLI parser to validate flag names - from eval_protocol.cli import parse_args - argv = [ "create", "rft", @@ -919,9 +1117,6 @@ def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(rft_test_h # New flow uses _discover_and_select_tests; patch it to return our single test. monkeypatch.setattr(cr, "_discover_and_select_tests", lambda cwd, non_interactive=False: [single_disc]) - # Stub selector, upload, and polling - import eval_protocol.cli_commands.upload as upload_mod - monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) @@ -952,8 +1147,6 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl) monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"}) - import argparse - args = argparse.Namespace( evaluator=None, yes=True, From 9f220282f809e4eeb054043762e9811ee190ccef Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 15:20:44 -0800 Subject: [PATCH 3/7] update to separate validation and upload of dataset --- eval_protocol/cli_commands/create_rft.py | 137 +++++++++++++++-------- tests/test_cli_create_rft_infer.py | 28 ++++- tests/test_cli_local_test.py | 3 +- 3 files changed, 120 insertions(+), 48 deletions(-) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index e277e078..ccc2fe55 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -363,18 +363,21 @@ def _resolve_evaluator( return evaluator_id, evaluator_resource_name, selected_test_file_path, selected_test_func_name -def _resolve_and_prepare_dataset( +def _resolve_dataset( project_root: str, account_id: str, - api_key: str, - api_base: str, evaluator_id: str, args: argparse.Namespace, selected_test_file_path: Optional[str], selected_test_func_name: Optional[str], - dry_run: bool, ) -> tuple[Optional[str], Optional[str], Optional[str]]: - """Resolve dataset id/resource and ensure dataset exists if using JSONL.""" + """Resolve dataset source without performing any uploads. + + Returns a tuple of: + - dataset_id: existing dataset id when using --dataset or fully-qualified dataset resource + - dataset_resource: fully-qualified dataset resource for existing datasets; None for JSONL sources + - dataset_jsonl: local JSONL path when using --dataset-jsonl or inferred sources; None for id-only datasets + """ dataset_id = getattr(args, "dataset", None) dataset_jsonl = getattr(args, "dataset_jsonl", None) dataset_display_name = getattr(args, "dataset_display_name", None) @@ -432,40 +435,72 @@ def _resolve_and_prepare_dataset( ) return None, None, None - inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) - if dry_run: - print("--dry-run: would create dataset and upload JSONL") - dataset_id = inferred_dataset_id - else: - try: - # Resolve dataset_jsonl path relative to CWD if needed - jsonl_path_for_upload = ( - dataset_jsonl - if os.path.isabs(dataset_jsonl) - else os.path.abspath(os.path.join(project_root, dataset_jsonl)) - ) - dataset_id, _ = create_dataset_from_jsonl( - account_id=account_id, - api_key=api_key, - api_base=api_base, - dataset_id=inferred_dataset_id, - display_name=dataset_display_name or inferred_dataset_id, - jsonl_path=jsonl_path_for_upload, - ) - print(f"āœ“ Created and uploaded dataset: {dataset_id}") - except Exception as e: - print(f"Error creating/uploading dataset: {e}") - return None, None, None - - if not dataset_id: - return None, None, None + # Build dataset resource for existing datasets; JSONL-based datasets will be uploaded later. + dataset_resource = None + if dataset_id: + dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}" - # Build dataset resource (prefer override when provided) - dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}" return dataset_id, dataset_resource, dataset_jsonl -def _ensure_evaluator_active( +def _upload_dataset( + project_root: str, + account_id: str, + api_key: str, + api_base: str, + evaluator_id: str, + dataset_id: Optional[str], + dataset_resource: Optional[str], + dataset_jsonl: Optional[str], + args: argparse.Namespace, + dry_run: bool, +) -> tuple[Optional[str], Optional[str]]: + """Create/upload the dataset when using a local JSONL source. + + For existing datasets (--dataset or fully-qualified ids), this is a no-op that + simply ensures dataset_id and dataset_resource are populated. + """ + # Existing dataset case: nothing to upload + if not dataset_jsonl: + if not dataset_id: + return None, None + if not dataset_resource: + dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}" + return dataset_id, dataset_resource + + # JSONL-based dataset: upload or simulate upload + inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) + dataset_display_name = getattr(args, "dataset_display_name", None) or inferred_dataset_id + + # Resolve dataset_jsonl path relative to CWD if needed + jsonl_path_for_upload = ( + dataset_jsonl if os.path.isabs(dataset_jsonl) else os.path.abspath(os.path.join(project_root, dataset_jsonl)) + ) + + if dry_run: + print("--dry-run: would create dataset and upload JSONL") + dataset_id = inferred_dataset_id + dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}" + return dataset_id, dataset_resource + + try: + dataset_id, _ = create_dataset_from_jsonl( + account_id=account_id, + api_key=api_key, + api_base=api_base, + dataset_id=inferred_dataset_id, + display_name=dataset_display_name, + jsonl_path=jsonl_path_for_upload, + ) + print(f"āœ“ Created and uploaded dataset: {dataset_id}") + dataset_resource = f"accounts/{account_id}/datasets/{dataset_id}" + return dataset_id, dataset_resource + except Exception as e: + print(f"Error creating/uploading dataset: {e}") + return None, None + + +def _upload_and_ensure_evaluator( project_root: str, evaluator_id: str, evaluator_resource_name: str, @@ -726,19 +761,17 @@ def create_rft_command(args) -> int: if not evaluator_id or not evaluator_resource_name: return 1 - # 2) Resolve dataset (id/resource) and underlying JSONL (if any) - dataset_id, dataset_resource, dataset_jsonl = _resolve_and_prepare_dataset( + # 2) Resolve dataset source (id or JSONL path) + dataset_id, dataset_resource, dataset_jsonl = _resolve_dataset( project_root=project_root, account_id=account_id, - api_key=api_key, - api_base=api_base, evaluator_id=evaluator_id, args=args, selected_test_file_path=selected_test_file_path, selected_test_func_name=selected_test_func_name, - dry_run=dry_run, ) - if not dataset_id or not dataset_resource: + # Require either an existing dataset id or a JSONL source to materialize from + if dataset_jsonl is None and not dataset_id: return 1 # 3) Optional local validation @@ -758,8 +791,24 @@ def create_rft_command(args) -> int: ): return 1 - # 4) Ensure evaluator exists and is ACTIVE (upload + poll if needed) - if not _ensure_evaluator_active( + # 4) Upload dataset when using JSONL sources (no-op for existing datasets) + dataset_id, dataset_resource = _upload_dataset( + project_root=project_root, + account_id=account_id, + api_key=api_key, + api_base=api_base, + evaluator_id=evaluator_id, + dataset_id=dataset_id, + dataset_resource=dataset_resource, + dataset_jsonl=dataset_jsonl, + args=args, + dry_run=dry_run, + ) + if not dataset_id or not dataset_resource: + return 1 + + # 5) Ensure evaluator exists and is ACTIVE (upload + poll if needed) + if not _upload_and_ensure_evaluator( project_root=project_root, evaluator_id=evaluator_id, evaluator_resource_name=evaluator_resource_name, @@ -769,7 +818,7 @@ def create_rft_command(args) -> int: ): return 1 - # 5) Create the RFT job + # 6) Create the RFT job return _create_rft_job( account_id=account_id, api_key=api_key, diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft_infer.py index c6714402..71f2a064 100644 --- a/tests/test_cli_create_rft_infer.py +++ b/tests/test_cli_create_rft_infer.py @@ -40,7 +40,7 @@ def rft_test_harness(tmp_path, monkeypatch): monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0) monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True) - monkeypatch.setattr(cr, "_ensure_evaluator_active", lambda *a, **k: True) + monkeypatch.setattr(cr, "_upload_and_ensure_evaluator", lambda *a, **k: True) return project @@ -82,6 +82,25 @@ def _fake_create_job(account_id, api_key, api_base, body): monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job) + # Stub validation helpers: dataset always valid; capture evaluator validation flags + monkeypatch.setattr(cr, "_validate_dataset", lambda dataset_jsonl: True) + flag_calls = {"ignore_docker": None, "docker_build_extra": None, "docker_run_extra": None} + + def _fake_validate_evaluator_locally( + project_root, + selected_test_file, + selected_test_func, + ignore_docker, + docker_build_extra, + docker_run_extra, + ): + flag_calls["ignore_docker"] = ignore_docker + flag_calls["docker_build_extra"] = docker_build_extra + flag_calls["docker_run_extra"] = docker_run_extra + return True + + monkeypatch.setattr(cr, "_validate_evaluator_locally", _fake_validate_evaluator_locally) + args = argparse.Namespace( # Evaluator and dataset evaluator="my-evaluator", @@ -94,7 +113,7 @@ def _fake_create_job(account_id, api_key, api_base, body): dry_run=False, force=False, env_file=None, - skip_validation=True, + skip_validation=False, ignore_docker=False, docker_build_extra="--build-extra FLAG", docker_run_extra="--run-extra FLAG", @@ -177,6 +196,11 @@ def _fake_create_job(account_id, api_key, api_base, body): for k in ("skip_validation", "ignore_docker", "docker_build_extra", "docker_run_extra"): assert k not in body + # But they should be propagated into local evaluator validation + assert flag_calls["ignore_docker"] is False + assert flag_calls["docker_build_extra"] == "--build-extra FLAG" + assert flag_calls["docker_run_extra"] == "--run-extra FLAG" + def test_create_rft_evaluator_validation_fails(rft_test_harness, monkeypatch): project = rft_test_harness diff --git a/tests/test_cli_local_test.py b/tests/test_cli_local_test.py index 6ab0b14e..0be9c2fa 100644 --- a/tests/test_cli_local_test.py +++ b/tests/test_cli_local_test.py @@ -126,8 +126,7 @@ def test_local_test_selector_single_test(tmp_path, monkeypatch): # No entry; force discover + selector disc = SimpleNamespace(qualname="metric.test_sel", file_path=str(test_file)) - monkeypatch.setattr(lt, "_discover_tests", lambda root: [disc]) - monkeypatch.setattr(lt, "_prompt_select", lambda tests, non_interactive=False: tests[:1]) + monkeypatch.setattr(lt, "_discover_and_select_tests", lambda cwd, non_interactive=False: [disc]) monkeypatch.setattr(lt, "_find_dockerfiles", lambda root: []) called = {"host": False} From 83dc904c3652188b2283a0c06547b3a272c0fa65 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 15:39:58 -0800 Subject: [PATCH 4/7] don't validate when looking for candidate --- eval_protocol/cli_commands/create_rft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index ccc2fe55..511633cd 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -153,7 +153,7 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str) candidate_paths.append(os.path.abspath(os.path.join(os.getcwd(), dataset_path))) for candidate in candidate_paths: - if os.path.isfile(candidate) and _validate_dataset_jsonl(candidate): + if os.path.isfile(candidate): return candidate return None except Exception: From 17896a7e5d20f02b43c5e6781498e7c4bb79b931 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 15:40:50 -0800 Subject: [PATCH 5/7] rename test file --- tests/{test_cli_create_rft_infer.py => test_cli_create_rft.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_cli_create_rft_infer.py => test_cli_create_rft.py} (100%) diff --git a/tests/test_cli_create_rft_infer.py b/tests/test_cli_create_rft.py similarity index 100% rename from tests/test_cli_create_rft_infer.py rename to tests/test_cli_create_rft.py From 722d3d2c117a7a6262b34d09494290cf645c82fc Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 15:52:14 -0800 Subject: [PATCH 6/7] don't allow dataset id and dataset jsonl --- eval_protocol/cli_commands/create_rft.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 511633cd..43c11c9b 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -366,7 +366,6 @@ def _resolve_evaluator( def _resolve_dataset( project_root: str, account_id: str, - evaluator_id: str, args: argparse.Namespace, selected_test_file_path: Optional[str], selected_test_func_name: Optional[str], @@ -380,9 +379,15 @@ def _resolve_dataset( """ dataset_id = getattr(args, "dataset", None) dataset_jsonl = getattr(args, "dataset_jsonl", None) - dataset_display_name = getattr(args, "dataset_display_name", None) dataset_resource_override: Optional[str] = None + if dataset_id and dataset_jsonl: + print( + "Error: --dataset and --dataset-jsonl cannot be used together.\n" + " Use --dataset to reference an existing dataset, or --dataset-jsonl to create a new one from JSONL." + ) + return None, None, None + if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"): # Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing dataset_resource_override = dataset_id @@ -765,7 +770,6 @@ def create_rft_command(args) -> int: dataset_id, dataset_resource, dataset_jsonl = _resolve_dataset( project_root=project_root, account_id=account_id, - evaluator_id=evaluator_id, args=args, selected_test_file_path=selected_test_file_path, selected_test_func_name=selected_test_func_name, From 4a25e453524e1f48c18b063b9e7ecd8f008c3f88 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 25 Nov 2025 16:03:42 -0800 Subject: [PATCH 7/7] use func_name --- eval_protocol/cli_commands/local_test.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/eval_protocol/cli_commands/local_test.py b/eval_protocol/cli_commands/local_test.py index 545b61fd..704345ac 100644 --- a/eval_protocol/cli_commands/local_test.py +++ b/eval_protocol/cli_commands/local_test.py @@ -165,12 +165,8 @@ def local_test_command(args: argparse.Namespace) -> int: print("Error: Please select exactly one evaluation test for 'local-test'.") return 1 chosen = selected[0] - abs_path = os.path.abspath(chosen.file_path) - try: - rel = os.path.relpath(abs_path, project_root) - except Exception: - rel = abs_path - pytest_target = rel + func_name = chosen.qualname.split(".")[-1] + pytest_target = _build_entry_point(project_root, chosen.file_path, func_name) ignore_docker = bool(getattr(args, "ignore_docker", False)) build_extras_str = getattr(args, "docker_build_extra", "") or ""