diff --git a/README.md b/README.md index 6df0161c..d3800c4d 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,28 @@ Install with pip: pip install eval-protocol ``` +## Fireworks Login (REST) + +Use the CLI to sign in without gRPC. + +``` +# API key flow +eval-protocol login --api-key YOUR_KEY --account-id YOUR_ACCOUNT_ID --validate + +# OAuth2 device flow (like firectl) +eval-protocol login --oauth --issuer https://YOUR_ISSUER --client-id YOUR_PUBLIC_CLIENT_ID \ + --account-id YOUR_ACCOUNT_ID --open-browser +``` + +- Omit `--api-key` to be prompted securely. +- Omit `--account-id` to save only the key; you can add it later. +- Add `--api-base https://api.fireworks.ai` for a custom base, if needed. +- For OAuth2, you can also set env vars: `FIREWORKS_OIDC_ISSUER`, `FIREWORKS_OAUTH_CLIENT_ID`, `FIREWORKS_OAUTH_SCOPE`. + +Credentials are stored at `~/.fireworks/auth.ini` with 600 permissions and are read automatically by the SDK. + +Note: Model/LLM calls still require a Fireworks API key. OAuth login alone does not enable LLM calls yet; ensure `FIREWORKS_API_KEY` is set or saved via `eval-protocol login --api-key ...`. + ## License [MIT](LICENSE) diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index c90c6aef..adb54a67 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -2,8 +2,11 @@ import logging import os from pathlib import Path +import time from typing import Dict, Optional # Added Dict +import requests + logger = logging.getLogger(__name__) FIREWORKS_CONFIG_DIR = Path.home() / ".fireworks" @@ -36,7 +39,19 @@ def _parse_simple_auth_file(file_path: Path) -> Dict[str, str]: ): value = value[1:-1] - if key in ["api_key", "account_id"] and value: + if key in [ + "api_key", + "account_id", + "api_base", + # OAuth2-related keys + "issuer", + "client_id", + "access_token", + "refresh_token", + "expires_at", + "scope", + "token_type", + ] and value: creds[key] = value except Exception as e: logger.warning(f"Error during simple parsing of {file_path}: {e}") @@ -142,15 +157,135 @@ def get_fireworks_api_base() -> str: """ Retrieves the Fireworks API base URL. - The base URL is sourced from the FIREWORKS_API_BASE environment variable. - If not set, it defaults to "https://api.fireworks.ai". + The base URL is sourced in the following order: + 1. FIREWORKS_API_BASE environment variable. + 2. 'api_base' from the [fireworks] section of ~/.fireworks/auth.ini (or simple key=val). + 3. Defaults to "https://api.fireworks.ai". Returns: The API base URL. """ - api_base = os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") - if os.environ.get("FIREWORKS_API_BASE"): + env_api_base = os.environ.get("FIREWORKS_API_BASE") + if env_api_base: logger.debug("Using FIREWORKS_API_BASE from environment variable.") - else: - logger.debug(f"FIREWORKS_API_BASE not set in environment, defaulting to {api_base}.") - return api_base + return env_api_base + + file_api_base = _get_credential_from_config_file("api_base") + if file_api_base: + logger.debug("Using api_base from auth.ini configuration.") + return file_api_base + + default_base = "https://api.fireworks.ai" + logger.debug(f"FIREWORKS_API_BASE not set; defaulting to {default_base}.") + return default_base + + +def _get_from_env_or_file(key_name: str) -> Optional[str]: + # 1. Check env + env_val = os.environ.get(key_name.upper()) + if env_val: + return env_val + # 2. Check config file + return _get_credential_from_config_file(key_name.lower()) + + +def _write_auth_config(updates: Dict[str, str]) -> None: + """Merge-write simple key=value pairs into AUTH_INI_FILE preserving existing values.""" + FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True) + existing = _parse_simple_auth_file(AUTH_INI_FILE) + existing.update({k: v for k, v in updates.items() if v is not None}) + lines = [f"{k}={v}" for k, v in existing.items()] + AUTH_INI_FILE.write_text("\n".join(lines) + "\n") + try: + os.chmod(AUTH_INI_FILE, 0o600) + except Exception: + pass + + +def _discover_oidc(issuer: str) -> Dict[str, str]: + """Fetch OIDC discovery doc. Returns empty dict on failure.""" + try: + url = issuer.rstrip("/") + "/.well-known/openid-configuration" + resp = requests.get(url, timeout=10) + if resp.ok: + return resp.json() + except Exception: + return {} + return {} + + +def _refresh_oauth_token_if_needed() -> Optional[str]: + """Refresh OAuth access token if expired and refresh token available. Returns current/new token or None.""" + cfg = _parse_simple_auth_file(AUTH_INI_FILE) + access_token = cfg.get("access_token") + refresh_token = cfg.get("refresh_token") + expires_at_str = cfg.get("expires_at") + issuer = cfg.get("issuer") or os.environ.get("FIREWORKS_OIDC_ISSUER") + client_id = cfg.get("client_id") or os.environ.get("FIREWORKS_OAUTH_CLIENT_ID") + + # If we have no expiry, just return access token (best effort) + if not refresh_token or not issuer or not client_id: + return access_token + + now = int(time.time()) + try: + expires_at = int(expires_at_str) if expires_at_str else None + except ValueError: + expires_at = None + + # If not expired (with 60s buffer), return current token + if access_token and expires_at and expires_at - 60 > now: + return access_token + + # Attempt refresh + discovery = _discover_oidc(issuer) + token_endpoint = discovery.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + } + try: + resp = requests.post(token_endpoint, data=data, timeout=15) + if not resp.ok: + logger.warning(f"OAuth token refresh failed: {resp.status_code} {resp.text[:200]}") + return access_token + tok = resp.json() + new_access = tok.get("access_token") + new_refresh = tok.get("refresh_token") or refresh_token + expires_in = tok.get("expires_in") + new_expires_at = str(now + int(expires_in)) if expires_in else expires_at_str + _write_auth_config( + { + "access_token": new_access, + "refresh_token": new_refresh, + "expires_at": new_expires_at, + "token_type": tok.get("token_type") or cfg.get("token_type") or "Bearer", + "scope": tok.get("scope") or cfg.get("scope") or "", + } + ) + return new_access or access_token + except Exception as e: + logger.debug(f"Exception during oauth refresh: {e}") + return access_token + + +def get_auth_bearer() -> Optional[str]: + """Return a bearer token to use in Authorization. + + Priority: + 1. FIREWORKS_ACCESS_TOKEN env + 2. FIREWORKS_API_KEY env + 3. Refreshed OAuth access_token from auth.ini (if present) + 4. api_key from auth.ini + """ + env_access = os.environ.get("FIREWORKS_ACCESS_TOKEN") + if env_access: + return env_access + env_key = os.environ.get("FIREWORKS_API_KEY") + if env_key: + return env_key + refreshed = _refresh_oauth_token_if_needed() + if refreshed: + return refreshed + return _get_credential_from_config_file("api_key") diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 24307eb0..bdabe187 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -28,6 +28,7 @@ from .cli_commands.logs import logs_command from .cli_commands.preview import preview_command from .cli_commands.run_eval_cmd import hydra_cli_entry_point +from .cli_commands.login import login_command def parse_args(args=None): @@ -37,6 +38,30 @@ def parse_args(args=None): subparsers = parser.add_subparsers(dest="command", help="Command to run") + # Login command + login_parser = subparsers.add_parser( + "login", help="Sign in to Fireworks via API key or OAuth2 device flow" + ) + # API key flow + login_parser.add_argument("--api-key", help="Fireworks API key (prompted if not provided)") + # OAuth2 flow toggles + login_parser.add_argument("--oauth", action="store_true", help="Use OAuth2 device flow (like firectl)") + login_parser.add_argument("--issuer", help="OIDC issuer URL (e.g., https://auth.fireworks.ai)") + login_parser.add_argument("--client-id", help="OAuth2 public client id for device flow") + login_parser.add_argument( + "--scope", + help="OAuth2 scopes (default: 'openid offline_access email profile')", + ) + login_parser.add_argument( + "--open-browser", action="store_true", help="Attempt to open the verification URL in a browser" + ) + # Common options + login_parser.add_argument("--account-id", help="Fireworks Account ID to associate with this login") + login_parser.add_argument("--api-base", help="Custom API base (defaults to https://api.fireworks.ai)") + vgroup = login_parser.add_mutually_exclusive_group() + vgroup.add_argument("--validate", action="store_true", help="Validate account with a test API call (API key flow)") + vgroup.add_argument("--no-validate", action="store_true", help="Do not validate; just write the file") + # Preview command preview_parser = subparsers.add_parser("preview", help="Preview an evaluator with sample data") preview_parser.add_argument( @@ -338,6 +363,10 @@ def main(): if args.command == "preview": return preview_command(args) + elif args.command == "login": + # translate mutually exclusive group into a single boolean + setattr(args, "validate", bool(getattr(args, "validate", False) and not getattr(args, "no_validate", False))) + return login_command(args) elif args.command == "deploy": return deploy_command(args) elif args.command == "deploy-mcp": diff --git a/eval_protocol/cli_commands/common.py b/eval_protocol/cli_commands/common.py index 1490f704..22387722 100644 --- a/eval_protocol/cli_commands/common.py +++ b/eval_protocol/cli_commands/common.py @@ -7,6 +7,8 @@ import os from typing import Any, Dict, Iterator, List, Optional +from eval_protocol.auth import get_auth_bearer + logger = logging.getLogger(__name__) @@ -42,13 +44,22 @@ def setup_logging(verbose=False, debug=False): def check_environment(): - """Check if required environment variables are set for general commands.""" - if not os.environ.get("FIREWORKS_API_KEY"): - logger.warning("FIREWORKS_API_KEY environment variable is not set.") - logger.warning("This is required for API calls. Set this variable before running the command.") - logger.warning("Example: FIREWORKS_API_KEY=$DEV_FIREWORKS_API_KEY reward-kit [command]") - return False - return True + """Check if credentials are available for non-LLM API calls. + + Accepts either FIREWORKS_API_KEY or an OAuth bearer (FIREWORKS_ACCESS_TOKEN or tokens in auth.ini). + LLM calls elsewhere still explicitly require FIREWORKS_API_KEY. + """ + if os.environ.get("FIREWORKS_API_KEY"): + return True + bearer = get_auth_bearer() + if bearer: + if not os.environ.get("FIREWORKS_API_KEY"): + logger.info( + "Using OAuth bearer for non-LLM API calls. Note: LLM/model calls still require FIREWORKS_API_KEY." + ) + return True + logger.warning("No credentials found. Set FIREWORKS_API_KEY or login via OAuth: eval-protocol login --oauth ...") + return False def check_agent_environment(test_mode=False): diff --git a/eval_protocol/cli_commands/login.py b/eval_protocol/cli_commands/login.py new file mode 100644 index 00000000..793ae317 --- /dev/null +++ b/eval_protocol/cli_commands/login.py @@ -0,0 +1,383 @@ +import getpass +import logging +import os +import time +import webbrowser +from pathlib import Path +from typing import Dict, Optional, Tuple +import secrets +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer +from urllib.parse import urlencode, urlparse, parse_qs + +import requests + +from eval_protocol.auth import ( + AUTH_INI_FILE, + FIREWORKS_CONFIG_DIR, + get_fireworks_api_base, +) + +logger = logging.getLogger(__name__) + + +def _write_auth_file_kv(entries: Dict[str, str]) -> Path: + """Write key=value entries to ~/.fireworks/auth.ini with 600 perms.""" + FIREWORKS_CONFIG_DIR.mkdir(parents=True, exist_ok=True) + # Merge with any existing keys + existing: Dict[str, str] = {} + try: + with open(AUTH_INI_FILE, "r") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#") or line.startswith(";"): + continue + if "=" in line: + k, v = line.split("=", 1) + existing[k.strip()] = v.strip() + except FileNotFoundError: + pass + existing.update({k: v for k, v in entries.items() if v is not None}) + AUTH_INI_FILE.write_text("\n".join([f"{k}={v}" for k, v in existing.items()]) + "\n") + try: + os.chmod(AUTH_INI_FILE, 0o600) + except Exception: + pass + return AUTH_INI_FILE + + +def _validate_account(api_key: str, account_id: str, api_base: Optional[str]) -> bool: + """Validate API key against a specific account id using Fireworks REST API. + + Performs GET /v1/accounts/{account_id}. Returns True on HTTP 200, False otherwise. + """ + base = (api_base or get_fireworks_api_base()).rstrip("/") + url = f"{base}/v1/accounts/{account_id}" + headers = {"Authorization": f"Bearer {api_key}"} + try: + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code == 200: + logger.info("Successfully validated credentials against Fireworks API.") + return True + else: + logger.warning( + f"Validation failed (status {resp.status_code}). Response: {resp.text[:200]}" + ) + return False + except requests.exceptions.RequestException as e: + logger.warning(f"Network error during validation: {e}") + return False + + +def _discover_oidc(issuer: str) -> Dict[str, str]: + try: + resp = requests.get(issuer.rstrip("/") + "/.well-known/openid-configuration", timeout=10) + if resp.ok: + return resp.json() + except Exception: + return {} + return {} + + +def _oauth_device_flow(issuer: str, client_id: str, scope: str, open_browser: bool) -> Optional[Dict[str, str]]: + """Perform OAuth2 Device Authorization Grant and return token dict {access_token, refresh_token, expires_in, token_type, scope} or None.""" + meta = _discover_oidc(issuer) + device_endpoint = meta.get("device_authorization_endpoint") or issuer.rstrip("/") + "/oauth/device/code" + token_endpoint = meta.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + + # 1) Request device code + data = {"client_id": client_id, "scope": scope} + resp = requests.post(device_endpoint, data=data, timeout=15) + if not resp.ok: + logger.error(f"Device code request failed: {resp.status_code} {resp.text[:200]}") + return None + d = resp.json() + device_code = d.get("device_code") + verification_uri = d.get("verification_uri_complete") or d.get("verification_uri") + user_code = d.get("user_code") + interval = int(d.get("interval", 5)) + expires_in = int(d.get("expires_in", 600)) + + if not device_code or not verification_uri: + logger.error("Invalid device authorization response; missing device_code or verification_uri.") + return None + + logger.info("To authorize, visit this URL and enter the code if prompted:") + logger.info(verification_uri) + if user_code: + logger.info(f"User code: {user_code}") + if open_browser: + try: + webbrowser.open(verification_uri) + except Exception: + pass + + # 2) Poll token endpoint + start = time.time() + while True: + if time.time() - start > expires_in: + logger.error("Device code expired before authorization completed.") + return None + time.sleep(interval) + payload = { + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code, + "client_id": client_id, + } + t = requests.post(token_endpoint, data=payload, timeout=15) + if t.status_code == 200: + return t.json() + try: + err = t.json().get("error") + except Exception: + err = None + if err in ("authorization_pending", "slow_down"): + if err == "slow_down": + interval += 5 + continue + elif err == "access_denied": + logger.error("Access denied during device authorization.") + return None + else: + logger.warning(f"Unexpected token polling response: {t.status_code} {t.text[:200]}") + continue + + +def _oauth_browser_flow(issuer: str, client_id: str, scope: str) -> Optional[Dict[str, str]]: + """Perform OAuth2 Authorization Code flow using a local redirect server.""" + meta = _discover_oidc(issuer) + auth_endpoint = meta.get("authorization_endpoint") or issuer.rstrip("/") + "/oauth/authorize" + token_endpoint = meta.get("token_endpoint") or issuer.rstrip("/") + "/oauth/token" + + # Start temporary local server + state = secrets.token_urlsafe(24) + code_holder: Dict[str, Optional[str]] = {"code": None, "error": None} + + class Handler(BaseHTTPRequestHandler): + def do_GET(self): # type: ignore + try: + parsed = urlparse(self.path) + if parsed.path != "/": + self.send_error(404) + return + params = parse_qs(parsed.query) + got_state = params.get("state", [""])[0] + if got_state != state: + code_holder["error"] = f"state_mismatch" + elif "error" in params: + code_holder["error"] = params.get("error", [""])[0] + else: + code_holder["code"] = params.get("code", [None])[0] + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + self.wfile.write( + b"