diff --git a/README.md b/README.md index 6df0161c..d3800c4d 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,28 @@ Install with pip: pip install eval-protocol ``` +## Fireworks Login (REST) + +Use the CLI to sign in without gRPC. + +``` +# API key flow +eval-protocol login --api-key YOUR_KEY --account-id YOUR_ACCOUNT_ID --validate + +# OAuth2 device flow (like firectl) +eval-protocol login --oauth --issuer https://YOUR_ISSUER --client-id YOUR_PUBLIC_CLIENT_ID \ + --account-id YOUR_ACCOUNT_ID --open-browser +``` + +- Omit `--api-key` to be prompted securely. +- Omit `--account-id` to save only the key; you can add it later. +- Add `--api-base https://api.fireworks.ai` for a custom base, if needed. +- For OAuth2, you can also set env vars: `FIREWORKS_OIDC_ISSUER`, `FIREWORKS_OAUTH_CLIENT_ID`, `FIREWORKS_OAUTH_SCOPE`. + +Credentials are stored at `~/.fireworks/auth.ini` with 600 permissions and are read automatically by the SDK. + +Note: Model/LLM calls still require a Fireworks API key. OAuth login alone does not enable LLM calls yet; ensure `FIREWORKS_API_KEY` is set or saved via `eval-protocol login --api-key ...`. + ## License [MIT](LICENSE) diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index c90c6aef..adb54a67 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -2,8 +2,11 @@ import logging import os from pathlib import Path +import time from typing import Dict, Optional # Added Dict +import requests + logger = logging.getLogger(__name__) FIREWORKS_CONFIG_DIR = Path.home() / ".fireworks" @@ -36,7 +39,19 @@ def _parse_simple_auth_file(file_path: Path) -> Dict[str, str]: ): value = value[1:-1] - if key in ["api_key", "account_id"] and value: + if key in [ + "api_key", + "account_id", + "api_base", + # OAuth2-related keys + "issuer", + "client_id", + "access_token", + "refresh_token", + "expires_at", + "scope", + "token_type", + ] and value: creds[key] = value except Exception as e: logger.warning(f"Error during simple parsing of {file_path}: {e}") @@ -142,15 +157,135 @@ def get_fireworks_api_base() -> str: """ Retrieves the Fireworks API base URL. - The base URL is sourced from the FIREWORKS_API_BASE environment variable. - If not set, it defaults to "https://api.fireworks.ai". + The base URL is sourced in the following order: + 1. FIREWORKS_API_BASE environment variable. + 2. 'api_base' from the [fireworks] section of ~/.fireworks/auth.ini (or simple key=val). + 3. Defaults to "https://api.fireworks.ai". Returns: The API base URL. """ - api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") - if os.environ.get("FIREWORKS_API_BASE"): + env_api_base = os.environ.get("FIREWORKS_API_BASE") + if env_api_base: logger.debug("Using FIREWORKS_API_BASE from environment variable.") - else: - logger.debug(f"FIREWORKS_API_BASE not set in environment, defaulting to {api_base}.") - return api_base + return env_api_base + + file_api_base = _get_credential_from_config_file("api_base") + if file_api_base: + logger.debug("Using api_base from auth.ini configuration.") + return file_api_base + + default_base = "https://api.fireworks.ai" + logger.debug(f"FIREWORKS_API_BASE not set; defaulting to {default_base}.") + return default_base + + +def _get_from_env_or_file(key_name: str) -> Optional[str]: + # 1. Check env + env_val = os.environ.get(key_name.upper()) + if env_val: + return env_val + # 2. Check config file + return _get_credential_from_config_file(key_name.lower()) + + +def _write_auth_config(updates: Dict[str, str]) -> None: + """Merge-write simple key=value pairs into AUTH_INI_FILE preserving existing values.""" + FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True) + existing = _parse_simple_auth_file(AUTH_INI_FILE) + existing.update({k: v for k, v in updates.items() if v is not None}) + lines = [f"{k}={v}" for k, v in existing.items()] + AUTH_INI_FILE.write_text("\n".join(lines) + "\n") + try: + os.chmod(AUTH_INI_FILE, 0o600) + except Exception: + pass + + +def _discover_oidc(issuer: str) -> Dict[str, str]: + """Fetch OIDC discovery doc. Returns empty dict on failure.""" + try: + url = issuer.rstrip("/") + "/.well-known/openid-configuration" + resp = requests.get(url, timeout=10) + if resp.ok: + return resp.json() + except Exception: + return {} + return {} + + +def _refresh_oauth_token_if_needed() -> Optional[str]: + """Refresh OAuth access token if expired and refresh token available. Returns current/new token or None.""" + cfg = _parse_simple_auth_file(AUTH_INI_FILE) + access_token = cfg.get("access_token") + refresh_token = cfg.get("refresh_token") + expires_at_str = cfg.get("expires_at") + issuer = cfg.get("issuer") or os.environ.get("FIREWORKS_OIDC_ISSUER") + client_id = cfg.get("client_id") or os.environ.get("FIREWORKS_OAUTH_CLIENT_ID") + + # If we have no expiry, just return access token (best effort) + if not refresh_token or not issuer or not client_id: + return access_token + + now = int(time.time()) + try: + expires_at = int(expires_at_str) if expires_at_str else None + except ValueError: + expires_at = None + + # If not expired (with 60s buffer), return current token + if access_token and expires_at and expires_at - 60 > now: + return access_token + + # Attempt refresh + discovery = _discover_oidc(issuer) + token_endpoint = discovery.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + } + try: + resp = requests.post(token_endpoint, data=data, timeout=15) + if not resp.ok: + logger.warning(f"OAuth token refresh failed: {resp.status_code} {resp.text[:200]}") + return access_token + tok = resp.json() + new_access = tok.get("access_token") + new_refresh = tok.get("refresh_token") or refresh_token + expires_in = tok.get("expires_in") + new_expires_at = str(now + int(expires_in)) if expires_in else expires_at_str + _write_auth_config( + { + "access_token": new_access, + "refresh_token": new_refresh, + "expires_at": new_expires_at, + "token_type": tok.get("token_type") or cfg.get("token_type") or "Bearer", + "scope": tok.get("scope") or cfg.get("scope") or "", + } + ) + return new_access or access_token + except Exception as e: + logger.debug(f"Exception during oauth refresh: {e}") + return access_token + + +def get_auth_bearer() -> Optional[str]: + """Return a bearer token to use in Authorization. + + Priority: + 1. FIREWORKS_ACCESS_TOKEN env + 2. FIREWORKS_API_KEY env + 3. Refreshed OAuth access_token from auth.ini (if present) + 4. api_key from auth.ini + """ + env_access = os.environ.get("FIREWORKS_ACCESS_TOKEN") + if env_access: + return env_access + env_key = os.environ.get("FIREWORKS_API_KEY") + if env_key: + return env_key + refreshed = _refresh_oauth_token_if_needed() + if refreshed: + return refreshed + return _get_credential_from_config_file("api_key") diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 24307eb0..bdabe187 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -28,6 +28,7 @@ from .cli_commands.logs import logs_command from .cli_commands.preview import preview_command from .cli_commands.run_eval_cmd import hydra_cli_entry_point +from .cli_commands.login import login_command def parse_args(args=None): @@ -37,6 +38,30 @@ def parse_args(args=None): subparsers = parser.add_subparsers(dest="command", help="Command to run") + # Login command + login_parser = subparsers.add_parser( + "login", help="Sign in to Fireworks via API key or OAuth2 device flow" + ) + # API key flow + login_parser.add_argument("--api-key", help="Fireworks API key (prompted if not provided)") + # OAuth2 flow toggles + login_parser.add_argument("--oauth", action="store_true", help="Use OAuth2 device flow (like firectl)") + login_parser.add_argument("--issuer", help="OIDC issuer URL (e.g., https://auth.fireworks.ai)") + login_parser.add_argument("--client-id", help="OAuth2 public client id for device flow") + login_parser.add_argument( + "--scope", + help="OAuth2 scopes (default: 'openid offline_access email profile')", + ) + login_parser.add_argument( + "--open-browser", action="store_true", help="Attempt to open the verification URL in a browser" + ) + # Common options + login_parser.add_argument("--account-id", help="Fireworks Account ID to associate with this login") + login_parser.add_argument("--api-base", help="Custom API base (defaults to https://api.fireworks.ai)") + vgroup = login_parser.add_mutually_exclusive_group() + vgroup.add_argument("--validate", action="store_true", help="Validate account with a test API call (API key flow)") + vgroup.add_argument("--no-validate", action="store_true", help="Do not validate; just write the file") + # Preview command preview_parser = subparsers.add_parser("preview", help="Preview an evaluator with sample data") preview_parser.add_argument( @@ -338,6 +363,10 @@ def main(): if args.command == "preview": return preview_command(args) + elif args.command == "login": + # translate mutually exclusive group into a single boolean + setattr(args, "validate", bool(getattr(args, "validate", False) and not getattr(args, "no_validate", False))) + return login_command(args) elif args.command == "deploy": return deploy_command(args) elif args.command == "deploy-mcp": diff --git a/eval_protocol/cli_commands/common.py b/eval_protocol/cli_commands/common.py index 1490f704..22387722 100644 --- a/eval_protocol/cli_commands/common.py +++ b/eval_protocol/cli_commands/common.py @@ -7,6 +7,8 @@ import os from typing import Any, Dict, Iterator, List, Optional +from eval_protocol.auth import get_auth_bearer + logger = logging.getLogger(__name__) @@ -42,13 +44,22 @@ def setup_logging(verbose=False, debug=False): def check_environment(): - """Check if required environment variables are set for general commands.""" - if not os.environ.get("FIREWORKS_API_KEY"): - logger.warning("FIREWORKS_API_KEY environment variable is not set.") - logger.warning("This is required for API calls. Set this variable before running the command.") - logger.warning("Example: FIREWORKS_API_KEY=$DEV_FIREWORKS_API_KEY reward-kit [command]") - return False - return True + """Check if credentials are available for non-LLM API calls. + + Accepts either FIREWORKS_API_KEY or an OAuth bearer (FIREWORKS_ACCESS_TOKEN or tokens in auth.ini). + LLM calls elsewhere still explicitly require FIREWORKS_API_KEY. + """ + if os.environ.get("FIREWORKS_API_KEY"): + return True + bearer = get_auth_bearer() + if bearer: + if not os.environ.get("FIREWORKS_API_KEY"): + logger.info( + "Using OAuth bearer for non-LLM API calls. Note: LLM/model calls still require FIREWORKS_API_KEY." + ) + return True + logger.warning("No credentials found. Set FIREWORKS_API_KEY or login via OAuth: eval-protocol login --oauth ...") + return False def check_agent_environment(test_mode=False): diff --git a/eval_protocol/cli_commands/login.py b/eval_protocol/cli_commands/login.py new file mode 100644 index 00000000..793ae317 --- /dev/null +++ b/eval_protocol/cli_commands/login.py @@ -0,0 +1,383 @@ +import getpass +import logging +import os +import time +import webbrowser +from pathlib import Path +from typing import Dict, Optional, Tuple +import secrets +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import urlencode, urlparse, parse_qs + +import requests + +from eval_protocol.auth import ( + AUTH_INI_FILE, + FIREWORKS_CONFIG_DIR, + get_fireworks_api_base, +) + +logger = logging.getLogger(__name__) + + +def _write_auth_file_kv(entries: Dict[str, str]) -> Path: + """Write key=value entries to ~/.fireworks/auth.ini with 600 perms.""" + FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True) + # Merge with any existing keys + existing: Dict[str, str] = {} + try: + with open(AUTH_INI_FILE, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or line.startswith(";"): + continue + if "=" in line: + k, v = line.split("=", 1) + existing[k.strip()] = v.strip() + except FileNotFoundError: + pass + existing.update({k: v for k, v in entries.items() if v is not None}) + AUTH_INI_FILE.write_text("\n".join([f"{k}={v}" for k, v in existing.items()]) + "\n") + try: + os.chmod(AUTH_INI_FILE, 0o600) + except Exception: + pass + return AUTH_INI_FILE + + +def _validate_account(api_key: str, account_id: str, api_base: Optional[str]) -> bool: + """Validate API key against a specific account id using Fireworks REST API. + + Performs GET /v1/accounts/{account_id}. Returns True on HTTP 200, False otherwise. + """ + base = (api_base or get_fireworks_api_base()).rstrip("/") + url = f"{base}/v1/accounts/{account_id}" + headers = {"Authorization": f"Bearer {api_key}"} + try: + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code == 200: + logger.info("Successfully validated credentials against Fireworks API.") + return True + else: + logger.warning( + f"Validation failed (status {resp.status_code}). Response: {resp.text[:200]}" + ) + return False + except requests.exceptions.RequestException as e: + logger.warning(f"Network error during validation: {e}") + return False + + +def _discover_oidc(issuer: str) -> Dict[str, str]: + try: + resp = requests.get(issuer.rstrip("/") + "/.well-known/openid-configuration", timeout=10) + if resp.ok: + return resp.json() + except Exception: + return {} + return {} + + +def _oauth_device_flow(issuer: str, client_id: str, scope: str, open_browser: bool) -> Optional[Dict[str, str]]: + """Perform OAuth2 Device Authorization Grant and return token dict {access_token, refresh_token, expires_in, token_type, scope} or None.""" + meta = _discover_oidc(issuer) + device_endpoint = meta.get("device_authorization_endpoint") or issuer.rstrip("/") + "/oauth/device/code" + token_endpoint = meta.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + + # 1) Request device code + data = {"client_id": client_id, "scope": scope} + resp = requests.post(device_endpoint, data=data, timeout=15) + if not resp.ok: + logger.error(f"Device code request failed: {resp.status_code} {resp.text[:200]}") + return None + d = resp.json() + device_code = d.get("device_code") + verification_uri = d.get("verification_uri_complete") or d.get("verification_uri") + user_code = d.get("user_code") + interval = int(d.get("interval", 5)) + expires_in = int(d.get("expires_in", 600)) + + if not device_code or not verification_uri: + logger.error("Invalid device authorization response; missing device_code or verification_uri.") + return None + + logger.info("To authorize, visit this URL and enter the code if prompted:") + logger.info(verification_uri) + if user_code: + logger.info(f"User code: {user_code}") + if open_browser: + try: + webbrowser.open(verification_uri) + except Exception: + pass + + # 2) Poll token endpoint + start = time.time() + while True: + if time.time() - start > expires_in: + logger.error("Device code expired before authorization completed.") + return None + time.sleep(interval) + payload = { + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code, + "client_id": client_id, + } + t = requests.post(token_endpoint, data=payload, timeout=15) + if t.status_code == 200: + return t.json() + try: + err = t.json().get("error") + except Exception: + err = None + if err in ("authorization_pending", "slow_down"): + if err == "slow_down": + interval += 5 + continue + elif err == "access_denied": + logger.error("Access denied during device authorization.") + return None + else: + logger.warning(f"Unexpected token polling response: {t.status_code} {t.text[:200]}") + continue + + +def _oauth_browser_flow(issuer: str, client_id: str, scope: str) -> Optional[Dict[str, str]]: + """Perform OAuth2 Authorization Code flow using a local redirect server.""" + meta = _discover_oidc(issuer) + auth_endpoint = meta.get("authorization_endpoint") or issuer.rstrip("/") + "/oauth/authorize" + token_endpoint = meta.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + + # Start temporary local server + state = secrets.token_urlsafe(24) + code_holder: Dict[str, Optional[str]] = {"code": None, "error": None} + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): # type: ignore + try: + parsed = urlparse(self.path) + if parsed.path != "/": + self.send_error(404) + return + params = parse_qs(parsed.query) + got_state = params.get("state", [""])[0] + if got_state != state: + code_holder["error"] = f"state_mismatch" + elif "error" in params: + code_holder["error"] = params.get("error", [""])[0] + else: + code_holder["code"] = params.get("code", [None])[0] + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"

Authenticated

You can close this window." + ) + except Exception: + pass + + def log_message(self, format, *args): # type: ignore + return + + # Bind to an available port + httpd = HTTPServer(("127.0.0.1", 0), Handler) + port = httpd.server_address[1] + redirect_uri = f"http://127.0.0.1:{port}/" + + # Launch server in thread + server_thread = threading.Thread(target=httpd.serve_forever, daemon=True) + server_thread.start() + + # Build auth URL + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": redirect_uri, + "scope": scope, + "state": state, + } + auth_url = auth_endpoint + ("?" + urlencode(params)) + + try: + webbrowser.open(auth_url) + except Exception: + logger.info("Could not open a browser automatically. Please open this URL manually:") + logger.info(auth_url) + + # Wait for code up to 180 seconds + deadline = time.time() + 180 + while time.time() < deadline and code_holder["code"] is None and code_holder["error"] is None: + time.sleep(0.2) + + try: + httpd.shutdown() + except Exception: + pass + + if code_holder["error"]: + logger.error(f"OAuth error: {code_holder['error']}") + return None + if not code_holder["code"]: + logger.error("Timed out waiting for OAuth authorization.") + return None + + data = { + "grant_type": "authorization_code", + "code": code_holder["code"], + "redirect_uri": redirect_uri, + "client_id": client_id, + } + t = requests.post(token_endpoint, data=data, timeout=15) + if t.status_code == 200: + return t.json() + logger.error(f"Token exchange failed: {t.status_code} {t.text[:200]}") + return None + + +def _get_oauth_args_via_rest(account_id: Optional[str], api_base: Optional[str]) -> Optional[Dict[str, str]]: + """Try to fetch OAuth issuer/client args from Fireworks public API. + + Tries several likely endpoints; returns a dict with keys issuerUrl, clientId, cognitoDomain if found. + """ + base = (api_base or get_fireworks_api_base()).rstrip("/") + account = account_id or "" + candidates = [] + if account: + candidates.extend( + [ + f"{base}/v1/accounts/{account}:getOAuthArguments", + f"{base}/v1/accounts/{account}/oauth:arguments", + f"{base}/v1/accounts/{account}/oauth/arguments", + ] + ) + candidates.extend([f"{base}/v1/oauth:arguments", f"{base}/v1/oauth/arguments"]) # global + + for url in candidates: + try: + resp = requests.get(url, timeout=10) + if resp.status_code == 200: + data = resp.json() + # Normalize keys + return { + "issuerUrl": data.get("issuerUrl") or data.get("issuer_url"), + "clientId": data.get("clientId") or data.get("client_id"), + "cognitoDomain": data.get("cognitoDomain") or data.get("cognito_domain"), + } + except Exception: + continue + return None + + +def login_command(args) -> int: + """Handle `eval-protocol login` to store Fireworks credentials. + + - Accepts --api-key, --account-id, --api-base + - If --validate and account id provided, calls REST API to verify + - Writes ~/.fireworks/auth.ini (key=value) with 600 perms + """ + # 1) API key flow if explicitly provided + if getattr(args, "api_key", None): + api_key: Optional[str] = getattr(args, "api_key", None) + account_id: Optional[str] = getattr(args, "account_id", None) + api_base: Optional[str] = getattr(args, "api_base", None) + validate: bool = bool(getattr(args, "validate", False)) + if validate and account_id and api_key: + ok = _validate_account(api_key, account_id, api_base) + if not ok: + logger.error("Credential validation failed. Use --no-validate to write anyway.") + return 2 + entries = {"api_key": api_key} + if account_id: + entries["account_id"] = account_id + if api_base: + entries["api_base"] = api_base + path = _write_auth_file_kv(entries) + masked = api_key[:4] + "…" if len(api_key) >= 4 else "***" + logger.info(f"Saved Fireworks credentials to {path}. API key starts with: {masked}.") + if not account_id: + logger.info("No --account-id provided. You can add it later by re-running login.") + if api_base: + logger.info(f"Using custom API base: {api_base}") + return 0 + + # 2) OAuth is the default flow (even if --oauth not passed) + if getattr(args, "oauth", True): + issuer = getattr(args, "issuer", None) or os.environ.get("FIREWORKS_OIDC_ISSUER") + client_id = getattr(args, "client_id", None) or os.environ.get("FIREWORKS_OAUTH_CLIENT_ID") + scope = getattr(args, "scope", None) or os.environ.get("FIREWORKS_OAUTH_SCOPE", "openid offline_access email profile") + api_base: Optional[str] = getattr(args, "api_base", None) + account_id: Optional[str] = getattr(args, "account_id", None) + # If issuer/client not provided, try discovery via public API + if not issuer or not client_id: + discovered = _get_oauth_args_via_rest(account_id, api_base) + if discovered: + issuer = issuer or discovered.get("issuerUrl") + client_id = client_id or discovered.get("clientId") + # cognitoDomain unused here but could be logged + if not issuer or not client_id: + logger.error( + "Unable to discover OAuth issuer/client ID. Provide --issuer and --client-id, or set FIREWORKS_OIDC_ISSUER/FIREWORKS_OAUTH_CLIENT_ID, or use --api-key." + ) + return 1 + + # Try browser flow first; fallback to device flow if it fails + token = _oauth_browser_flow(issuer, client_id, scope) + if not token: + token = _oauth_device_flow(issuer, client_id, scope, open_browser=True) + if not token: + return 2 + now = int(time.time()) + expires_in = token.get("expires_in") + expires_at = str(now + int(expires_in)) if expires_in else "" + entries = { + "issuer": issuer, + "client_id": client_id, + "access_token": token.get("access_token", ""), + "refresh_token": token.get("refresh_token", ""), + "token_type": token.get("token_type", "Bearer"), + "scope": token.get("scope", scope), + "expires_at": expires_at, + } + if api_base: + entries["api_base"] = api_base + if account_id: + entries["account_id"] = account_id + path = _write_auth_file_kv(entries) + logger.info(f"Saved OAuth tokens to {path}.") + # Inform about API key requirement for LLM/model calls + has_env_key = bool(os.environ.get("FIREWORKS_API_KEY")) + has_file_key = False + try: + with open(path, "r") as f: + for line in f: + if line.strip().startswith("api_key=") and line.strip().split("=", 1)[1].strip(): + has_file_key = True + break + except Exception: + pass + if not (has_env_key or has_file_key): + logger.warning( + "No Fireworks API key detected. Model/LLM calls require FIREWORKS_API_KEY. " + "You can add it by re-running: eval-protocol login --api-key YOUR_KEY" + ) + if not account_id: + logger.info("Tip: pass --account-id to store your account for platform API calls.") + return 0 + + # 3) Fallback: prompt for API key if OAuth not selected/failed above + api_key = getpass.getpass(prompt="Enter Fireworks API key: ") + if not api_key: + logger.error("No credentials provided. Aborting login.") + return 1 + entries = {"api_key": api_key} + account_id = getattr(args, "account_id", None) + api_base = getattr(args, "api_base", None) + if account_id: + entries["account_id"] = account_id + if api_base: + entries["api_base"] = api_base + path = _write_auth_file_kv(entries) + masked = api_key[:4] + "…" if len(api_key) >= 4 else "***" + logger.info(f"Saved Fireworks credentials to {path}. API key starts with: {masked}.") + return 0 diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index fe58bb8a..d8e17a5e 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -15,7 +15,7 @@ import requests -from eval_protocol.auth import get_fireworks_account_id, get_fireworks_api_key +from eval_protocol.auth import get_fireworks_account_id, get_auth_bearer from eval_protocol.typed_interface import EvaluationMode logger = logging.getLogger(__name__) @@ -345,7 +345,7 @@ def preview(self, sample_file, max_samples=5): raise ValueError(f"No valid samples found in {sample_file}") account_id = self.account_id or get_fireworks_account_id() - auth_token = self.api_key or get_fireworks_api_key() + auth_token = self.api_key or get_auth_bearer() logger.debug(f"Preview using account_id: {account_id}") if not account_id or not auth_token: @@ -504,7 +504,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise ValueError("No code files loaded. Load metric folder(s) or provide ts_mode_config/remote_url first.") account_id = self.account_id or get_fireworks_account_id() - auth_token = self.api_key or get_fireworks_api_key() + auth_token = self.api_key or get_auth_bearer() if not auth_token or not account_id: logger.error("Authentication error: API credentials appear to be invalid or incomplete.") raise ValueError("Invalid or missing API credentials.") @@ -762,7 +762,7 @@ def _get_code_from_files(self, files): # This method seems unused now, consider def _get_authentication(self): account_id = get_fireworks_account_id() - auth_token = get_fireworks_api_key() + auth_token = get_auth_bearer() if not account_id: logger.error("Authentication error: Fireworks Account ID not found.") raise ValueError("Fireworks Account ID not found.") diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index c5c4d62e..ffd27c02 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -6,11 +6,7 @@ import requests from dotenv import find_dotenv, load_dotenv -from eval_protocol.auth import ( - get_fireworks_account_id, - get_fireworks_api_base, - get_fireworks_api_key, -) +from eval_protocol.auth import get_fireworks_api_base, get_auth_bearer logger = logging.getLogger(__name__) @@ -74,7 +70,7 @@ def create_or_update_fireworks_secret( Returns: True if successful, False otherwise. """ - resolved_api_key = api_key or get_fireworks_api_key() + resolved_api_key = api_key or get_auth_bearer() resolved_api_base = api_base or get_fireworks_api_base() resolved_account_id = account_id # Must be provided @@ -82,10 +78,7 @@ def create_or_update_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") return False - headers = { - "Authorization": f"Bearer {resolved_api_key}", - "Content-Type": "application/json", - } + headers = {"Authorization": f"Bearer {resolved_api_key}", "Content-Type": "application/json"} # The secret_id for GET/PATCH/DELETE operations is the key_name. # The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous. @@ -200,7 +193,7 @@ def get_fireworks_secret( Note: This typically does not return the secret's actual value for security reasons, but rather its metadata. """ - resolved_api_key = api_key or get_fireworks_api_key() + resolved_api_key = api_key or get_auth_bearer() resolved_api_base = api_base or get_fireworks_api_base() resolved_account_id = account_id @@ -286,6 +279,8 @@ def delete_fireworks_secret( # FIREWORKS_ACCOUNT_ID="your_fireworks_account_id" # FIREWORKS_API_BASE="https://api.fireworks.ai" # or your dev/staging endpoint + from eval_protocol.auth import get_fireworks_account_id, get_fireworks_api_key + test_account_id = get_fireworks_account_id() test_api_key = get_fireworks_api_key() # Not passed directly, functions will resolve test_api_base = get_fireworks_api_base()