diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index c8ba4594..e8c0e19e 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -425,6 +425,7 @@ def parse_args(args=None): rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id") rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") + rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") # Run command (for Hydra-based evaluations) # This subparser intentionally defines no arguments itself. diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index e3687869..1f903d4b 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -5,12 +5,15 @@ import argparse from typing import Any, Dict, Optional +import requests + from ..auth import ( get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from ..common_utils import get_user_agent from ..fireworks_rft import ( _map_api_host_to_app_host, build_default_output_model, @@ -263,10 +266,72 @@ def _auto_select_evaluator_id(cwd: str) -> Optional[str]: return None +def _poll_evaluator_status( + evaluator_resource_name: str, api_key: str, api_base: str, timeout_minutes: int = 10 +) -> bool: + """ + Poll evaluator status until it becomes ACTIVE or times out. + + Args: + evaluator_resource_name: Full evaluator resource name (e.g., accounts/xxx/evaluators/yyy) + api_key: Fireworks API key + api_base: Fireworks API base URL + timeout_minutes: Maximum time to wait in minutes + + Returns: + True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED + """ + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } + + check_url = f"{api_base}/v1/{evaluator_resource_name}" + timeout_seconds = timeout_minutes * 60 + poll_interval = 10 # seconds + start_time = time.time() + + print(f"Polling evaluator status (timeout: {timeout_minutes}m, interval: {poll_interval}s)...") + + while time.time() - start_time < timeout_seconds: + try: + response = requests.get(check_url, headers=headers, timeout=30) + response.raise_for_status() + + evaluator_data = response.json() + state = evaluator_data.get("state", "STATE_UNSPECIFIED") + status = evaluator_data.get("status", "") + + if state == "ACTIVE": + print("✅ Evaluator is ACTIVE and ready!") + return True + elif state == "BUILD_FAILED": + print(f"❌ Evaluator build failed. Status: {status}") + return False + elif state == "BUILDING": + elapsed_minutes = (time.time() - start_time) / 60 + print(f"⏳ Evaluator is still building... ({elapsed_minutes:.1f}m elapsed)") + else: + print(f"⏳ Evaluator state: {state}, status: {status}") + + except requests.exceptions.RequestException as e: + print(f"Warning: Failed to check evaluator status: {e}") + + # Wait before next poll + time.sleep(poll_interval) + + # Timeout reached + elapsed_minutes = (time.time() - start_time) / 60 + print(f"⏰ Timeout after {elapsed_minutes:.1f}m - evaluator is not yet ACTIVE") + return False + + def create_rft_command(args) -> int: evaluator_id: Optional[str] = getattr(args, "evaluator_id", None) non_interactive: bool = bool(getattr(args, "yes", False)) dry_run: bool = bool(getattr(args, "dry_run", False)) + force: bool = bool(getattr(args, "force", False)) api_key = get_fireworks_api_key() if not api_key: @@ -326,12 +391,34 @@ def create_rft_command(args) -> int: id=evaluator_id, display_name=None, description=None, - force=False, + force=force, # Pass through the --force flag yes=True, + env_file=None, # Add the new env_file parameter ) + + if force: + print(f"🔄 Force flag enabled - will overwrite existing evaluator '{evaluator_id}'") + rc = upload_command(upload_args) if rc == 0: print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") + + # Poll for evaluator status + print(f"Waiting for evaluator '{evaluator_id}' to become ACTIVE...") + is_active = _poll_evaluator_status( + evaluator_resource_name=evaluator_resource_name, api_key=api_key, api_base=api_base, timeout_minutes=10 + ) + + if not is_active: + # Print helpful message with dashboard link + app_base = _map_api_host_to_app_host(api_base) + evaluator_slug = _extract_terminal_segment(evaluator_id) + dashboard_url = f"{app_base}/dashboard/evaluators/{evaluator_slug}" + + print("\n❌ Evaluator is not ready within the timeout period.") + print(f"📊 Please check the evaluator status at: {dashboard_url}") + print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.") + return 1 else: print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") except Exception as e: