diff --git a/src/amp/admin/client.py b/src/amp/admin/client.py index 27a6d61..e6469c0 100644 --- a/src/amp/admin/client.py +++ b/src/amp/admin/client.py @@ -4,6 +4,7 @@ with the Amp Admin API over HTTP. """ +import os from typing import Optional import httpx @@ -19,31 +20,61 @@ class AdminClient: Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') - auth_token: Optional Bearer token for authentication + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth Example: + >>> # Use amp auth from file + >>> client = AdminClient('http://localhost:8080', auth=True) + >>> + >>> # Use manual token + >>> client = AdminClient('http://localhost:8080', auth_token='your-token') + >>> + >>> # Use environment variable + >>> # export AMP_AUTH_TOKEN="eyJhbGci..." >>> client = AdminClient('http://localhost:8080') - >>> datasets = client.datasets.list_all() """ - def __init__(self, base_url: str, auth_token: Optional[str] = None): + def __init__(self, base_url: str, auth_token: Optional[str] = None, auth: bool = False): """Initialize Admin API client. Args: base_url: Base URL for Admin API (e.g., 'http://localhost:8080') auth_token: Optional Bearer token for authentication + auth: If True, load auth token from ~/.amp/cache + + Raises: + ValueError: If both auth=True and auth_token are provided """ + if auth and auth_token: + raise ValueError('Cannot specify both auth=True and auth_token. Choose one authentication method.') + self.base_url = base_url.rstrip('/') - # Build headers - headers = {} + # Resolve auth token provider with priority: explicit param > env var > auth file + self._get_token = None if auth_token: - headers['Authorization'] = f'Bearer {auth_token}' - - # Create HTTP client + # Priority 1: Explicit auth_token parameter (static token) + self._get_token = lambda: auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + self._get_token = lambda: env_token + elif auth: + # Priority 3: Load from ~/.amp-cli-config/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + self._get_token = auth_service.get_token # Callable that auto-refreshes + + # Create HTTP client (no auth header yet - will be added per-request) self._http = httpx.Client( base_url=self.base_url, - headers=headers, timeout=30.0, follow_redirects=True, ) @@ -66,6 +97,12 @@ def _request( Raises: AdminAPIError: If the API returns an error response """ + # Add auth header dynamically (auto-refreshes if needed) + headers = kwargs.get('headers', {}) + if self._get_token: + headers['Authorization'] = f'Bearer {self._get_token()}' + kwargs['headers'] = headers + response = self._http.request(method, path, json=json, params=params, **kwargs) # Handle error responses diff --git a/src/amp/auth/__init__.py b/src/amp/auth/__init__.py new file mode 100644 index 0000000..d648a6f --- /dev/null +++ b/src/amp/auth/__init__.py @@ -0,0 +1,11 @@ +"""Authentication module for amp Python client. + +Provides Privy authentication support compatible with the TypeScript CLI. +Reads and manages auth tokens from ~/.amp/cache. +""" + +from .device_flow import interactive_device_login +from .models import AuthStorage, RefreshTokenResponse +from .service import AuthService + +__all__ = ['AuthService', 'AuthStorage', 'RefreshTokenResponse', 'interactive_device_login'] diff --git a/src/amp/auth/device_flow.py b/src/amp/auth/device_flow.py new file mode 100644 index 0000000..33bb659 --- /dev/null +++ b/src/amp/auth/device_flow.py @@ -0,0 +1,289 @@ +"""OAuth2 Device Authorization Flow for Privy authentication. + +Implements the device authorization grant flow with PKCE for secure authentication. +Matches the TypeScript CLI implementation. +""" + +import base64 +import hashlib +import secrets +import time +import webbrowser +from typing import Optional, Tuple + +import httpx +from pydantic import BaseModel, Field + +from .models import AuthStorage +from .service import AUTH_PLATFORM_URL + + +class DeviceAuthorizationResponse(BaseModel): + """Response from device authorization endpoint.""" + + device_code: str = Field(..., description='Device verification code for polling') + user_code: str = Field(..., description='Code for user to enter in browser') + verification_uri: str = Field(..., description='URL where user enters the code') + expires_in: int = Field(..., description='Seconds until device code expires') + interval: int = Field(..., description='Minimum polling interval in seconds') + + +class DeviceTokenResponse(BaseModel): + """Response from device token endpoint (success case).""" + + access_token: str = Field(..., description='Access token for authenticated requests') + refresh_token: str = Field(..., description='Refresh token for renewing access') + user_id: str = Field(..., description='Authenticated user ID') + user_accounts: list[str] = Field(..., description='List of user accounts/wallets') + expires_in: int = Field(..., description='Seconds until token expires') + + +class DeviceTokenPendingResponse(BaseModel): + """Response when authorization is still pending.""" + + error: str = Field('authorization_pending', description='Error code') + + +class DeviceTokenExpiredResponse(BaseModel): + """Response when device code has expired.""" + + error: str = Field('expired_token', description='Error code') + + +def generate_pkce_pair() -> Tuple[str, str]: + """Generate PKCE code_verifier and code_challenge. + + Returns: + Tuple of (code_verifier, code_challenge) + """ + # Generate cryptographically random code_verifier + # Must be 43-128 characters using unreserved chars [A-Za-z0-9-._~] + code_verifier_bytes = secrets.token_bytes(32) + code_verifier = base64.urlsafe_b64encode(code_verifier_bytes).decode('utf-8').rstrip('=') + + # Generate code_challenge = BASE64URL(SHA256(code_verifier)) + challenge_bytes = hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=') + + return code_verifier, code_challenge + + +def request_device_authorization(http_client: httpx.Client) -> Tuple[DeviceAuthorizationResponse, str]: + """Request device authorization from auth platform. + + Args: + http_client: HTTP client to use for request + + Returns: + Tuple of (DeviceAuthorizationResponse, code_verifier) + + Raises: + httpx.HTTPStatusError: If request fails + ValueError: If response is invalid + """ + # Generate PKCE parameters + code_verifier, code_challenge = generate_pkce_pair() + + # Request device authorization + url = f'{AUTH_PLATFORM_URL}api/v1/device/authorize' + response = http_client.post( + url, json={'code_challenge': code_challenge, 'code_challenge_method': 'S256'}, timeout=30.0 + ) + + if response.status_code != 200: + raise ValueError(f'Device authorization failed: {response.status_code} - {response.text}') + + device_auth = DeviceAuthorizationResponse.model_validate(response.json()) + return device_auth, code_verifier + + +def poll_for_token(http_client: httpx.Client, device_code: str, code_verifier: str) -> Optional[DeviceTokenResponse]: + """Poll device token endpoint once. + + Args: + http_client: HTTP client to use for request + device_code: Device code from authorization response + code_verifier: PKCE code verifier + + Returns: + DeviceTokenResponse if auth complete, None if still pending + + Raises: + ValueError: If device code expired or other error + """ + url = f'{AUTH_PLATFORM_URL}api/v1/device/token' + response = http_client.get(url, params={'device_code': device_code, 'code_verifier': code_verifier}, timeout=10.0) + + data = response.json() + + # Check for error responses + if 'error' in data: + error = data['error'] + if error == 'authorization_pending': + return None # Still pending + elif error == 'expired_token': + raise ValueError('Device code expired. Please try again.') + else: + raise ValueError(f'Token polling error: {error}') + + # Success - parse token response + return DeviceTokenResponse.model_validate(data) + + +def poll_until_authenticated( + http_client: httpx.Client, + device_code: str, + code_verifier: str, + interval: int, + expires_in: int, + on_poll: Optional[callable] = None, + verbose: bool = False, +) -> DeviceTokenResponse: + """Poll for token until authenticated or timeout. + + Args: + http_client: HTTP client to use for requests + device_code: Device code from authorization + code_verifier: PKCE code verifier + interval: Minimum polling interval in seconds + expires_in: Seconds until device code expires + on_poll: Optional callback called on each poll attempt + + Returns: + DeviceTokenResponse when authentication completes + + Raises: + ValueError: If authentication times out or fails + """ + start_time = time.time() + poll_count = 0 + max_polls = int(expires_in / interval) + 5 # Add some buffer + + while poll_count < max_polls: + elapsed = time.time() - start_time + if elapsed > expires_in: + raise ValueError('Authentication timed out. Please try again.') + + if on_poll: + on_poll(poll_count, elapsed) + + # Poll for token + try: + token_response = poll_for_token(http_client, device_code, code_verifier) + if token_response: + return token_response + except ValueError as e: + if 'expired' in str(e).lower(): + raise + # Other errors, log and continue polling + if verbose: + print(f'\n⚠ Polling error (will retry): {e}') + pass + except Exception as e: + # Log unexpected errors + if verbose: + print(f'\n⚠ Unexpected error (will retry): {type(e).__name__}: {e}') + pass + + # Wait before next poll + time.sleep(interval) + poll_count += 1 + + raise ValueError('Authentication timed out. Please try again.') + + +def open_browser(url: str) -> bool: + """Open URL in browser. + + Args: + url: URL to open + + Returns: + True if browser opened successfully + """ + try: + webbrowser.open(url) + return True + except Exception: + return False + + +def interactive_device_login(verbose: bool = True, auto_open_browser: bool = True) -> AuthStorage: + """Perform interactive device authorization flow. + + Args: + verbose: Print progress messages + auto_open_browser: Automatically open browser for user + + Returns: + AuthStorage with tokens + + Raises: + ValueError: If authentication fails + """ + http_client = httpx.Client() + + try: + # Step 1: Request device authorization + if verbose: + print('🔐 Starting authentication...\n') + + device_auth, code_verifier = request_device_authorization(http_client) + + # Step 2: Display user code and open browser + if verbose: + print(f'📱 Verification Code: {device_auth.user_code}') + print(f'🌐 Verification URL: {device_auth.verification_uri}\n') + + if auto_open_browser: + if verbose: + print('Opening browser...') + if open_browser(device_auth.verification_uri): + if verbose: + print('✓ Browser opened') + else: + if verbose: + print('✗ Could not open browser automatically') + print(f' Please open: {device_auth.verification_uri}') + + if verbose: + print(f'\n⏳ Waiting for authentication (expires in {device_auth.expires_in}s)...') + print(' Complete the authentication in your browser.\n') + + # Step 3: Poll for token + spinner_frames = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + + def poll_callback(count: int, elapsed: float): + if verbose: + spinner = spinner_frames[count % len(spinner_frames)] + print(f'\r{spinner} Polling... ({int(elapsed)}s elapsed)', end='', flush=True) + + token_response = poll_until_authenticated( + http_client, + device_auth.device_code, + code_verifier, + device_auth.interval, + device_auth.expires_in, + poll_callback, + verbose=verbose, + ) + + if verbose: + print('\r✓ Authentication successful! \n') + + # Step 4: Create auth storage + now_ms = int(time.time() * 1000) + expiry_ms = now_ms + (token_response.expires_in * 1000) + + auth_storage = AuthStorage( + accessToken=token_response.access_token, + refreshToken=token_response.refresh_token, + userId=token_response.user_id, + accounts=token_response.user_accounts, + expiry=expiry_ms, + ) + + return auth_storage + + finally: + http_client.close() diff --git a/src/amp/auth/models.py b/src/amp/auth/models.py new file mode 100644 index 0000000..d60f7c5 --- /dev/null +++ b/src/amp/auth/models.py @@ -0,0 +1,41 @@ +"""Data models for Privy authentication. + +Models match the TypeScript CLI schema for compatibility. +""" + +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class AuthStorage(BaseModel): + """Auth storage schema for ~/.amp/cache/amp_cli_auth file. + + Matches the TypeScript AuthStorageSchema so they can share auth. + """ + + accessToken: str = Field(..., description='Access token for authenticated requests') + refreshToken: str = Field(..., description='Refresh token for renewing access') + userId: str = Field(..., description="User's Privy DID (format: did:privy:c...)") + accounts: Optional[List[str]] = Field(None, description='List of connected accounts/wallets') + expiry: Optional[int] = Field(None, description='Token expiry timestamp in milliseconds') + + +class RefreshTokenResponseUser(BaseModel): + """User information in the refresh token response.""" + + id: str = Field(..., description='User ID (Privy DID)') + accounts: List[str] = Field(..., description='List of connected accounts/wallets') + + +class RefreshTokenResponse(BaseModel): + """Response from the token refresh endpoint. + + Matches the TypeScript RefreshTokenResponse schema. + """ + + token: str = Field(..., description='The refreshed access token') + refresh_token: Optional[str] = Field(None, description='New refresh token (if rotated)') + session_update_action: str = Field(..., description='Session update action') + expires_in: int = Field(..., description='Seconds until token expires') + user: RefreshTokenResponseUser = Field(..., description='User information') diff --git a/src/amp/auth/service.py b/src/amp/auth/service.py new file mode 100644 index 0000000..b0c83b0 --- /dev/null +++ b/src/amp/auth/service.py @@ -0,0 +1,254 @@ +"""Auth service for managing Privy authentication. + +Handles loading, refreshing, and persisting auth tokens from ~/.amp/cache. +Compatible with the TypeScript CLI authentication system. +""" + +import json +import time +from pathlib import Path +from typing import Optional + +import httpx + +from .models import AuthStorage, RefreshTokenResponse + +# Auth platform URL (matches TypeScript implementation) +AUTH_PLATFORM_URL = 'https://auth.amp.thegraph.com/' + +# Storage location (matches TypeScript implementation) +# TypeScript CLI uses: ~/.amp/cache/amp_cli_auth (directory with file inside) +AUTH_CONFIG_DIR = Path.home() / '.amp' / 'cache' +AUTH_CONFIG_FILE = AUTH_CONFIG_DIR / 'amp_cli_auth' + + +class AuthService: + """Service for managing Privy authentication tokens. + + Loads tokens from ~/.amp/cache (shared with TypeScript CLI), + automatically refreshes expired tokens, and persists updates. + + Example: + >>> auth = AuthService() + >>> if auth.is_authenticated(): + ... token = auth.get_token() # Auto-refreshes if needed + """ + + def __init__(self, config_path: Optional[Path] = None): + """Initialize auth service. + + Args: + config_path: Optional custom path to config file (defaults to ~/.amp/cache/amp_cli_auth) + """ + self.config_path = config_path or AUTH_CONFIG_FILE + self._http = httpx.Client(timeout=30.0) + + def is_authenticated(self) -> bool: + """Check if user is authenticated. + + Returns: + True if valid auth exists in ~/.amp/cache + """ + try: + auth = self.load_auth() + return auth is not None + except Exception: + return False + + def load_auth(self) -> Optional[AuthStorage]: + """Load auth from ~/.amp/cache/amp_cli_auth file. + + Returns: + AuthStorage if found, None if not authenticated + + Raises: + FileNotFoundError: If config file doesn't exist + json.JSONDecodeError: If config file is invalid JSON + ValueError: If auth data is invalid + """ + if not self.config_path.exists(): + return None + + with open(self.config_path) as f: + auth_data = json.load(f) + + return AuthStorage.model_validate(auth_data) + + def save_auth(self, auth: AuthStorage) -> None: + """Save auth to ~/.amp/cache/amp_cli_auth file. + + Args: + auth: Auth data to persist + + Raises: + IOError: If unable to write to config file + """ + # Ensure directory exists + self.config_path.parent.mkdir(parents=True, exist_ok=True) + + # Write auth data directly to file (no wrapper object) + with open(self.config_path, 'w') as f: + json.dump(auth.model_dump(mode='json', exclude_none=False), f, indent=2) + + def get_token(self) -> str: + """Get valid access token, refreshing if needed. + + Returns: + Valid access token string + + Raises: + FileNotFoundError: If not authenticated (no ~/.amp/cache) + ValueError: If auth data is invalid or refresh fails + """ + auth = self.load_auth() + if not auth: + raise FileNotFoundError( + 'Not authenticated. Please run authentication first.\n' + "Use the TypeScript CLI 'amp auth login' or authenticate via the Python client." + ) + + # Check if we need to refresh the token + needs_refresh = self._needs_refresh(auth) + + if needs_refresh: + auth = self.refresh_token(auth) + + return auth.accessToken + + def _needs_refresh(self, auth: AuthStorage) -> bool: + """Check if token needs to be refreshed. + + Token needs refresh if: + - Missing expiry field (old format, need to refresh to populate) + - Missing accounts field (old format, need to refresh to populate) + - Token is expired + - Token is expiring within 5 minutes + + Args: + auth: Auth storage to check + + Returns: + True if token needs refresh + """ + # Missing expiry or accounts - refresh to populate + if auth.expiry is None or auth.accounts is None: + return True + + # Get current time in milliseconds (matching TypeScript) + now_ms = int(time.time() * 1000) + + # Token is expired + if auth.expiry < now_ms: + return True + + # Token is expiring within 5 minutes + five_minutes_ms = 5 * 60 * 1000 + if auth.expiry - now_ms <= five_minutes_ms: + return True + + return False + + def refresh_token(self, auth: AuthStorage) -> AuthStorage: + """Refresh an expired or expiring access token. + + Args: + auth: Current auth storage with refresh token + + Returns: + Updated auth storage with new tokens + + Raises: + httpx.HTTPStatusError: If refresh request fails + ValueError: If response validation fails or user ID mismatch + """ + # Build refresh request (matches TypeScript implementation) + url = f'{AUTH_PLATFORM_URL}api/v1/auth/refresh' + headers = { + 'Authorization': f'Bearer {auth.accessToken}', + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + body = {'refresh_token': auth.refreshToken, 'user_id': auth.userId} + + # Make request + response = self._http.post(url, headers=headers, json=body) + + # Handle errors + if response.status_code == 401 or response.status_code == 403: + raise ValueError('Token refresh failed: Authentication expired. Please log in again.') + + if response.status_code == 429: + retry_after = response.headers.get('retry-after', '60') + raise ValueError(f'Token refresh rate limited. Retry after {retry_after} seconds.') + + if response.status_code != 200: + error_msg = 'Failed to refresh token' + try: + error_data = response.json() + if 'error_description' in error_data: + error_msg = error_data['error_description'] + except Exception: + pass + raise ValueError(f'{error_msg} (status: {response.status_code})') + + # Parse response + refresh_response = RefreshTokenResponse.model_validate(response.json()) + + # Validate user ID matches (security check) + if refresh_response.user.id != auth.userId: + raise ValueError(f'User ID mismatch after refresh. Expected {auth.userId}, got {refresh_response.user.id}') + + # Calculate new expiry + now_ms = int(time.time() * 1000) + expiry_ms = now_ms + (refresh_response.expires_in * 1000) + + # Create updated auth storage + refreshed_auth = AuthStorage( + accessToken=refresh_response.token, + refreshToken=refresh_response.refresh_token or auth.refreshToken, + userId=refresh_response.user.id, + accounts=refresh_response.user.accounts, + expiry=expiry_ms, + ) + + # Persist updated tokens + self.save_auth(refreshed_auth) + + return refreshed_auth + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self._http.close() + + def login(self, verbose: bool = True, auto_open_browser: bool = True) -> None: + """Perform interactive browser-based login. + + Opens browser for OAuth2 device authorization flow with PKCE. + Saves authentication tokens to ~/.amp/cache/amp_cli_auth. + + Args: + verbose: Print progress messages + auto_open_browser: Automatically open browser + + Raises: + ValueError: If authentication fails + + Example: + >>> auth = AuthService() + >>> auth.login() # Opens browser for authentication + >>> # Auth tokens saved to ~/.amp/cache/amp_cli_auth + """ + from .device_flow import interactive_device_login + + # Perform device authorization flow + auth_storage = interactive_device_login(verbose=verbose, auto_open_browser=auto_open_browser) + + # Save to config file + self.save_auth(auth_storage) + + if verbose: + print(f'✓ Authentication saved to {self.config_path}') diff --git a/src/amp/client.py b/src/amp/client.py index 2af91ee..760bfc7 100644 --- a/src/amp/client.py +++ b/src/amp/client.py @@ -1,9 +1,11 @@ import logging +import os from typing import Dict, Iterator, List, Optional, Union import pyarrow as pa from google.protobuf.any_pb2 import Any from pyarrow import flight +from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory from . import FlightSql_pb2 from .config.connection_manager import ConnectionManager @@ -19,6 +21,38 @@ ) +class AuthMiddleware(ClientMiddleware): + """Flight middleware to add Bearer token authentication header.""" + + def __init__(self, get_token): + """Initialize auth middleware. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def sending_headers(self): + """Add Authorization header to outgoing requests.""" + return {'authorization': f'Bearer {self.get_token()}'} + + +class AuthMiddlewareFactory(ClientMiddlewareFactory): + """Factory for creating auth middleware instances.""" + + def __init__(self, get_token): + """Initialize auth middleware factory. + + Args: + get_token: Callable that returns the current access token + """ + self.get_token = get_token + + def start_call(self, info): + """Create auth middleware for each call.""" + return AuthMiddleware(self.get_token) + + class QueryBuilder: """Chainable query builder for data loading operations. @@ -237,17 +271,28 @@ class Client: url: Flight SQL URL (for backward compatibility, treated as query_url) query_url: Query endpoint URL via Flight SQL (e.g., 'grpc://localhost:1602') admin_url: Optional Admin API URL (e.g., 'http://localhost:8080') - auth_token: Optional Bearer token for Admin API authentication + auth_token: Optional Bearer token for authentication (highest priority) + auth: If True, load auth token from ~/.amp/cache (shared with TS CLI) + + Authentication Priority (highest to lowest): + 1. Explicit auth_token parameter + 2. AMP_AUTH_TOKEN environment variable + 3. auth=True - reads from ~/.amp/cache/amp_cli_auth Example: >>> # Query-only client (backward compatible) >>> client = Client(url='grpc://localhost:1602') >>> - >>> # Client with admin capabilities + >>> # Client with amp auth from file >>> client = Client( ... query_url='grpc://localhost:1602', - ... admin_url='http://localhost:8080' + ... admin_url='http://localhost:8080', + ... auth=True ... ) + >>> + >>> # Client with auth from environment variable + >>> # export AMP_AUTH_TOKEN="eyJhbGci..." + >>> client = Client(query_url='grpc://localhost:1602') """ def __init__( @@ -256,14 +301,39 @@ def __init__( query_url: Optional[str] = None, admin_url: Optional[str] = None, auth_token: Optional[str] = None, + auth: bool = False, ): # Backward compatibility: url parameter → query_url if url and not query_url: query_url = url + # Resolve auth token provider with priority: explicit param > env var > auth file + get_token = None + if auth_token: + # Priority 1: Explicit auth_token parameter (static token) + def get_token(): + return auth_token + elif os.getenv('AMP_AUTH_TOKEN'): + # Priority 2: AMP_AUTH_TOKEN environment variable (static token) + env_token = os.getenv('AMP_AUTH_TOKEN') + + def get_token(): + return env_token + elif auth: + # Priority 3: Load from ~/.amp/cache/amp_cli_auth (auto-refreshing) + from amp.auth import AuthService + + auth_service = AuthService() + get_token = auth_service.get_token # Callable that auto-refreshes + # Initialize Flight SQL client if query_url: - self.conn = flight.connect(query_url) + # Add auth middleware if token provider exists + if get_token: + middleware = [AuthMiddlewareFactory(get_token)] + self.conn = flight.connect(query_url, middleware=middleware) + else: + self.conn = flight.connect(query_url) else: raise ValueError('Either url or query_url must be provided for Flight SQL connection') @@ -276,7 +346,18 @@ def __init__( if admin_url: from amp.admin.client import AdminClient - self._admin_client = AdminClient(admin_url, auth_token) + # Pass auth=True if we have a get_token callable from auth file + # Otherwise pass the static token if available + if auth: + # Use auth file (auto-refreshing) + self._admin_client = AdminClient(admin_url, auth=True) + elif auth_token or os.getenv('AMP_AUTH_TOKEN'): + # Use static token + static_token = auth_token or os.getenv('AMP_AUTH_TOKEN') + self._admin_client = AdminClient(admin_url, auth_token=static_token) + else: + # No auth + self._admin_client = AdminClient(admin_url) else: self._admin_client = None diff --git a/tests/integration/admin/test_admin_client.py b/tests/integration/admin/test_admin_client.py index f8e7c31..82e7648 100644 --- a/tests/integration/admin/test_admin_client.py +++ b/tests/integration/admin/test_admin_client.py @@ -25,8 +25,7 @@ def test_admin_client_with_auth_token(self): """Test AdminClient with authentication token.""" client = AdminClient('http://localhost:8080', auth_token='test-token') - assert 'Authorization' in client._http.headers - assert client._http.headers['Authorization'] == 'Bearer test-token' + assert client._get_token() == 'test-token' @respx.mock def test_request_success(self): diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py index d13f8eb..78c2c17 100644 --- a/tests/integration/test_snowflake_loader.py +++ b/tests/integration/test_snowflake_loader.py @@ -67,7 +67,7 @@ def wait_for_snowpipe_data(loader, table_name, expected_count, max_wait=30, poll # Skip all Snowflake tests -# pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') +pytestmark = pytest.mark.skip(reason='Requires active Snowflake account - see module docstring for details') @pytest.fixture diff --git a/tests/unit/test_auth_service.py b/tests/unit/test_auth_service.py new file mode 100644 index 0000000..450a6e3 --- /dev/null +++ b/tests/unit/test_auth_service.py @@ -0,0 +1,345 @@ +"""Unit tests for Privy authentication service.""" + +import json +import time +from unittest.mock import Mock, patch + +import pytest + +from src.amp.auth.models import AuthStorage +from src.amp.auth.service import AuthService + + +@pytest.mark.unit +class TestAuthService: + """Test AuthService functionality.""" + + def test_is_authenticated_when_no_file(self, tmp_path): + """Test is_authenticated returns False when config file doesn't exist.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is False + + def test_is_authenticated_when_no_auth_data(self, tmp_path): + """Test is_authenticated returns False when config exists but has no auth data.""" + config_path = tmp_path / 'config.json' + config_path.write_text('{}') + + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is False + + def test_is_authenticated_when_auth_exists(self, tmp_path): + """Test is_authenticated returns True when valid auth exists.""" + config_path = tmp_path / 'amp_cli_auth' + auth_data = { + 'accessToken': 'test-access-token', + 'refreshToken': 'test-refresh-token', + 'userId': 'did:privy:test123', + 'accounts': ['0x123'], + 'expiry': int(time.time() * 1000) + 3600000, # 1 hour from now + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + + assert auth.is_authenticated() is True + + def test_load_auth_success(self, tmp_path): + """Test loading auth from config file.""" + config_path = tmp_path / 'amp_cli_auth' + auth_data = { + 'accessToken': 'test-access-token', + 'refreshToken': 'test-refresh-token', + 'userId': 'did:privy:test123', + 'accounts': ['0x123'], + 'expiry': 1234567890000, + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + result = auth.load_auth() + + assert result is not None + assert result.accessToken == 'test-access-token' + assert result.refreshToken == 'test-refresh-token' + assert result.userId == 'did:privy:test123' + assert result.accounts == ['0x123'] + assert result.expiry == 1234567890000 + + def test_load_auth_returns_none_when_missing(self, tmp_path): + """Test load_auth returns None when config doesn't exist.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + result = auth.load_auth() + + assert result is None + + def test_save_auth(self, tmp_path): + """Test saving auth to config file.""" + config_path = tmp_path / 'amp_cli_auth' + auth = AuthService(config_path=config_path) + + auth_storage = AuthStorage( + accessToken='new-access-token', + refreshToken='new-refresh-token', + userId='did:privy:test456', + accounts=['0x456'], + expiry=9876543210000, + ) + + auth.save_auth(auth_storage) + + # Verify file was created + assert config_path.exists() + + # Verify content + with open(config_path) as f: + saved_data = json.load(f) + + assert saved_data['accessToken'] == 'new-access-token' + assert saved_data['userId'] == 'did:privy:test456' + + def test_save_auth_overwrites_existing(self, tmp_path): + """Test saving auth overwrites existing auth data.""" + config_path = tmp_path / 'amp_cli_auth' + initial_data = { + 'accessToken': 'old-token', + 'refreshToken': 'old-refresh', + 'userId': 'did:privy:old', + 'accounts': [], + 'expiry': 12345, + } + config_path.write_text(json.dumps(initial_data)) + + auth = AuthService(config_path=config_path) + auth_storage = AuthStorage( + accessToken='new-token', + refreshToken='new-refresh', + userId='did:privy:new', + accounts=['0x123'], + expiry=67890, + ) + + auth.save_auth(auth_storage) + + # Verify new data saved + with open(config_path) as f: + saved_data = json.load(f) + + assert saved_data['accessToken'] == 'new-token' + assert saved_data['userId'] == 'did:privy:new' + assert saved_data['accounts'] == ['0x123'] + + def test_get_token_raises_when_not_authenticated(self, tmp_path): + """Test get_token raises error when not authenticated.""" + config_path = tmp_path / 'nonexistent.json' + auth = AuthService(config_path=config_path) + + with pytest.raises(FileNotFoundError, match='Not authenticated'): + auth.get_token() + + def test_get_token_returns_valid_token(self, tmp_path): + """Test get_token returns token when valid and not expired.""" + config_path = tmp_path / 'amp_cli_auth' + future_expiry = int(time.time() * 1000) + 3600000 # 1 hour from now + auth_data = { + 'accessToken': 'valid-token', + 'refreshToken': 'refresh-token', + 'userId': 'did:privy:test', + 'accounts': ['0x123'], + 'expiry': future_expiry, + } + config_path.write_text(json.dumps(auth_data)) + + auth = AuthService(config_path=config_path) + token = auth.get_token() + + assert token == 'valid-token' + + +@pytest.mark.unit +class TestAuthServiceRefresh: + """Test token refresh functionality.""" + + def test_needs_refresh_when_missing_expiry(self): + """Test needs refresh when expiry field is missing.""" + auth = AuthService() + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=None, # Missing expiry + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_missing_accounts(self): + """Test needs refresh when accounts field is missing.""" + auth = AuthService() + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=None, # Missing accounts + expiry=int(time.time() * 1000) + 3600000, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_expired(self): + """Test needs refresh when token is expired.""" + auth = AuthService() + past_expiry = int(time.time() * 1000) - 1000 # 1 second ago + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=past_expiry, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_needs_refresh_when_expiring_soon(self): + """Test needs refresh when token expires within 5 minutes.""" + auth = AuthService() + soon_expiry = int(time.time() * 1000) + (4 * 60 * 1000) # 4 minutes from now + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=soon_expiry, + ) + + assert auth._needs_refresh(auth_storage) is True + + def test_does_not_need_refresh_when_valid(self): + """Test does not need refresh when token is valid and not expiring soon.""" + auth = AuthService() + future_expiry = int(time.time() * 1000) + (10 * 60 * 1000) # 10 minutes from now + auth_storage = AuthStorage( + accessToken='token', + refreshToken='refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=future_expiry, + ) + + assert auth._needs_refresh(auth_storage) is False + + @patch('httpx.Client.post') + def test_refresh_token_success(self, mock_post, tmp_path): + """Test successful token refresh.""" + config_path = tmp_path / 'config.json' + + # Mock HTTP response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'token': 'new-access-token', + 'refresh_token': 'new-refresh-token', + 'session_update_action': 'update', + 'expires_in': 3600, + 'user': {'id': 'did:privy:test', 'accounts': ['0x123', '0x456']}, + } + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=['0x123'], + expiry=12345, + ) + + new_auth = auth.refresh_token(old_auth) + + # Verify new tokens + assert new_auth.accessToken == 'new-access-token' + assert new_auth.refreshToken == 'new-refresh-token' + assert new_auth.userId == 'did:privy:test' + assert new_auth.accounts == ['0x123', '0x456'] + assert new_auth.expiry > int(time.time() * 1000) + + # Verify saved to file + assert config_path.exists() + + @patch('httpx.Client.post') + def test_refresh_token_user_id_mismatch(self, mock_post, tmp_path): + """Test refresh fails when user ID doesn't match.""" + config_path = tmp_path / 'config.json' + + # Mock HTTP response with different user ID + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + 'token': 'new-token', + 'refresh_token': 'new-refresh', + 'session_update_action': 'update', + 'expires_in': 3600, + 'user': {'id': 'did:privy:different-user', 'accounts': []}, + } + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:original-user', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='User ID mismatch'): + auth.refresh_token(old_auth) + + @patch('httpx.Client.post') + def test_refresh_token_401_error(self, mock_post, tmp_path): + """Test refresh handles 401 authentication error.""" + config_path = tmp_path / 'config.json' + + # Mock 401 response + mock_response = Mock() + mock_response.status_code = 401 + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='Authentication expired'): + auth.refresh_token(old_auth) + + @patch('httpx.Client.post') + def test_refresh_token_429_rate_limit(self, mock_post, tmp_path): + """Test refresh handles 429 rate limit error.""" + config_path = tmp_path / 'config.json' + + # Mock 429 response + mock_response = Mock() + mock_response.status_code = 429 + mock_response.headers = {'retry-after': '120'} + mock_post.return_value = mock_response + + auth = AuthService(config_path=config_path) + old_auth = AuthStorage( + accessToken='old-token', + refreshToken='old-refresh', + userId='did:privy:test', + accounts=[], + expiry=12345, + ) + + with pytest.raises(ValueError, match='rate limited'): + auth.refresh_token(old_auth) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2a0c04c..00b0488 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -7,7 +7,7 @@ import json from pathlib import Path -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest @@ -87,6 +87,154 @@ def test_client_requires_url_or_query_url(self): Client() +@pytest.mark.unit +class TestClientAuthPriority: + """Test Client authentication priority (explicit token > env var > auth file)""" + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_explicit_token_highest_priority(self, mock_connect, mock_getenv): + """Test that explicit auth_token parameter has highest priority""" + mock_getenv.return_value = 'env-var-token' + + Client(query_url='grpc://localhost:1602', auth_token='explicit-token') + + # Verify that explicit token was used (not env var) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'explicit-token' + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_env_var_second_priority(self, mock_connect, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority""" + + # Return 'env-var-token' for AMP_AUTH_TOKEN, None for others + def getenv_side_effect(key, default=None): + if key == 'AMP_AUTH_TOKEN': + return 'env-var-token' + return default + + mock_getenv.side_effect = getenv_side_effect + + Client(query_url='grpc://localhost:1602') + + # Verify env var was checked + calls = [str(call) for call in mock_getenv.call_args_list] + assert any('AMP_AUTH_TOKEN' in call for call in calls) + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + assert middleware[0].get_token() == 'env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_auth_file_lowest_priority(self, mock_connect, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority""" + + # Return None for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + Client(query_url='grpc://localhost:1602', auth=True) + + # Verify auth file was used + mock_auth_service.assert_called_once() + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware', []) + assert len(middleware) == 1 + # The middleware should use the auth service's get_token method directly + assert middleware[0].get_token == mock_service_instance.get_token + + @patch('amp.client.os.getenv') + @patch('amp.client.flight.connect') + def test_no_auth_when_nothing_provided(self, mock_connect, mock_getenv): + """Test that no auth middleware is added when no auth is provided""" + + # Return None/default for all getenv calls + def getenv_side_effect(key, default=None): + return default + + mock_getenv.side_effect = getenv_side_effect + + Client(query_url='grpc://localhost:1602') + + # Verify no middleware was added + mock_connect.assert_called_once() + call_args = mock_connect.call_args + middleware = call_args[1].get('middleware') + assert middleware is None or len(middleware) == 0 + + +@pytest.mark.unit +class TestAdminClientAuthPriority: + """Test AdminClient authentication priority""" + + @patch('amp.admin.client.os.getenv') + def test_admin_explicit_token_highest_priority(self, mock_getenv): + """Test that explicit auth_token parameter has highest priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = 'env-var-token' + + client = AdminClient('http://localhost:8080', auth_token='explicit-token') + + # Verify explicit token was used (check get_token callable) + assert client._get_token() == 'explicit-token' + + @patch('amp.admin.client.os.getenv') + def test_admin_env_var_second_priority(self, mock_getenv): + """Test that AMP_AUTH_TOKEN env var has second priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = 'env-var-token' + + client = AdminClient('http://localhost:8080') + + # Verify env var was used + mock_getenv.assert_called_with('AMP_AUTH_TOKEN') + assert client._get_token() == 'env-var-token' + + @patch('amp.auth.AuthService') + @patch('amp.admin.client.os.getenv') + def test_admin_auth_file_lowest_priority(self, mock_getenv, mock_auth_service): + """Test that auth=True has lowest priority for AdminClient""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = None + mock_service_instance = Mock() + mock_service_instance.get_token.return_value = 'file-token' + mock_auth_service.return_value = mock_service_instance + + client = AdminClient('http://localhost:8080', auth=True) + + # Verify auth file was used + assert client._get_token == mock_service_instance.get_token + + @patch('amp.admin.client.os.getenv') + def test_admin_no_auth_when_nothing_provided(self, mock_getenv): + """Test that no auth header is added when no auth is provided""" + from amp.admin.client import AdminClient + + mock_getenv.return_value = None + + client = AdminClient('http://localhost:8080') + + # Verify no auth header + assert 'Authorization' not in client._http.headers + + @pytest.mark.unit class TestQueryBuilderManifest: """Test QueryBuilder manifest generation"""